Přeskočit navigaci.
Domů

Memoization, monads (code)

module MemoFix where
{-
This is my exercise program for memoizing values of a generic recursively
defined function. My aim was to make the actual function definition as clear
as possible.
-}

import IO
import Maybe
import Control.Monad.State
import Control.Monad.Fix
import Control.Monad.Identity
import Data.Array
import Data.Map (Map)
import qualified Data.Map as Map
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap

{- A State that Memos key::a / value::b pairs. -}

type MonadicFn a b m = a -> m b
type MonadicFnRec a b m = MonadicFn a b m -> MonadicFn a b m

type Memo a b s = a -> State s b
type MemoT a b m s = a -> StateT s m b
{- Type for recursive functions over the memo. -}
type MemoRec a b s = Memo a b s -> Memo a b s
type MemoRecT a b m s = MemoT a b m s -> MemoT a b m s

{- Converts unmemod function into a memod one over Maps. Takes a Memo
function and an argument. If the argument has a value in the map, returns
memod value, otherwise computes the function and puts the value in the map. -}
memoMap :: (Ord a) => MonadicFnRec a b (State (Map a b))
memoMap f x = do store <- get ; maybe compute return (Map.lookup x store)
where compute = do v <- f x ; get >>= put . Map.insert x v >> return v

memoMapT :: (Ord a, Monad m) => MonadicFnRec a b (StateT (Map a b) m)
memoMapT f x = do store <- get ; maybe compute return (Map.lookup x store)
where compute = do v <- f x ; get >>= put . Map.insert x v >> return v

{- Converts unmemod function into a memod one over Maps. Takes a Memo
function and an argument. If the argument has a value in the map, returns
memod value, otherwise computes the function and puts the value in the map. -}

memoIntMap :: MonadicFnRec Int b (State (IntMap b))
memoIntMap f x = do store <- get ; maybe compute return (IntMap.lookup x store)
where compute = do v <- f x ; get >>= put . IntMap.insert x v >> return v

memoIntMapT :: (Monad m) => MonadicFnRec Int b (StateT (IntMap b) m)
memoIntMapT f x = do store <- get ; maybe compute return (IntMap.lookup x store)
where compute = do v <- f x ; get >>= put . IntMap.insert x v >> return v

memoArray :: (Ix a) => MonadicFnRec a b (State (Array a (Maybe b)))
memoArray f x = do
store <- get
let xOk = inRange (bounds store) x
if xOk then
maybe compute return (store!x)
else f x
where compute = do v <- f x ; store <- get ; put (store // [(x,Just v)]) ; return v

memoArrayT :: (Monad m, Ix a) => MonadicFnRec a b (StateT (Array a (Maybe b)) m)
memoArrayT f x = do
store <- get
let xOk = inRange (bounds store) x
if xOk then
maybe compute return (store!x)
else f x
where compute = do v <- f x ; store <- get ; put (store // [(x,Just v)]) ; return v

{- A very simple file memo: Stores values in "/tmp/memo" ++ show key -}
memoIOT :: (Show a, Show b, Read b) => MonadicFnRec a b (StateT m IO)
memoIOT f x = do
v <- f x
lift $ memoIOImpl x v

{- A very simple file memo: Stores values in "/tmp/memo" ++ show key -}
memoIO :: (Show a, Show b, Read b) => MonadicFnRec a b IO
memoIO f x = do
v <- f x
memoIOImpl x v

memoIOImpl :: (Show a, Show b, Read b) => a -> b -> IO b
memoIOImpl x v = tryRead `catch` (computeAndSave v)
where
fname = "/tmp/memo/" ++ show x
tryRead = do
-- putStrLn $ "Reading " ++ fname
h <- openFile fname ReadMode
s <- hGetLine h
hClose h
return $ read s
computeAndSave :: (Show b) => b -> IOError -> IO b
computeAndSave v err = if not $ isDoesNotExistError err then ioError err else do
putStrLn $ "Writing " ++ fname
writeFile fname (show v)
return v

{- A fixed point combinator that takes a caching function and a function to
compute, and computes its fixed point. -}
memoFix :: MemoRecT a b m s -> MemoRecT a b m s -> MemoT a b m s
memoFix memo f = fix (memo . f) -- let mf = memo (f mf) in mf

memoFixT :: MemoRecT a b m s -> MemoRecT a b m s -> MemoT a b m s
memoFixT memo f = fix (memo . f) -- let mf = memo (f mf) in mf

{- Implementation that memos nothing. -}
memoNoneFix :: MemoRecT a b m s -> MemoT a b m s
memoNoneFix = fix

{- Implementation that memos using IntMap. -}
memoIntMapFix :: (Monad m) => MemoRecT Int b m (IntMap b) -> MemoT Int b m (IntMap b)
memoIntMapFix f = fix (memoIntMapT . f)

{- Implementation that memos using Map. -}
memoMapFixT :: (Ord a, Monad m) => MemoRecT a b m (Map a b) -> MemoT a b m (Map a b)
memoMapFixT f = fix (memoMapT . f)

memoMapFix :: (Ord a) => MemoRec a b (Map a b) -> Memo a b (Map a b)
memoMapFix f = fix (memoMap . f)

{- Implementation that memos using Array. -}
memoArrayFix :: (Ix a, Monad m) => MemoRecT a b m (Array a (Maybe b)) -> MemoT a b m (Array a (Maybe b))
memoArrayFix f = fix (memoArrayT . f)

{- If you want to use it as a module, remove the rest from here -}
{- -- 8-< ----------------------------------------------------- -}

{- A Fibonacci function example that uses less-known recursion based computation:
f 2n = (f (n+1))^2 - (f (n-1))^2
f 2n+1 = f (2n+2) - f(2n)
. -}
-- fib :: (Integral b, Monad m) => MemoRecT Integer b m s
fib :: (Integral a, Integral b, Monad m) => (a -> m b) -> (a -> m b)
fib _ 0 = return 0
fib _ 1 = return 1
fib _ 2 = return 1
fib _ 3 = return 2
fib rec x | q == 0 = do
v1 <- rec (p+1)
v2 <- rec (p-1)
return $ v1^2 - v2^2
| otherwise = do
v1 <- rec (x+1)
v2 <- rec (x-1)
return $ v1 - v2
where (p,q) = x `divMod` 2

fib' :: (Integral b, Monad m) => (Integer -> m b) -> (Integer -> m b)
fib' _ 0 = return 0
fib' _ 1 = return 1
fib' rec x = do v1 <- rec (x-1) ; v2 <- rec (x-2) ; return (v1 + v2)

{- compute and print some values of the function -}
main :: IO ()
main = do
putStrLn "First 20 Fibonacci numbers"
(sequence . map (\n -> fix (memoIO . fib) n) $ [1..20]) >>= print
putStrLn "10^1..10^8 Fibonacci numbers using Map: "
sequence . map print . map (\n -> evalState (memoMapFix fib (10^n)) em) $ [1..6]
putStrLn "3^1..3^12 Fibonacci numbers using files and an array: "
sequence . map (\n -> evalStateT (fix (memoArrayT . memoIOT . fib) (3^n)) ea >>= print) $ [1..10]
return ()
where
em = Map.empty
ea = let n = 2000 in array (0,n) [ (x, Nothing) | x <- [0..n] ] :: (Array Integer (Maybe Integer))

AttachmentSize
MemoFix.hs6.08 KB