102 votes

Comment fonctionne la fonction python numpy.where() ?

Je joue avec numpy et en creusant dans la documentation, je suis tombé sur de la magie. En fait, je parle de numpy.where() :

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Comment parviennent-ils à faire en sorte qu'en interne, vous puissiez passer quelque chose comme x > 5 dans une méthode ? Je suppose que cela a quelque chose à voir avec __gt__ mais je cherche une explication détaillée.

77voto

Joe Kington Points 68089

Comment font-ils pour que, en interne, vous puissiez passer quelque chose comme x > 5 dans une méthode ?

La réponse courte est qu'ils ne le font pas.

Toute sorte d'opération logique sur un tableau numpy renvoie un tableau booléen. (c'est-à-dire __gt__ , __lt__ etc. renvoient tous des tableaux booléens où la condition donnée est vraie).

Par exemple

x = np.arange(9).reshape(3,3)
print x > 5

rendements :

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

C'est la même raison pour laquelle quelque chose comme if x > 5: soulève une ValueError si x est un tableau numpy. C'est un tableau de valeurs Vrai/Faux, pas une seule valeur.

De plus, les tableaux numpy peuvent être indexés par des tableaux booléens. Par exemple x[x>5] donne [6 7 8] dans ce cas.

Honnêtement, c'est assez rare que vous ayez vraiment besoin numpy.where mais il renvoie juste les indicies où un tableau booléen est True . En général, vous pouvez faire ce dont vous avez besoin avec une simple indexation booléenne.

11 votes

Je tiens à souligner que numpy.where ont deux "modes opérationnels", le premier renvoie l'adresse de l'utilisateur. indices , donde condition is True et si les paramètres facultatifs x y y sont présents (même forme que condition ou diffusable sous une telle forme !), il renverra des valeurs de x quand condition is True sinon de y . Cela fait donc where plus polyvalent et lui permet d'être utilisé plus souvent. Merci

1 votes

Il peut également y avoir des surcharges dans certains cas en utilisant l'option __getitem__ syntaxe de [] sur l'un ou l'autre numpy.where o numpy.take . Puisque __getitem__ doit aussi supporter le découpage, il y a des frais généraux. J'ai constaté des différences de vitesse notables en travaillant avec les structures de données Pandas de Python et en indexant logiquement de très grandes colonnes. Dans ces cas, si vous n'avez pas besoin du découpage en tranches, alors take y where sont en fait meilleures.

25voto

Garrett Berg Points 884

Ancienne réponse c'est un peu confus. Il vous donne les LIEUX (tous) où votre affirmation est vraie.

donc :

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Je l'utilise comme une alternative à list.index(), mais il a aussi de nombreuses autres utilisations. Je ne l'ai jamais utilisé avec des tableaux 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nouvelle réponse Il semble que la personne demandait quelque chose de plus fondamental.

La question était de savoir comment VOUS pouviez mettre en œuvre quelque chose qui permette à une fonction (telle que where) de savoir ce qui a été demandé.

Notez d'abord que l'appel de n'importe lequel des opérateurs de comparaison fait une chose intéressante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Cela se fait en surchargeant la méthode "__gt__". Par exemple :

>>> class demo(object):
    def __gt__(self, item):
        print item

>>> a = demo()
>>> a > 4
4

Comme vous pouvez le voir, "a > 4" était un code valide.

Vous pouvez obtenir une liste complète et la documentation de toutes les fonctions surchargées ici : http://docs.python.org/reference/datamodel.html

Ce qui est incroyable, c'est la simplicité avec laquelle on peut le faire. TOUTES les opérations en Python sont effectuées de cette manière. Dire a > b est équivalent à a. gt (b) !

3 votes

Cette surcharge de l'opérateur de comparaison ne semble pas fonctionner correctement avec des expressions logiques plus complexes - par exemple je ne peux pas faire np.where(a > 30 and a < 50) o np.where(30 < a < 50) parce qu'elle finit par essayer d'évaluer le ET logique de deux tableaux de booléens, ce qui n'a pas beaucoup de sens. Existe-t-il un moyen d'écrire une telle condition avec np.where ?

0 votes

@meowsqueak np.where((a > 30) & (a < 50))

0 votes

Pourquoi np.where() renvoie-t-il une liste dans votre exemple ?

4voto

Piyush Singh Points 1871

np.where renvoie un tuple de longueur égale à la dimension du numpy ndarray sur lequel il est appelé (en d'autres termes ndim ) et chaque élément du tuple est un ndarray numpy d'indices de toutes les valeurs du ndarray initial pour lesquelles la condition est vraie (ne confondez pas dimension et forme).

Par exemple :

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))

y est un tuple de longueur 2 car x.ndim est 2. Le premier élément du tuple contient les numéros de ligne de tous les éléments supérieurs à 4 et le deuxième élément contient les numéros de colonne de tous les éléments supérieurs à 4. Comme vous pouvez le voir, [1,2,2,2] correspond aux numéros de ligne de 5,6,7,8 et [2,0,1,2] correspond aux numéros de colonne de 5,6,7,8. Notez que le tableau ndarray est parcouru le long de la première dimension (par ligne).

De même,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)

retournera un tuple de longueur 3 car x a 3 dimensions.

Mais attendez, il y a plus à np.where !

lorsque deux arguments supplémentaires sont ajoutés à np.where il effectuera une opération de remplacement pour toutes les combinaisons ligne-colonne par paire qui sont obtenues par le tuple ci-dessus.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X