Je suis sûr qu'il existe de nombreuses façons de procéder, mais j'ai bricolé et j'ai créé ma propre version.
Tout d'abord, un callback personnalisé permet de saisir et de mettre à jour l'historique à la fin de chaque époque. J'y ai aussi une callback pour sauvegarder le modèle. Ces deux fonctions sont pratiques car si vous vous plantez ou si vous vous arrêtez, vous pouvez reprendre l'entraînement à la dernière époque terminée.
class LossHistory(Callback):
# https://stackoverflow.com/a/53653154/852795
def on_epoch_end(self, epoch, logs = None):
new_history = {}
for k, v in logs.items(): # compile new history from logs
new_history[k] = [v] # convert values into lists
current_history = loadHist(history_filename) # load history from current training
current_history = appendHist(current_history, new_history) # append the logs
saveHist(history_filename, current_history) # save history from current training
model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
history_checkpoint = LossHistory()
callbacks_list = [model_checkpoint, history_checkpoint]
Ensuite, voici quelques fonctions d'aide qui font exactement ce qu'elles sont censées faire. Elles sont toutes appelées par la fonction LossHistory()
le rappel.
# https://stackoverflow.com/a/54092401/852795
import json, codecs
def saveHist(path, history):
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4)
def loadHist(path):
n = {} # set history to empty
if os.path.exists(path): # reload history if it exists
with codecs.open(path, 'r', encoding='utf-8') as f:
n = json.loads(f.read())
return n
def appendHist(h1, h2):
if h1 == {}:
return h2
else:
dest = {}
for key, value in h1.items():
dest[key] = value + h2[key]
return dest
Après cela, tout ce dont vous avez besoin est de définir history_filename
à quelque chose comme data/model-history.json
ainsi que l'ensemble model_filesname
à quelque chose comme data/model.h5
. Une dernière astuce pour être sûr de ne pas gâcher votre historique à la fin de l'entraînement, en supposant que vous vous arrêtez et recommencez, et que vous maintenez les rappels, est de faire ceci :
new_history = model.fit(X_train, y_train,
batch_size = batch_size,
nb_epoch = nb_epoch,
validation_data=(X_test, y_test),
callbacks=callbacks_list)
history = appendHist(history, new_history.history)
Quand vous voulez, history = loadHist(history_filename)
récupère votre histoire.
Le problème vient du json et des listes mais je n'ai pas réussi à le faire fonctionner sans le convertir par itération. Quoi qu'il en soit, je sais que cela fonctionne parce que j'ai travaillé dessus pendant des jours maintenant. Le site pickled.dump
réponse à https://stackoverflow.com/a/44674337/852795 pourrait être meilleur, mais je ne sais pas ce que c'est. Si j'ai oublié quelque chose ou si vous n'arrivez pas à le faire fonctionner, faites-le moi savoir.
2 votes
Si cela vous aide, vous pouvez aussi bien utiliser la fonction
CSVLogger()
de keras, comme décrit ici : keras.io/callbacks/#csvlogger1 votes
Quelqu'un peut-il recommander une méthode permettant de sauvegarder l'objet historique renvoyé par la fonction
fit
? Il contient des informations utiles dans.params
attribut que j'aimerais garder aussi. Oui, je peux sauvegarder leparams
&history
séparément ou les combiner dans un dict, par exemple, mais je suis intéressé par un moyen simple de sauvegarder l'ensemble de lahistory
objet.