J'essaie de comprendre le rôle de la Flatten
dans Keras. Voici mon code, qui est un simple réseau à deux couches. Il prend en entrée des données bidimensionnelles de forme (3, 2), et sort des données unidimensionnelles de forme (1, 4) :
model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)
print y.shape
Cela imprime que y
a la forme (1, 4). Cependant, si je supprime le Flatten
la ligne, puis il imprime que y
a la forme (1, 3, 4).
Je ne comprends pas. D'après ce que je comprends des réseaux neuronaux, le model.add(Dense(16, input_shape=(3, 2)))
crée une couche cachée entièrement connectée, avec 16 nœuds. Chacun de ces nœuds est connecté à chacun des 3x2 éléments d'entrée. Par conséquent, les 16 nœuds à la sortie de cette première couche sont déjà "plats". Ainsi, la forme de la sortie de la première couche devrait être (1, 16). Ensuite, la deuxième couche prend cette forme en entrée et produit des données de forme (1, 4).
Donc, si la sortie de la première couche est déjà "plate" et de forme (1, 16), pourquoi dois-je l'aplatir davantage ?
0 votes
Il peut être utile de comprendre Flatten en le comparant à GlobalPooling. stackoverflow.com/a/63502664/10375049