7 votes

Stabilité du logsoftmax

Je sais comment rendre le softmax stable en ajoutant à chaque élément -max _i x_i. Cela évite le débordement et le débordement inférieur. Maintenant, prendre le logarithme de cela peut provoquer un débordement. Le log softmax(x) peut évaluer à zéro, conduisant à -infini.

Je ne suis pas sûr comment le résoudre. Je sais que c'est un problème courant. J'ai lu plusieurs réponses à ce sujet, que je n'ai pas comprises. Mais je suis toujours confus sur comment résoudre ce problème.

PS : Si vous fournissez un exemple simple, ce serait génial.

13voto

e_soroush Points 1725

Pour stabiliser Logsoftmax, la plupart des implémentations telles que Tensorflow et Thenao, utilisent une astuce qui retire le composant le plus grand max(x_i). Cette astuce est souvent utilisée pour calculer de manière stable softmax. Pour logsoftmax, nous commençons par :

formule

Après avoir extrait exp(b) et utilisé le fait que log(exp(x)) = x, nous avons :

formule

Si nous définissons b = max(x_i), cette nouvelle équation a à la fois des conditions de stabilité de dépassement et de sous-débit.


En termes de code, si x est un vecteur :

def log_softmax(x):
    x_off = x - np.max(x)
    return x_off - np.log(np.sum(np.exp(x_off)))

Voir aussi : https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/

1voto

user3113854 Points 26
logsoftmax = logits - log(reduce_sum(exp(logits), dim))

référence : https://www.tensorflow.org/api_docs/python/tf/nn/log_softmax

0voto

Keshav Kumar Points 1

Utilisez simplement ceci car il prend en charge Nan

tf.nn.softmax_cross_entropy_with_logits(
    labels, logits, axis=-1, name=None
)

logits = tf.constant([[4, 5, 1000]], dtype = tf.float32)
labels = tf.constant([[1,0,1]], dtype = tf.float32)

# Cas-1
output = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(output) 
>>> tf.Tensor([996.], shape=(1,), dtype=float32)

#Cas-2
a = tf.nn.softmax(logits)
output = tf.reduce_sum(-(labels * tf.math.log(a)))
print(output) 
>>> tf.Tensor(nan, shape=(), dtype=float32)

# cela se produit car la valeur du softmax est tronquée à zéro

print(a) 
>>>

-1voto

prosti Points 4630

Les astuces mathématiques ne peuvent pas vous aider à faire en sorte que log 0 soit autre chose que -∞. Si vous y réfléchissez bien, la seule solution est de normaliser les données pour éviter d'en arriver là.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X