146 votes

La mémorisation en Haskell ?

Des indications sur la manière de résoudre efficacement la fonction suivante en Haskell, pour les grands nombres. (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

J'ai vu des exemples de mémorisation en Haskell pour résoudre le fibonacci. ce qui impliquait de calculer (paresseusement) tous les nombres de fibonacci jusqu'au nombre requis n. Mais dans ce cas, pour un n donné, nous n'avons besoin de calculer que très peu de résultats intermédiaires.

Gracias

111 votes

Seulement dans le sens où c'est un travail que je fais à la maison :-)

274voto

Edward Kmett Points 18369

Nous pouvons le faire très efficacement en créant une structure que nous pouvons indexer en temps sub-linéaire.

Mais d'abord,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Définissons f mais faites-lui utiliser la "récursion ouverte" plutôt que de l'appeler directement.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Vous pouvez obtenir un non-mémorisé f en utilisant fix f

Cela vous permettra de tester que f fait ce que vous voulez dire pour de petites valeurs de f en appelant, par exemple : fix f 123 = 144

Nous pourrions mémoriser ceci en définissant :

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Cela fonctionne passablement bien, et remplace ce qui allait prendre O(n^3) temps avec quelque chose qui mémorise les résultats intermédiaires.

Mais cela prend toujours du temps linéaire juste pour indexer pour trouver la réponse mémorisée pour mf . Cela signifie que des résultats comme :

*Main Data.List> faster_f 123801
248604

sont tolérables, mais le résultat ne s'étend pas beaucoup mieux que cela. Nous pouvons faire mieux !

Tout d'abord, définissons un arbre infini :

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Puis nous définirons un moyen de l'indexer, afin de trouver un nœud avec l'indice n en O(log n) à la place :

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... et nous pouvons trouver qu'un arbre rempli de nombres naturels est pratique pour ne pas avoir à manipuler ces indices :

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Puisque nous pouvons indexer, vous pouvez simplement convertir un arbre en une liste :

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Vous pouvez vérifier le travail effectué jusqu'à présent en vérifiant que toList nats vous donne [0..]

Maintenant,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

fonctionne de la même manière que la liste ci-dessus, mais au lieu de prendre un temps linéaire pour trouver chaque nœud, on peut le poursuivre en temps logarithmique.

Le résultat est considérablement plus rapide :

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

En fait, c'est tellement plus rapide que vous pouvez passer et remplacer Int con Integer ci-dessus et obtenir des réponses ridiculement grandes presque instantanément

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

Pour une bibliothèque prête à l'emploi qui implémente la mémorisation basée sur l'arbre, utilisez MemoTrie :

$ stack repl --package MemoTrie

Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

5 votes

J'ai essayé ce code et, de manière intéressante, f_faster semblait être plus lent que f. Je suppose que ces références de liste ont vraiment ralenti les choses. La définition de nats et d'index m'a semblé assez mystérieuse, j'ai donc ajouté ma propre réponse qui pourrait rendre les choses plus claires.

0 votes

EdwardKmett J'ai passé des heures à apprendre et à rechercher comment cela fonctionne et c'est très intelligent. Mais ce que je n'arrive pas à trouver, c'est pourquoi la liste infinie prend tellement plus de mémoire que l'arbre infini ? Par exemple, si vous appelez "fastest_f 111111111" en regardant l'utilisation de la mémoire de ghci, vous pouvez voir qu'il n'utilise presque rien. Mais lorsque vous appelez "fastest_f 111111111", il utilise environ 1,5 Go, puis ghci se termine parce que je n'ai plus de mémoire. J'ai testé les appels suivants en utilisant le :set +s de ghci et fastest_f améliore sa vitesse à presque rien et faster_f aussi. Alors, que se passe-t-il ?

6 votes

Dans le cas d'une liste infinie, il s'agit d'une liste liée de 111111111 éléments. Dans le cas de l'arbre, il s'agit de log n * le nombre de nœuds atteints.

20voto

Tom Ellis Points 3455

La réponse d'Edward est un joyau si merveilleux que je l'ai dupliqué et fourni des implémentations de memoList y memoTree combinateurs qui mémorisent une fonction sous forme récursive ouverte.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)

-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f

-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f

13voto

rampion Points 38697

Ce n'est pas le moyen le plus efficace, mais il permet de mémoriser :

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

lors de la demande f !! 144 on vérifie que f !! 143 existe, mais sa valeur exacte n'est pas calculée. Elle est toujours définie comme un résultat inconnu d'un calcul. Les seules valeurs exactes calculées sont celles qui sont nécessaires.

Donc au départ, en ce qui concerne le montant calculé, le programme ne sait rien.

f = .... 

Quand nous faisons la demande f !! 12 il commence à faire des recherches de motifs :

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, il commence à calculer

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Cela fait récursivement une autre demande sur f, donc nous calculons

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Maintenant nous pouvons remonter un peu

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Ce qui signifie que le programme sait maintenant :

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ça continue d'arriver au compte-gouttes :

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Ce qui signifie que le programme sait maintenant :

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant nous continuons avec notre calcul de f!!6 :

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Ce qui signifie que le programme sait maintenant :

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant nous continuons avec notre calcul de f!!12 :

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Ce qui signifie que le programme sait maintenant :

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Le calcul est donc effectué de manière assez paresseuse. Le programme sait qu'une certaine valeur pour f !! 8 existe, qu'il est égal à g 8 mais il n'a aucune idée de ce que g 8 est.

0 votes

Merci pour celui-ci. Comment créer et utiliser un espace de solution à deux dimensions ? S'agirait-il d'une liste de listes ? et g n m = (something with) f!!a!!b

1 votes

Bien sûr, vous pourriez. Pour une vraie solution, cependant, j'utiliserais probablement une bibliothèque de mémorisation, telle que memocombinateurs

0 votes

C'est O(n^2) malheureusement.

9voto

Pitarou Points 1678

Ceci est un addendum à l'excellente réponse d'Edward Kmett.

Lorsque j'ai essayé son code, les définitions de nats y index semblait assez mystérieux, alors j'ai écrit une version alternative que je trouvais plus facile à comprendre.

Je définis index y nats en termes de index' y nats' .

index' t n est défini sur l'intervalle [1..] . (Rappelons que index t est défini sur l'intervalle [0..] .) Elle recherche l'arbre en traitant n comme une chaîne de bits, et lire les bits en sens inverse. Si le bit est 1 il prend la branche de droite. Si le bit est 0 il prend la branche de gauche. Il s'arrête lorsqu'il atteint le dernier bit (qui doit être un 1 ).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Tout comme nats est défini pour index de sorte que index nats n == n est toujours vrai, nats' est défini pour index' .

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Ahora, nats y index sont simplement nats' y index' mais avec les valeurs décalées de 1 :

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'

0 votes

Merci. Je mémorise une fonction multivariable, et cela m'a vraiment aidé à comprendre ce que font vraiment index et nats.

2voto

Neal Young Points 136

Encore un ajout à la réponse d'Edward Kmett : un exemple autonome :

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Utilisez-le comme suit pour mémoriser une fonction avec un seul argument entier (par exemple, fibonacci) :

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Seules les valeurs pour les arguments non-négatifs seront mises en cache.

Pour mettre également en cache les valeurs des arguments négatifs, utilisez memoInt définis comme suit :

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Pour mettre en cache les valeurs des fonctions avec deux arguments entiers, utilisez memoIntInt définis comme suit :

memoIntInt f = memoInt (\n -> memoInt (f n))

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X