2 votes

Comment appliquer un dégradé personnalisé : erreur de conversion de type

J'implémente DDPG. J'ai besoin de calculer un gradient personnalisé (le gradient des poids de l'acteur qui maximise la sortie du critique), puis de l'appliquer avec un optimiseur. Cependant, pour une raison quelconque, j'obtiens une erreur de type mystérieuse lorsque j'essaie de l'exécuter.

J'ai essayé de regarder d'autres tutoriels et de faire des recherches sur stack overflow mais je n'ai pas trouvé comment corriger l'erreur.

Voici un exemple de code qui donne l'erreur (le calcul réel est plus compliqué mais donne une réponse de la même forme) :

actor = Sequential()
actor.add(Dense(2, input_shape=(6,)))
# actor_inputs is randomly sampled
sess = K.get_session()
grad_op = K.gradients(actor.output, actor.trainable_weights)
grads = sess.run(grad_op, feed_dict={actor.input: actor_inputs})
opt = tf.keras.optimizers.Adam(lr=1e-4)
opt.apply_gradients(zip(grads, actor.trainable_weights))

Le gradient que cela calcule semble correct. Je m'attendais à ce que l'optimiseur l'applique au réseau mais j'obtiens l'erreur suivante dans la fenêtre de l'optimiseur apply_gradients appeler :

Tensor conversion requested dtype float32_ref for Tensor with dtype float32: 'Tensor("Adam_24/dense_95/kernel/m/Initializer/zeros:0", shape=(6, 2), dtype=float32)'

Voici le résultat de quelques impressions de test des données pertinentes :

print(actor_inputs) :

[[-0.43979521  0.         -1.28554755  0.          0.94703663 -0.32112555]]

print(grad_op) :

[<tf.Tensor 'gradients_2/dense_95/MatMul_grad/MatMul_1:0' shape=(6, 2) dtype=float32>, <tf.Tensor 'gradients_2/dense_95/BiasAdd_grad/BiasAddGrad:0' shape=(2,) dtype=float32>]

print(grads) :

[array([[ 3.003665  ,  3.003665  ],
       [ 0.        ,  0.        ],
       [-2.2157073 , -2.2157073 ],
       [ 0.        ,  0.        ],
       [-0.8517535 , -0.8517535 ],
       [ 0.52394277,  0.52394277]], dtype=float32), array([1., 1.], dtype=float32)]

print(actor.trainable_weights) :

[<tf.Variable 'dense_95/kernel:0' shape=(6, 2) dtype=float32_ref>, <tf.Variable 'dense_95/bias:0' shape=(2,) dtype=float32_ref>]

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X