1. Je veux multiplier un lot de matrices avec un lot de matrices de la même longueur deux à deux
M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((batch_size, m, p))
# python >= 3.5
MN = M @ N
# or the old way,
MN = tf.matmul(M, N)
# MN has shape (batch_size, n, p)
2. Je veux multiplier un lot de matrices avec un lot de vecteurs de même longueur, par paires
Nous retombons de cas 1 par l'ajout et la suppression d'une dimension à l' v
.
M = tf.random_normal((batch_size, n, m))
v = tf.random_normal((batch_size, m))
Mv = (M @ v[..., None])[..., 0]
# Mv has shape (batch_size, n)
3. Je veux multiplier une matrice avec un lot de matrices
Dans ce cas, nous ne pouvons pas simplement ajouter un lot de la dimension de l' 1
de la matrice seule, car tf.matmul
ne diffuse pas dans le lot de dimension.
3.1. La matrice est sur le côté droit
Dans ce cas, nous pouvons traiter la matrice lot comme une seule matrice de grande taille, à l'aide d'un simple remodeler.
M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((m, p))
MN = tf.reshape(tf.reshape(M, [-1, m]) @ N, [-1, n, p])
# MN has shape (batch_size, n, p)
3.2. La matrice est sur le côté gauche
Ce cas est plus compliqué. Nous pouvons revenir à des cas 3.1 par la transposition des matrices.
MT = tf.matrix_transpose(M)
NT = tf.matrix_transpose(N)
NTMT = tf.reshape(tf.reshape(NT, [-1, m]) @ MT, [-1, p, n])
MN = tf.matrix_transpose(NTMT)
Toutefois, la transposition peut être une opération coûteuse, et ici il est fait deux fois sur un ensemble de matrices. Il peut être préférable de simplement dupliquer M
pour correspondre le lot de dimension:
MN = tf.tile(M[None], [batch_size, 1, 1]) @ N
Le profilage dire l'option qui convient le mieux pour un problème donné/combinaison matérielle.
4. Je veux multiplier une matrice avec un lot de vecteurs
Cela ressemble à l'affaire 3.2 comme la matrice est sur la gauche, mais il est en fait plus simple en raison de la transposition d'un vecteur est essentiellement un no-op. Nous terminons avec
M = tf.random_normal((n, m))
v = tf.random_normal((batch_size, m))
MT = tf.matrix_transpose(M)
Mv = v @ MT
Qu'en einsum
?
Tous les précédents multiplications pourrait avoir été écrite avec l' tf.einsum
couteau de l'armée suisse. Par exemple, la première solution pour le 3.2 pourrait être écrit simplement comme
MN = tf.einsum('nm,bmp->bnp', M, N)
Notez, cependant, que einsum
est finalement en s'appuyant sur tranpose
et matmul
pour le calcul.
Donc, même si einsum
est un moyen très pratique pour écrire la matrice de multiplications, il masque la complexité des opérations de sous - par exemple, il n'est pas simple à deviner combien de fois einsum
expression transposer vos données, et donc coûteux, l'opération sera. Aussi, il peut cacher le fait qu'il pourrait y avoir plusieurs solutions de rechange pour la même opération (voir 3.2) et pourrait ne pas forcément choisir la meilleure option.
Pour cette raison, je serais personnellement utiliser des formules explicites comme ceux ci-dessus afin de mieux transmettre leurs respectifs de la complexité. Bien que si vous savez ce que vous faites et comme la simplicité de l' einsum
de la syntaxe, puis par tous signifie aller pour elle.