J'ai un modèle fonctionnel dans Keras (Resnet50 à partir des exemples du répertoire). Je l'ai entraîné avec ImageDataGenerator
y flow_from_directory
les données et le modèle sauvegardé à .h5
fichier. Lorsque j'appelle model.predict
J'obtiens un tableau de probabilités de classe. Mais je veux les associer à des étiquettes de classe (dans mon cas, des noms de dossiers). Comment puis-je les obtenir ? J'ai trouvé que je pouvais utiliser model.predict_classes
y model.predict_proba
mais je n'ai pas ces fonctions dans le modèle fonctionnel, seulement dans le modèle séquentiel.
Réponses
Trop de publicités?y_prob = model.predict(x)
y_classes = y_prob.argmax(axis=-1)
Comme suggéré aquí .
Lorsqu'on utilise flow_from_directory, le problème est de savoir comment interpréter les sorties de probabilité. En d'autres termes, comment faire correspondre les sorties de probabilité et les étiquettes de classe, étant donné que la façon dont flow_from_directory crée des vecteurs à un coup n'est pas connue à l'avance.
Nous pouvons obtenir un dictionnaire qui associe les étiquettes de classe à l'index du vecteur de prédiction que nous obtenons en sortie lorsque nous utilisons la fonction
generator= train_datagen.flow_from_directory("train", batch_size=batch_size)
label_map = (generator.class_indices)
La variable label_map est un dictionnaire comme ceci
{'class_14': 5, 'class_10': 1, 'class_11': 2, 'class_12': 3, 'class_13': 4, 'class_2': 6, 'class_3': 7, 'class_1': 0, 'class_6': 10, 'class_7': 11, 'class_4': 8, 'class_5': 9, 'class_8': 12, 'class_9': 13}
On peut ensuite en déduire la relation entre les scores de probabilité et les noms de classe.
En gros, vous pouvez créer ce dictionnaire par ce code.
from glob import glob
class_names = glob("*") # Reads all the folders in which images are present
class_names = sorted(class_names) # Sorting them
name_id_map = dict(zip(class_names, range(len(class_names))))
La variable name_id_map dans le code ci-dessus contient également le même dictionnaire que celui obtenu à partir de la fonction class_indices de flow_from_directory.
J'espère que cela vous aidera !
MISE À JOUR : cette méthode n'est plus valable pour les nouvelles versions de Keras. Veuillez utiliser argmax()
comme dans la réponse d'Emilia Apostolova.
Les modèles d'API fonctionnels n'ont que le predict()
qui, pour la classification, renvoie les probabilités de classe. Vous pouvez ensuite sélectionner les classes les plus probables en utilisant la fonction probas_to_classes()
fonction d'utilité. Exemple :
y_proba = model.predict(x)
y_classes = keras.np_utils.probas_to_classes(y_proba)
Ceci est équivalent à model.predict_classes(x)
sur le modèle séquentiel.
La raison en est que l'API fonctionnelle prend en charge une classe plus générale de tâches où predict_classes()
n'aurait pas de sens.
Plus d'informations : https://github.com/fchollet/keras/issues/2524
Vous devez utiliser l'index des étiquettes que vous avez, voici ce que je fais pour la classification des textes :
labels_index = { "website" : 0, "money" : 1 ....} # données labels = [1, 2, 1...] label_categories = to_categorical(np.asarray(labels)) # pour alimenter le modèle
Ensuite, pour les prédictions :
texts = ["hello, rejoins moi sur skype", "bonjour comment ça va ?", "tu me donnes de l'argent"]
sequences = tokenizer.texts_to_sequences(texts)
data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
predictions = model.predict(data)
t = 0
for text in texts:
i = 0
print("Prediction for \"%s\": " % (text))
for label in labels_index:
print("\t%s ==> %f" % (label, predictions[t][i]))
i = i + 1
t = t + 1
Cela donne :
Prediction for "hello, rejoins moi sur skype":
website ==> 0.759483
money ==> 0.037091
under ==> 0.010587
camsite ==> 0.114436
email ==> 0.075975
abuse ==> 0.002428
Prediction for "bonjour comment ça va ?":
website ==> 0.433079
money ==> 0.084878
under ==> 0.048375
camsite ==> 0.036674
email ==> 0.369197
abuse ==> 0.027798
Prediction for "tu me donnes de l'argent":
website ==> 0.006223
money ==> 0.095308
under ==> 0.003586
camsite ==> 0.003115
email ==> 0.884112
abuse ==> 0.007655