Conv2d_transpose() transpose simplement les poids et les retourne de 180 degrés. Ensuite, elle applique la méthode standard conv2d(). "Transpose" signifie pratiquement qu'elle change l'ordre des "colonnes" dans le tenseur des poids. Veuillez consulter l'exemple ci-dessous.
Voici un exemple qui utilise des convolutions avec stride=1 et padding='SAME'. C'est un cas simple mais le même raisonnement pourrait être appliqué aux autres cas.
Disons que nous avons :
- Entrée : Image MNIST de 28x28x1, forme = [28,28,1].
- Couche convolutive : 32 filtres de 7x7, forme des poids = [7, 7, 1, 32], nom = W_conv1
Si nous effectuons une convolution de l'entrée, les activations des auront la forme suivante : [1,28,28,32].
activations = sess.run(h_conv1,feed_dict={x:np.reshape(image,[1,784])})
Où :
W_conv1 = weight_variable([7, 7, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = conv2d(x, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1
Pour obtenir la "déconvolution" ou la "convolution transposée", nous pouvons utiliser conv2d_transpose() sur les activations de la convolution de cette manière :
deconv = conv2d_transpose(activations,W_conv1, output_shape=[1,28,28,1],padding='SAME')
OU en utilisant conv2d() nous devons transposer et inverser les poids :
transposed_weights = tf.transpose(W_conv1, perm=[0, 1, 3, 2])
Ici, nous changeons l'ordre des "colonnes" de [0,1,2,3] à [0,1,3,2]. Ainsi, à partir de [7, 7, 1, 32], nous obtiendrons un tenseur de forme=[7,7,32,1]. Ensuite, nous inversons les poids :
for i in range(n_filters):
# Flip the weights by 180 degrees
transposed_and_flipped_weights[:,:,i,0] = sess.run(tf.reverse(transposed_weights[:,:,i,0], axis=[0, 1]))
Nous pouvons alors calculer la convolution avec conv2d() comme :
strides = [1,1,1,1]
deconv = conv2d(activations,transposed_and_flipped_weights,strides=strides,padding='SAME')
Et nous obtiendrons le même résultat que précédemment. On peut également obtenir le même résultat avec conv2d_backprop_input() en utilisant :
deconv = conv2d_backprop_input([1,28,28,1],W_conv1,activations, strides=strides, padding='SAME')
Les résultats sont présentés ici :
Test des fonctions conv2d(), conv2d_tranposed() et conv2d_backprop_input()
Nous pouvons constater que les résultats sont les mêmes. Pour le voir d'une meilleure façon, veuillez consulter mon code à l'adresse suivante :
https://github.com/simo23/conv2d_transpose
Ici, je reproduis la sortie de la fonction conv2d_transpose() en utilisant la fonction standard conv2d().