88 votes

Comment appliquer numpy.linalg.norm à chaque ligne d'une matrice ?

J'ai une matrice 2D et je veux prendre la norme de chaque ligne. Mais lorsque j'utilise numpy.linalg.norm(X) directement, il prend la norme de la matrice entière.

Je peux prendre la norme de chaque ligne en utilisant une boucle for et ensuite prendre la norme de chaque ligne X[i] mais cela prend énormément de temps car j'ai 30 000 lignes.

Avez-vous des suggestions pour trouver un moyen plus rapide ? Ou est-il possible d'appliquer np.linalg.norm à chaque ligne d'une matrice ?

92voto

unutbu Points 222216

Notez que, comme perimosocordiae montre à partir de la version 1.9 de NumPy, np.linalg.norm(x, axis=1) est le moyen le plus rapide de calculer la norme L2.


Si vous calculez une norme L2, vous pouvez la calculer directement (en utilisant la fonction axis=-1 pour additionner les lignes) :

np.sum(np.abs(x)**2,axis=-1)**(1./2)

Les normes Lp peuvent être calculées de manière similaire, bien sûr.

Il est considérablement plus rapide que np.apply_along_axis mais peut-être pas aussi pratique :

In [48]: %timeit np.apply_along_axis(np.linalg.norm, 1, x)
1000 loops, best of 3: 208 us per loop

In [49]: %timeit np.sum(np.abs(x)**2,axis=-1)**(1./2)
100000 loops, best of 3: 18.3 us per loop

Autre ord Les formes de norm peut aussi être calculée directement (avec des gains de vitesse similaires) :

In [55]: %timeit np.apply_along_axis(lambda row:np.linalg.norm(row,ord=1), 1, x)
1000 loops, best of 3: 203 us per loop

In [54]: %timeit np.sum(abs(x), axis=-1)
100000 loops, best of 3: 10.9 us per loop

9 votes

Pourquoi faire np.abs(x) si l'on met x au carré de toute façon ?

11 votes

@Patrick : Si le dtype de x est complexe, cela fait une différence. Par exemple, si x = np.array([(1+1j,2+1j)]) puis np.sum(np.abs(x)**2,axis=-1)**(1./2) es array([ 2.64575131]) , tandis que np.sum(x**2,axis=-1)**(1./2) es array([ 2.20320266+1.36165413j]) .

3 votes

@perimosocordiae affichée une mise à jour qui numpy.linalg.norm avec son nouveau axis est actuellement l'approche la plus rapide.

59voto

perimosocordiae Points 4582

Ressusciter une vieille question suite à une mise à jour de numpy. A partir de la version 1.9, numpy.linalg.norm accepte désormais un axis argument. [ code , documentation ]

C'est la nouvelle méthode la plus rapide en ville :

In [10]: x = np.random.random((500,500))

In [11]: %timeit np.apply_along_axis(np.linalg.norm, 1, x)
10 loops, best of 3: 21 ms per loop

In [12]: %timeit np.sum(np.abs(x)**2,axis=-1)**(1./2)
100 loops, best of 3: 2.6 ms per loop

In [13]: %timeit np.linalg.norm(x, axis=1)
1000 loops, best of 3: 1.4 ms per loop

Et pour prouver que c'est calculer la même chose :

In [14]: np.allclose(np.linalg.norm(x, axis=1), np.sum(np.abs(x)**2,axis=-1)**(1./2))
Out[14]: True

20voto

Nico Points 893

Bien plus rapide que la réponse acceptée, l'utilisation de l'outil NumPy einsum ,

numpy.sqrt(numpy.einsum('ij,ij->i', a, a))

Notez l'échelle logarithmique :

enter image description here


Code pour reproduire l'intrigue :

import numpy
import perfplot

def sum_sqrt(a):
    return numpy.sqrt(numpy.sum(numpy.abs(a) ** 2, axis=-1))

def apply_norm_along_axis(a):
    return numpy.apply_along_axis(numpy.linalg.norm, 1, a)

def norm_axis(a):
    return numpy.linalg.norm(a, axis=1)

def einsum_sqrt(a):
    return numpy.sqrt(numpy.einsum("ij,ij->i", a, a))

perfplot.show(
    setup=lambda n: numpy.random.rand(n, 3),
    kernels=[sum_sqrt, apply_norm_along_axis, norm_axis, einsum_sqrt],
    n_range=[2 ** k for k in range(20)],
    xlabel="len(a)",
)

2 votes

Pourriez-vous nous expliquer ce que c'est/ce que ça fait ?

0 votes

Einsum ne manque jamais de livrer :)

7voto

NPE Points 169956

Essayez ce qui suit :

In [16]: numpy.apply_along_axis(numpy.linalg.norm, 1, a)
Out[16]: array([ 5.38516481,  1.41421356,  5.38516481])

a est votre tableau 2D.

Ce qui précède permet de calculer la norme L2. Pour une norme différente, vous pouvez utiliser quelque chose comme.. :

In [22]: numpy.apply_along_axis(lambda row:numpy.linalg.norm(row,ord=1), 1, a)
Out[22]: array([9, 2, 9])

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X