149 votes

Test si un tableau numpy ne contient que des zéros

Nous initialisons un tableau numpy avec des zéros comme ci-dessous :

np.zeros((N,N+1))

Mais comment vérifier si tous les éléments d'une matrice numpy array n*n donnée sont nuls.
La méthode doit juste retourner un True si toutes les valeurs sont effectivement nulles.

228voto

superbatfish Points 1300

Les autres réponses affichées ici fonctionneront, mais la fonction la plus claire et la plus efficace à utiliser est la suivante numpy.any() :

>>> all_zeros = not np.any(a)

ou

>>> all_zeros = not a.any()
  • Cette solution est préférable à numpy.all(a==0) car il utilise moins de RAM. (Elle ne nécessite pas le tableau temporaire créé par la fonction a==0 terme.)
  • En outre, il est plus rapide que numpy.count_nonzero(a) car il peut retourner immédiatement lorsque le premier élément non nul a été trouvé.
    • Edit : Comme @Rachel l'a souligné dans les commentaires, np.any() n'utilise plus la logique de "court-circuit", donc vous ne verrez pas d'avantage de vitesse pour les petites matrices.

91voto

Prashant Kumar Points 5220

Vérifiez numpy.count_nonzero .

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5

73voto

J'utiliserais np.all ici, si vous avez un tableau a :

>>> np.all(a==0)

13voto

Rachel Points 225

Comme le dit une autre réponse, vous pouvez profiter des évaluations véridiques/fausses si vous savez que 0 est le seul élément falsy possible dans votre tableau. Tous les éléments d'un tableau sont faussés s'il n'y a pas d'éléments véridiques dans ce tableau.

>>> a = np.zeros(10)
>>> not np.any(a)
True

Cependant, la réponse affirmait que any était plus rapide que les autres options, en partie à cause des courts-circuits. En 2018, la méthode de Numpy all y any ne pas court-circuiter .

Si vous faites souvent ce genre de chose, il est très facile de fabriquer vos propres versions de court-circuitage en utilisant numba :

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Celles-ci ont tendance à être plus rapides que les versions de Numpy, même lorsqu'elles ne sont pas court-circuitées. count_nonzero est le plus lent.

Quelques entrées pour vérifier les performances :

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Vérifiez :

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Aide all y any équivalences :

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

2voto

Saankhya Mondal Points 73

Ça va marcher.

def check(arr):
    if np.all(arr == 0):
        return True
    return False

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X