3 votes

Accélérer la fonction qui prend une fonction en argument avec numba

Je tente d'utiliser numba pour accélérer une fonction qui prend une autre fonction en argument. Un exemple minimal serait le suivant :

import numba as nb

def f(x):
    return x*x

@nb.jit(nopython=True)
def call_func(func,x):
    return func(x)

if __name__ == '__main__':
    print(call_func(f,5))

Cependant, cela ne fonctionne pas, car apparemment numba ne sait pas quoi faire avec cet argument de fonction. La trace de la pile est assez longue :

Traceback (most recent call last):
  File "numba_function.py", line 15, in 
    print(call_func(f,5))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
    infer.propagate()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
    raise errors[0]
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
    constraint(typeinfer)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
    raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)

Cette erreur peut avoir été causée par le(s) argument(s) suivant(s) :
- argument 0 : impossible de déterminer le type Numba de la classe 'function'

Y a-t-il un moyen de corriger cela ?

4voto

MSeifert Points 6307

Il dépend si la func que vous passez à call_func peut être compilée en mode nopython.

Si elle ne peut pas être compilée en mode nopython, alors c'est impossible car numba ne prend pas en charge les appels Python à l'intérieur d'une fonction nopython (c'est la raison pour laquelle il est appelé nopython).

Cependant, si elle peut être compilée en mode nopython, vous pouvez utiliser une fermeture :

import numba as nb

def f(x):
    return x*x

def call_func(func, x):
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner(x)

if __name__ == '__main__':
    print(call_func(f,5))

Cette approche présente quelques inconvénients évidents car elle nécessite de compiler func et inner à chaque fois que vous appelez call_func. Cela signifie que c'est viable uniquement si l'accélération due à la compilation de la fonction est supérieure au coût de la compilation. Vous pouvez atténuer ce surcoût si vous appelez call_func avec la même fonction plusieurs fois :

import numba as nb

def f(x):
    return x*x

def call_func(func):  # accepte uniquement func
    func = nb.njit(func)   # compile func en mode nopython!
    @nb.njit
    def inner(x):
        return func(x)
    return inner  # retourne la fermeture

if __name__ == '__main__':
    call_func_with_f = call_func(f)   # compile une seule fois
    print(call_func_with_f(5))        # appel de la version compilée
    print(call_func_with_f(5))        # appel de la version compilée
    print(call_func_with_f(5))        # appel de la version compilée
    print(call_func_with_f(5))        # appel de la version compilée
    print(call_func_with_f(5))        # appel de la version compilée

Juste une note générale : Je ne créerais pas de fonctions numba qui prennent un argument de fonction. Si vous ne pouvez pas coder en dur la fonction, numba ne peut pas produire des fonctions vraiment rapides et si vous incluez également le coût de compilation pour les fermetures, cela ne vaut généralement pas la peine.

2voto

jdehesa Points 22254

Comme suggéré par le message d'erreur, Numba ne peut pas gérer des valeurs de type function. Vous pouvez vérifier dans la documentation avec quels types Numba peut fonctionner. La raison en est que Numba ne peut généralement pas optimiser (jit-compiler) des fonctions arbitraires en mode nopython, elles sont considérées essentiellement comme une boîte noire (en fait, la fonction passée pourrait même être une fonction native!).

L'approche habituelle serait de demander à Numba d'optimiser la fonction appelée à la place. Si vous ne pouvez pas ajouter le décorateur à la fonction (par exemple, parce qu'elle ne fait pas partie de votre code source), vous pouvez toujours l'utiliser manuellement comme suit :

import numba as nb

def f(x):
    return x*x

if __name__ == '__main__':
    f_opt = nb.jit(nopython=True)(f)
    print(f_opt(5))

De toute évidence, cela échouera toujours si f ne peut pas non plus être compilé par Numba, mais dans ce cas, il n'y a pas grand-chose que vous puissiez faire de toute façon.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X