9 votes

Le RMSE personnalisé n'est pas le même que la racine du MSE intégré de Keras pour la même prédiction.

J'ai défini une fonction RMSE personnalisée :

def rmse(y_pred, y_true):
    return K.sqrt(K.mean(K.square(y_pred - y_true)))

Je l'évaluais par rapport à l'erreur quadratique moyenne fournie par Keras :

keras.losses.mean_squared_error(y_true, y_pred)

Les valeurs que j'obtiens pour les métriques MSE et RMSE respectivement pour certaines (les mêmes) prédictions sont les suivantes :

mse: 115.7218 - rmse: 8.0966

Maintenant, quand je prends la racine de l'EMS, j'obtiens 10.7574 qui est évidemment plus élevée que la RMSE produite par la fonction RMSE personnalisée. Je n'ai pas été en mesure de comprendre pourquoi il en est ainsi, et je n'ai pas trouvé d'articles connexes sur ce sujet particulier. Y a-t-il une erreur dans la fonction RMSE que je ne vois tout simplement pas ? Ou est-ce lié à la façon dont Keras définit la fonction RMSE ? axis=-1 dans la fonction MSE (dont je n'ai pas encore bien compris le but) ?

C'est ici que j'invoque le RMSE et le MSE :

model.compile(loss="mae", optimizer="adam", metrics=["mse", rmse])

Je m'attendrais donc à ce que la racine de l'EQM soit la même que l'EQM.

J'ai initialement posé cette question sur Cross Validated mais elle a été mise en attente comme hors sujet.

6voto

Manoj Mohan Points 4906

Y a-t-il une erreur dans la fonction de perte RMSE que je n'ai simplement pas que je ne vois pas ? Ou est-ce lié à la façon dont Keras définit axis=-1 dans la fonction de perte MSE (dont je n'ai pas encore bien compris le but) ? fonction de perte MSE (dont je n'ai pas encore totalement compris le but) ?

Lorsque Keras effectue le calcul de la perte, la dimension du lot est conservée, ce qui explique la raison pour laquelle l axis=-1 . La valeur retournée est un tenseur. Ceci est dû au fait que la perte pour chaque échantillon mai doivent être pondérés avant de prendre la moyenne, selon que certains arguments sont passés dans la fonction fit() méthode comme sample_weight .

J'obtiens les mêmes résultats avec les deux approches.

from tensorflow import keras
import numpy as np
from keras import backend as K

def rmse(y_pred, y_true):
    return K.sqrt(K.mean(K.square(y_pred - y_true)))

l1 = keras.layers.Input(shape=(32))
l2 = keras.layers.Dense(10)(l1)
model = keras.Model(inputs=l1, outputs=l2)

train_examples = np.random.randn(5,32)
train_labels=np.random.randn(5,10)

Approche MSE

model.compile(loss='mse', optimizer='adam')
model.evaluate(train_examples, train_labels)

Approche RMSE

model.compile(loss=rmse, optimizer='adam')
model.evaluate(train_examples, train_labels)

Sortie

5/5 [==============================] - 0s 8ms/sample - loss: 1.9011
5/5 [==============================] - 0s 2ms/sample - loss: 1.3788

sqrt(1.9011) = 1.3788

0voto

Bien que sqrt(mse) est égal à rmse pour une configuration de modèle simple comme l'a montré la réponse de Manoj, j'ai rencontré ce problème pour une configuration de modèle complexe et je n'ai pas pu comprendre pourquoi cela s'est produit. Cependant, j'ai trouvé une solution de contournement pour se débarrasser de ce problème si quelqu'un a vraiment besoin de surveiller rmse comme métrique, mais je suis confronté au même problème dans la question. J'ai utilisé le LambdaCallback dans le callbacks pour imprimer le rmse de formation et de validation après chaque époque et cela a fonctionné :

def rmse(y_true, y_pred):
    return K.sqrt(K.mean(K.square(y_true - y_pred)))

rmse_print_callback = keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch,logs: 
                 print(f"rmse: {rmse(training_labels, model.predict(training_data)):.4f} - val_rmse: {rmse(validation_labels, model.predict(validation_data)):.4f}"))

model.fit(training_data, training_labels, epochs= 100, callbacks=[rmse_print_callback])

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X