11 votes

L'entraînement d'un réseau de neurones entièrement convolutif avec des entrées de taille variable prend un temps déraisonnable avec Keras/TensorFlow.

J'essaie d'implémenter un FCNN pour la classification d'images qui peut accepter des entrées de taille variable. Le modèle est construit en Keras avec un backend TensorFlow.

Prenons l'exemple ludique suivant :

model = Sequential()

# width and height are None because we want to process images of variable size 
# nb_channels is either 1 (grayscale) or 3 (rgb)
model.add(Convolution2D(32, 3, 3, input_shape=(nb_channels, None, None), border_mode='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(32, 3, 3, border_mode='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(16, 1, 1))
model.add(Activation('relu'))

model.add(Convolution2D(8, 1, 1))
model.add(Activation('relu'))

# reduce the number of dimensions to the number of classes
model.add(Convolution2D(nb_classses, 1, 1))
model.add(Activation('relu'))

# do global pooling to yield one value per class
model.add(GlobalAveragePooling2D())

model.add(Activation('softmax'))

Ce modèle fonctionne bien mais je rencontre un problème de performance. L'apprentissage sur des images de taille variable prend un temps déraisonnable par rapport à l'apprentissage sur les entrées de taille fixe. Si je redimensionne toutes les images à la taille maximale de l'ensemble de données, l'entraînement du modèle prend toujours beaucoup moins de temps que l'entraînement sur l'entrée de taille variable. Alors, est-ce que input_shape=(nb_channels, None, None) la bonne façon de spécifier une entrée de taille variable ? Existe-t-il un moyen d'atténuer ce problème de performance ?

Mise à jour

model.summary() pour un modèle avec 3 classes et des images en niveaux de gris :

Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
convolution2d_1 (Convolution2D)  (None, 32, None, None 320         convolution2d_input_1[0][0]      
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 32, None, None 0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 32, None, None 0           activation_1[0][0]               
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 32, None, None 9248        maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D)    (None, 32, None, None 0           convolution2d_2[0][0]            
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 16, None, None 528         maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 16, None, None 0           convolution2d_3[0][0]            
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D)  (None, 8, None, None) 136         activation_2[0][0]               
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 8, None, None) 0           convolution2d_4[0][0]            
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D)  (None, 3, None, None) 27          activation_3[0][0]               
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 3, None, None) 0           convolution2d_5[0][0]            
____________________________________________________________________________________________________
globalaveragepooling2d_1 (Global (None, 3)             0           activation_4[0][0]               
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 3)             0           globalaveragepooling2d_1[0][0]   
====================================================================================================
Total params: 10,259
Trainable params: 10,259
Non-trainable params: 0

0voto

Mark Parris Points 1

Les images de tailles différentes impliquent des images de choses similaires à une échelle différente. Si cette différence d'échelle est significative, la position relative des objets similaires se déplacera du centre du cadre vers le haut et la gauche au fur et à mesure que la taille de l'image diminue. L'architecture de réseau (simple) présentée est consciente de l'espace, il serait donc cohérent que le taux de convergence du modèle se dégrade car les données d'une échelle très différente seraient incohérentes. Cette architecture n'est pas bien adaptée à la recherche d'une même chose à des endroits différents ou multiples.

Un certain degré de cisaillement, de rotation, de miroir aiderait le modèle à se généraliser, mais en le redimensionnant à une taille cohérente. Ainsi, lorsque vous redimensionnez, vous réglez le problème d'échelle et rendez les données d'entrée spatialement cohérentes.

En bref, je pense que c'est que cette architecture de réseau n'est pas adaptée / capable pour la tâche que vous lui donnez, c'est-à-dire différentes échelles.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X