53 votes

Quelle est la différence entre tf.truncated_normal et tf.random_normal?

tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) génère des valeurs aléatoires à partir d'une distribution normale.

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) génère des valeurs aléatoires à partir d'une distribution normale tronquée.

J'ai essayé de googler «distribution normale tronquée». Mais je n'ai pas compris grand chose.

80voto

Salvador Dali Points 11667

La documentation dit tout: Pour la distribution normale tronquée:

Les valeurs générées suivent une distribution normale avec une moyenne spécifiée et l'écart-type, sauf que les valeurs dont l'ampleur est plus de 2 écarts-types de la moyenne sont supprimés et re-cueillies.

Très probablement, il est facile de comprendre la différence, en traçant le graphe de vous-même (%magie est parce que j'utilise jupyter portable):

import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline  

n = 500000
A = tf.truncated_normal((n,))
B = tf.random_normal((n,))
with tf.Session() as sess:
    a, b = sess.run([A, B])

Et maintenant

plt.hist(a, 100, (-4.2, 4.2));
plt.hist(b, 100, (-4.2, 4.2));

enter image description here


Le point pour l'utilisation normale tronquée est de surmonter la saturation de la tome des fonctions comme sigmoïde (où, si la valeur est trop grande/petite, le neurone s'arrête d'apprentissage).

25voto

ksooklall Points 1992

tf.truncated_normal() sélectionne nombres aléatoires à partir d'une distribution normale dont la moyenne est proche de 0 et les valeurs sont proches de 0 Ex. -0.1 à 0.1. Il est appelé tronquée parce que votre couper la queue d'une distribution normale.

tf.random_normal() sélectionne nombres aléatoires à partir d'une distribution normale dont la moyenne est proche de 0; toutefois, les valeurs peuvent être un peu plus espacés. Ex. -2 à 2

Dans la pratique (Apprentissage automatique), vous souhaitez généralement de votre poids pour être proche de 0.

8voto

Martin Svedin Points 148

La documentation de l'API pour les tf.truncated_normal() décrit la fonction en tant que:

Les sorties de valeurs aléatoires à partir d'une distribution normale tronquée.

Les valeurs générées suivent une distribution normale avec moyenne et écart-type, sauf que les valeurs dont l'ampleur est plus de 2 écarts-types de la moyenne sont supprimés et re-cueillies.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X