2 votes

Comment vectoriser cette multiplication ?

J'ai une matrice X de forme (ni*43*91)x67 et un tenseur W de forme 67x43x91. ni varie

J'ai besoin d'obtenir un vecteur (ni*43*91) y en faisant le point entre les ni premières lignes de X et la première colonne de W pour obtenir les ni premiers éléments de y et les ni deuxièmes lignes de X avec la deuxième colonne de W pour obtenir les ni deuxièmes éléments de y, et ainsi de suite. Lorsque je n'ai plus de colonnes dans W, je passe à la dimension suivante et je continue.

J'ai deux masques dim2 et dim3, tous deux de forme (ni*43*91), dans l'ordre. Pour l'instant c'est ce que je fais (simplifié) et c'est très lent

for d3 in range(91):
  for d2 in range(43):
    mask = ((dim3 == d3) & (dim2 == d2))
    curr_X = X[mask, :]
    curr_W = W[:,d2,d3]
    curr_y = numpy.dot(curr_X,curr_W)
    y[mask] = curr_y

Est-il possible de le faire sans les boucles "for" ?

0voto

Jaime Points 25540

Je ne comprends pas très bien ce que votre dim2 y dim3 sont des tableaux, et comment est mask mais d'après votre description, vous voulez quelque chose d'approchant :

ni = 10
a, b, c = 43, 91, 67
X = np.random.rand(ni*a*b, c)
W = np.random.rand(c, a, b)

X = X.reshape(ni, a*b, c)
W = W.reshape(c, a*b)

y = np.einsum('ijk, kj -> ij', X, W)
y = y.reshape(-1)

Si vous mettez à jour votre question avec un code de travail, c'est-à-dire une description complète de dim2 y dim3 Nous pouvons affiner cette méthode pour qu'elle renvoie exactement la même chose, si ce n'est pas déjà le cas.

0voto

flonk Points 584

Tout d'abord, ce que vous voulez faire n'est pas clair, car votre code ne fonctionne pas. Je ne peux que deviner vous voulez le faire :

from numpy import *
from numpy.random import rand

ni=12
A=67
B=43
C=91

X = rand(ni*B*C,A) 
W = rand(A,B,C)

y = zeros((ni*B*C))

for k in xrange(len(y)):
    b = (k/ni)/C
    c = (k/ni) % C

    #print 'y[%i] = dot(X[%i,:],W[:,%i,%i])'%(k,k,b,c)

    y[k] = dot(X[k,:],W[:,b,c])

Si vous définissez simplement A,B,C,ni à des valeurs inférieures et décommenter le fichier print -Vous verrez rapidement ce que fait cet algorithme.

Si c'est ce que vous voulez, vous pouvez le faire plus rapidement avec cette phrase :

y2 = sum(X * (W.reshape((A,B*C)).swapaxes(0,1).repeat(ni,axis=0)),axis=1)

Malgré quelques réarrangements d'index, l'astuce cruciale consiste à utiliser repeat car dans la boucle, les indices b,c "gel" pour ni étapes, tandis que k se développe.

Je suis un peu pressé en ce moment, mais si vous avez besoin de plus d'explications, laissez un commentaire.

0voto

Il est assez difficile de comprendre, à partir de la question, quel est le résultat souhaité, mais le résultat Je pense que que vous recherchez peut être obtenue assez facilement de la manière suivante :

y = (X.T * W[:,dim2,dim3]).sum(axis=0)

Comparaison de l'exactitude et de la rapidité :

import numpy as np

# some test data, the sorting isn't really necessary
N1, N2, N3 = 67, 43, 91
ni_avg = 1.75
N = int(ni_avg * N2 * N3)

dim2 = np.random.randint(N2, size=N)
dim3 = np.sort(np.random.randint(N3, size=N))
for d3 in range(N3):
    dim2[dim3==d3].sort()

X = np.random.rand(N, N1)
W = np.random.rand(N1, N2, N3)

# original code
def original():
    y = np.empty(X.shape[0])
    for d2 in range(W.shape[1]):
        for d3 in range(W.shape[2]):
            mask = ((dim3 == d3) & (dim2 == d2))
            curr_X = X[mask, :]
            curr_W = W[:,d2,d3]
            curr_y = numpy.dot(curr_X,curr_W)
            y[mask] = curr_y
    return y

# comparison
%timeit original()
# 1 loops, best of 3: 672 ms per loop
%timeit (X.T * W[:,dim2,dim3]).sum(axis=0)
# 10 loops, best of 3: 31.8 ms per loop
np.allclose(original(), np.sum(X.T * W[:,dim2,dim3], axis=0))
# True

Une solution un peu plus rapide consisterait à utiliser

y = np.einsum('ij,ji->i', X, W[:,dim2,dim3])

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X