137 votes

Chargement d'un modèle Keras formé et poursuite de la formation

Je me demandais si il était possible de sauvegarder une partie formés Keras modèle et la poursuite de la formation après le chargement du modèle.

La raison pour cela est que je vais avoir plus de données sur la formation dans l'avenir et je ne veux pas de recycler l'ensemble de la maquette de nouveau.

Les fonctions que j'utilise sont:

#Partly train model
model.fit(first_training, first_classes, batch_size=32, nb_epoch=20)

#Save partly trained model
model.save('partly_trained.h5')

#Load partly trained model
from keras.models import load_model
model = load_model('partly_trained.h5')

#Continue training
model.fit(second_training, second_classes, batch_size=32, nb_epoch=20)

Edit 1: ajout de travail entièrement exemple

Avec le premier jeu de données après 10 les époques, la perte de la dernière époque sera 0.0748 et l'exactitude 0.9863.

Après l'enregistrement, la suppression et recharger le modèle de la perte et de l'exactitude du modèle de formation sur le deuxième jeu de données va être 0.1711 et 0.9504 respectivement.

Est-ce causé par les nouvelles données d'entraînement ou par un tout nouveau modèle appris?

"""
Model by: http://machinelearningmastery.com/
"""
# load (downloaded if needed) the MNIST dataset
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.models import load_model
numpy.random.seed(7)

def baseline_model():
    model = Sequential()
    model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu'))
    model.add(Dense(num_classes, init='normal', activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

if __name__ == '__main__':
    # load data
    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    # flatten 28*28 images to a 784 vector for each image
    num_pixels = X_train.shape[1] * X_train.shape[2]
    X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
    X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')
    # normalize inputs from 0-255 to 0-1
    X_train = X_train / 255
    X_test = X_test / 255
    # one hot encode outputs
    y_train = np_utils.to_categorical(y_train)
    y_test = np_utils.to_categorical(y_test)
    num_classes = y_test.shape[1]

    # build the model
    model = baseline_model()

    #Partly train model
    dataset1_x = X_train[:3000]
    dataset1_y = y_train[:3000]
    model.fit(dataset1_x, dataset1_y, nb_epoch=10, batch_size=200, verbose=2)

    # Final evaluation of the model
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))

    #Save partly trained model
    model.save('partly_trained.h5')
    del model

    #Reload model
    model = load_model('partly_trained.h5')

    #Continue training
    dataset2_x = X_train[3000:]
    dataset2_y = y_train[3000:]
    model.fit(dataset2_x, dataset2_y, nb_epoch=10, batch_size=200, verbose=2)
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))

49voto

Marcin Możejko Points 19602

En fait, - model.save enregistre toutes les informations nécessaires au redémarrage de la formation dans votre cas. La seule chose qui pourrait être gâchée par le rechargement du modèle est l'état de votre optimiseur. Pour vérifier cela - essayez de save , rechargez le modèle et entraînez-le sur les données de formation.

10voto

Wolfgang Points 41

Le problème est peut-être que vous utilisez un autre assistant ou différents arguments pour votre optimizer. J'ai juste eu le même problème avec un modèle de pré-entraîné, à l'aide de

reduce_lr = ReduceLROnPlateau(monitor='loss', factor=lr_reduction_factor,
                              patience=patience, min_lr=min_lr, verbose=1)

pour la pré-entraîné modèle, selon lequel l'apprentissage initial, le taux commence à 0,0003 et au cours de la pré-formation, il est réduit à la min_learning taux, qui est 0.000003

Je viens de copier cette ligne sur le script qui utilise le pré-formés modèle et ça l'a vraiment mauvaise précision. Jusqu'à ce que j'ai remarqué que la dernière d'apprentissage taux de la pré-entraîné modèle était le min au taux d'apprentissage, c'est à dire 0.000003. Et si je commence avec l'apprentissage des taux d', j'obtiens exactement la même précision de commencer avec la sortie de la pré-entraîné modèle qui fait sens, en commençant par l'apprentissage d'un taux qui est 100 fois plus grande que la dernière apprentissage taux utilisé pour le pré-entraîné modèle donnera lieu à un énorme dépassement de GD et donc fortement diminué la précision.

3voto

shahar_m Points 727

Notez que Keras a parfois des problèmes avec les modèles chargés, comme ici . Cela peut expliquer des cas dans lesquels vous ne partez pas de la même précision entraînée.

2voto

CodePoet Points 71

Vous pouvez également utiliser Concept Drift, reportez- vous à la section Si vous souhaitez reconvertir un modèle lorsque de nouvelles observations sont disponibles . Il y a aussi le concept d'oubli catastrophique qui est discuté par un groupe de travaux universitaires. En voici une avec MNIST Enquête empirique sur l'oubli catastrophique

1voto

flowgrad Points 91

Tous les ci-dessus permet, vous devez reprendre à partir du même taux() comme le LR lorsque le modèle et le poids ont été sauvés. Définir directement sur l'optimiseur.

Notez que l'amélioration de la il n'est pas garanti, car le modèle peut avoir atteint le minimum local, qui peut être global. Il n'y a pas de point de reprise d'un modèle afin de rechercher un autre minimum local, sauf si vous avez l'intention d'augmenter le taux d'apprentissage de manière contrôlée et pousser le modèle dans un mieux possible le minimum non loin de là.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X