3 votes

Tensorflow : décroissance de poids vs normalisation des logits

Je travaille sur un CNN dont l'architecture est similaire à celle du CNN. cifar10 exemple. Mon CNN se compose de deux couches convolutionnelles, chacune suivie d'une couche de max-pooling, et de trois couches entièrement connectées. Toutes les couches, sauf la dernière, utilisent relu fonction d'activation. La couche finale produit des logits d'ordre 10^5 donc l'exécution de la fonction softmax aboutit à un encodage à un coup. J'ai essayé de résoudre ce problème de deux manières différentes.

Tout d'abord, j'ai simplement remis à l'échelle les logits à [-1, 1] c'est-à-dire qu'on les a normalisés. Cela semble résoudre le problème, la formation se déroule bien et CNN produit un résultat raisonnable. Mais je ne suis pas sûr que ce soit la bonne façon de procéder, j'ai l'impression que la normalisation des logits est une solution de rechange qui ne résout pas le problème initial. enter image description here

Deuxièmement, j'ai appliqué la décroissance du poids. Sans la normalisation des logits, l'entropie croisée part de grandes valeurs, mais elle diminue régulièrement comme la perte totale de poids. Cependant, la précision produit un modèle étrange. De plus, le réseau formé avec la décroissance de poids produit des résultats bien pires que celui avec la normalisation des logits. enter image description here

La dégradation du poids est ajoutée comme suit :

def weight_decay(var, wd):
    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('weight_losses', weight_decay)

Précision

predictions = tf.nn.softmax(logits)
one_hot_pred = tf.argmax(predictions, 1)
correct_pred = tf.equal(one_hot_pred, tf.argmax(y, 1))
accuracy_batch = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

Alors comment aborder le problème des logits de grande amplitude ? Quelle est la meilleure pratique ? Quelle peut être la raison pour laquelle l'application de la décroissance du poids produit un résultat moins bon ?

-1voto

dkk Points 1

A mon avis, tf.nn.l2_loss(var) n'est pas normalisé, en d'autres termes, peut-être que votre wd est trop grande, ou ne convient pas d'une manière ou d'une autre.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X