Notez que weighted_cross_entropy_with_logits
pondéré par la variante de l' sigmoid_cross_entropy_with_logits
. Sigmoïde de la croix de l'entropie est généralement utilisé pour les binaires de classification. Oui, il peut gérer plusieurs labels, mais sigmoïde de la croix de l'entropie fait, fondamentalement, un (binaire) décision sur chacun d'eux, par exemple, pour un visage net, ceux (non mutuellement exclusifs) les étiquettes pourraient être "le sujet porte des lunettes?", "L'objet de la femelle?", etc.
Dans la classification binaire(s), chaque canal de sortie correspond à un fichier binaire (soft) de la décision. Par conséquent, la pondération doit se faire dans le calcul de la perte. C'est ce qu' weighted_cross_entropy_with_logits
n', par la pondération d'un terme de l'entropie croisée sur l'autre.
Dans excluent mutuellement multilabel classification, nous utilisons softmax_cross_entropy_with_logits
, qui se comporte différemment: chaque canal de sortie correspond à la partition d'une classe candidat. La décision vient après, en comparant les résultats de chaque canal.
Pondération avant la décision finale est donc une simple question de modifier les partitions avant de les comparer, généralement par multiplication avec des poids. Par exemple, pour une tâche de classification ternaire,
# your class weights
class_weights = tf.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)
Vous pouvez également appuyer sur tf.losses.softmax_cross_entropy
pour gérer les trois dernières étapes.
Dans votre cas, où vous avez besoin pour s'attaquer à des données de déséquilibre, la classe de poids pourrait en effet être inversement proportionnelle à leur fréquence dans le train de données. De les normaliser, de sorte que leur somme soit égale à un, ou le nombre de classes a aussi du sens.
Notez que dans ce qui précède, nous avons pénalisé la perte de la vraie étiquette des échantillons. Nous avons pu également pénalisé la perte de l' estime des étiquettes en définissant simplement
weights = class_weights
et le reste du code n'a pas besoin de changer grâce à la radiodiffusion à de la magie.
Dans le cas général, vous voulez des poids qui varient selon le genre d'erreur que vous faites. En d'autres termes, à chaque paire d'étiquettes X
et Y
, vous pouvez choisir la façon de sanctionner le choix de l'étiquette X
quand le vrai label est - Y
. Vous vous retrouvez avec un ensemble avant la matrice de poids, ce qui entraîne weights
au-dessus de tous les (num_samples, num_classes)
tenseur. Cela va un peu au-delà de ce que vous voulez, mais il pourrait être utile de savoir cependant que seul votre définition du poids tenseur besoin de changer dans le code ci-dessus.