it-swarm.dev

¿Memoización en Haskell?

Cualquier indicador sobre cómo resolver eficientemente la siguiente función en Haskell, para grandes números (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

He visto ejemplos de memoización en Haskell para resolver los números de fibonacci, que involucraron computar (perezosamente) todos los números de fibonacci hasta el n requerido. Pero en este caso, para una n dada, solo necesitamos calcular muy pocos resultados intermedios.

Gracias

127
Angel de Vicente

Podemos hacer esto de manera muy eficiente haciendo una estructura que podamos indexar en tiempo sub-lineal.

Pero primero,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Definamos f, pero hagamos que use 'recursión abierta' en lugar de llamarse a sí mismo directamente.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Puede obtener un f no remarcado usando fix f

Esto le permitirá probar que f hace lo que quiere decir con valores pequeños de f llamando, por ejemplo: fix f 123 = 144

Podríamos memorizar esto definiendo:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Eso funciona de manera pasivamente bien, y reemplaza lo que iba a tomar O (n ^ 3) time con algo que memoriza los resultados intermedios.

Pero todavía se necesita un tiempo lineal solo para indexar y encontrar la respuesta memorizada para mf. Esto significa que resulta como:

*Main Data.List> faster_f 123801
248604

son tolerables, pero el resultado no es mucho mejor que eso. ¡Podemos hacerlo mejor!

Primero, definamos un árbol infinito:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Y luego definiremos una forma de indexar en él, para que podamos encontrar un nodo con el índice n en O (log n) time en su lugar:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... y podemos encontrar un árbol lleno de números naturales para que sea conveniente, por lo que no tenemos que jugar con esos índices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Como podemos indexar, puedes convertir un árbol en una lista:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Puedes verificar el trabajo hasta el momento verificando que toList nats te da [0..]

Ahora,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

funciona igual que en la lista anterior, pero en lugar de tomar un tiempo lineal para encontrar cada nodo, puede perseguirlo en el tiempo logarítmico.

El resultado es considerablemente más rápido:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

De hecho, es mucho más rápido que puede revisar y reemplazar Int con Integer arriba y obtener respuestas ridículamente grandes casi instantáneamente

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
245
Edward KMETT

La respuesta de Edward es una joya tan maravillosa que la he duplicado y he proporcionado implementaciones de memoList y memoTree combinators que memorizan una función en forma abierta-recursiva.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
17
Tom Ellis

No es la forma más eficiente, pero memoriza:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

cuando se solicita f !! 144, se verifica que f !! 143 existe, pero su valor exacto no se calcula. Todavía se establece como un resultado desconocido de un cálculo. Los únicos valores exactos calculados son los necesarios.

Así que inicialmente, en cuanto a cuánto se ha calculado, el programa no sabe nada.

f = .... 

Cuando hacemos la solicitud f !! 12, comienza a hacer algunos patrones de coincidencia:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora empieza a calcular

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Esto recursivamente hace otra demanda en f, por lo que calculamos

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Ahora podemos hacer una copia de seguridad de algunos

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuando a goteo:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Así que el cálculo se hace con bastante pereza. El programa sabe que existe algún valor para f !! 8, que es igual a g 8, pero no tiene idea de qué es g 8.

12
rampion

Como se indicó en la respuesta de Edward Kmett, para acelerar las cosas, debe almacenar en la memoria costosas computaciones y poder acceder a ellas rápidamente.

Para mantener la función no monádica, la solución de construir un árbol perezoso infinito, con una forma adecuada de indexarlo (como se muestra en las publicaciones anteriores) cumple con ese objetivo. Si abandona la naturaleza no monádica de la función, puede usar los contenedores asociativos estándar disponibles en Haskell en combinación con mónadas "similares a estado" (como State o ST).

Si bien el principal inconveniente es que obtiene una función no monádica, ya no tiene que indexar la estructura, y solo puede usar implementaciones estándar de contenedores asociativos.

Para hacerlo, primero debe volver a escribir su función para aceptar cualquier tipo de mónada:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Para sus pruebas, aún puede definir una función que no realice una memorización utilizando Data.Function.fix, aunque es un poco más detallado:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Luego puede usar la mónada estatal en combinación con Data.Map para acelerar las cosas:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Con cambios menores, puede adaptar el código para que funcione con Data.HashMap en su lugar:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

En lugar de estructuras de datos persistentes, también puede probar estructuras de datos mutables (como Data.HashTable) en combinación con la mónada ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

En comparación con la implementación sin ningún tipo de memorización, cualquiera de estas implementaciones le permite, para grandes insumos, obtener resultados en microsegundos en lugar de tener que esperar varios segundos.

Al usar Criterion como punto de referencia, pude observar que la implementación con Data.HashMap en realidad tuvo un rendimiento ligeramente mejor (alrededor del 20%) que el Data.Map y Data.HashTable para los cuales los tiempos fueron muy similares.

Me parecieron un poco sorprendentes los resultados del benchmark. Mi sensación inicial fue que el HashTable superaría la implementación de HashMap porque es mutable. Puede haber algún defecto de rendimiento oculto en esta última implementación.

8
Quentin

Esta es una adición a la excelente respuesta de Edward Kmett.

Cuando probé su código, las definiciones de nats y index me parecieron bastante misteriosas, así que escribo una versión alternativa que me pareció más fácil de entender.

Defino index y nats en términos de index' y nats'.

index' t n se define en el rango [1..]. (Recuerde que index t está definido en el rango [0..].) Funciona busca en el árbol tratando a n como una cadena de bits, y leyendo los bits a la inversa. Si el bit es 1, toma la rama derecha. Si el bit es 0, toma la rama de la izquierda. Se detiene cuando llega al último bit (que debe ser un 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Del mismo modo que nats se define para index, de manera que index nats n == n siempre es verdadero, nats' se define para index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Ahora, nats y index son simplemente nats' y index' pero con los valores cambiados en 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
8
Pitarou

Un par de años más tarde, miré esto y me di cuenta de que hay una forma sencilla de memorizar esto en tiempo lineal usando zipWith y una función auxiliar:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate tiene la útil propiedad de dilate n xs !! i == xs !! div i n.

Entonces, suponiendo que nos dan f (0), esto simplifica el cálculo para

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Se parece mucho a nuestra descripción original del problema y da una solución lineal (sum $ take n fs tomará O (n)).

4
rampion

Otro apéndice más a la respuesta de Edward Kmett: un ejemplo autocontenido:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Úselo de la siguiente manera para memorizar una función con un solo entero arg (por ejemplo, fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Solo se almacenarán en caché los valores para argumentos no negativos.

Para también almacenar en caché los valores de los argumentos negativos, use memoInt, definido de la siguiente manera:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Para almacenar en caché los valores de las funciones con dos argumentos enteros, use memoIntInt, definido de la siguiente manera:

memoIntInt f = memoInt (\n -> memoInt (f n))
2
Neal Young

Una solución sin indexación, y no basada en la de Edward KMETT.

Factoré subárboles comunes a un padre común (f(n/4) se comparte entre f(n/2) y f(n/4), y f(n/6) se comparte entre f(2) y f(3)). Al guardarlos como una sola variable en la matriz, el cálculo del subárbol se realiza una vez.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

El código no se extiende fácilmente a una función de memorización general (al menos, no sabría cómo hacerlo), y realmente tienes que pensar cómo se superponen los subproblemas, pero estrategia debería funcionar para Parámetros múltiples no enteros generales. (Lo pensé para dos parámetros de cadena.)

La nota se desecha después de cada cálculo. (Una vez más, estaba pensando en dos parámetros de cadena.)

No sé si esto es más eficiente que las otras respuestas. Cada búsqueda es técnicamente solo uno o dos pasos ("Mire a su hijo o al niño de su hijo"), pero puede haber un gran uso de memoria adicional.

Edit: Esta solución no es correcta todavía. El compartir es incompleto.

Edit: Debería estar compartiendo subchildren correctamente ahora, pero me di cuenta de que este problema tiene un montón de intercambio no trivial: n/2/2/2 y n/3/3 podrían ser los mismos. El problema no es un buen ajuste para mi estrategia.

2
leewz