Dans TensorFlow <2, la fonction de formation pour un acteur DDPG peut être implémentée de manière concise en utilisant tf.keras.backend.function
comme suit :
critic_output = self.critic([self.actor(state_input), state_input])
actor_updates = self.optimizer_actor.get_updates(params=self.actor.trainable_weights,
loss=-tf.keras.backend.mean(critic_output))
self.actor_train_on_batch = tf.keras.backend.function(inputs=[state_input],
outputs=[self.actor(state_input)],
updates=actor_updates)
Ensuite, lors de chaque étape de formation appelant self.actor_train_on_batch([np.array(state_batch)])
calculera les gradients et effectuera les mises à jour.
Cependant, l'exécution de ce programme sur TF 2.0 donne l'erreur suivante en raison de l'activation par défaut de la fonction "eager mode" :
actor_updates = self.optimizer_actor.get_updates(params=self.actor.trainable_weights, loss=-tf.keras.backend.mean(critic_output))
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py", line 448, in get_updates
grads = self.get_gradients(loss, params)
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py", line 361, in get_gradients
grads = gradients.gradients(loss, params)
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 158, in gradients
unconnected_gradients)
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\gradients_util.py", line 547, in _GradientsHelper
raise RuntimeError("tf.gradients is not supported when eager execution "
RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.
Comme prévu, la désactivation de l'exécution rapide via tf.compat.v1.disable_eager_execution()
corrige le problème.
Cependant, je ne veux pas désactiver l'exécution rapide pour tout - je voudrais utiliser uniquement l'API 2.0. L'exception suggère d'utiliser tf.GradientTape
au lieu de tf.gradients
mais c'est un appel interne.
Question : Quelle est la manière appropriée de calculer -tf.keras.backend.mean(critic_output)
en mode graphique (dans TensorFlow 2.0) ?