Stateモナドの使い道 純粋関数内で状態を扱う

System.Random について調べるコードを考えてみよう。
0から9までのランダムな整数を繰り返し生成するとき、最初に5が現れるのが何回目か知りたいとする。

System.Random には randomRs という関数があり、型と範囲と乱数生成器を指定すると、ランダムな値の無限リストに評価される。
この関数とリストを操作する関数を使えば、下記のようにすっきり記述できる。

場合によってはこれで十分だろう。ただ、処理効率や可読性の面で、より手続き型に近い記述にしたい場面がありそうだ。

count :: R.RandomGen g => g -> Int -> Int
count g n = length $ takeWhile (/=n) $ R.randomRs (0::Int,9) g

--
-- trace version
--
count' :: R.RandomGen g => g -> Int -> Int
count' g n = length $ takeWhile (\x -> trace ("x: " ++ show x) (x/=n)) $ R.randomRs (0::Int,9) g


方法1. Stateモナドを使う

System.Random には randomR という関数があり、型と範囲と乱数生成器を指定すると、ランダムな値と新しい乱数生成器の組に評価される。randomRs の単発版である。

execState は初期状態と Stateモナドを使用する関数を指定すると、終了状態に評価される関数だ。
Stateモナドは、副作用を扱うという点ではIOモナドと似ているが、純粋関数内に閉じ込められる点が異なる。
Stateモナド内で現在の状態を得るにはgetを、状態を更新するにはputを使用する。


countState :: R.RandomGen g => g -> Int -> Int
countState g n = snd $ execState loop (g,0)
  where
    loop :: R.RandomGen g => State (g,Int) ()
    loop = do
      (g,i) <- get
      let (x,g') = R.randomR (0::Int,9) g
      when (x/=n) $ put (g',i+1) >> loop -- n と同じ値が出るまで、状態を書き換えて繰り返す

--
-- trace version
--
countState' :: R.RandomGen g => g -> Int -> Int
countState' g n = snd $ execState loop (g,0)
  where
    loop :: R.RandomGen g => State (g,Int) ()
    loop = do
      (g,i) <- get
      let (x,g') = R.randomR (0::Int,9) g
      when (trace ("x: " ++ show x) (x/=n)) $ put (g',i+1) >> loop

方法2. STモナドを使う

また、IOモナドから入出力に関する機能を取り除き、純粋関数内で評価できるようにした、STモナドも使用できる。
IOモナド内で IORef を使用する感覚で STモナド内で使用できる STRefという型があり、状態を保持することができる。
STRef は 複数作成しても良いので状態の管理が複雑なときは便利かもしれない。


countST :: R.RandomGen g => g -> Int -> Int
countST g n = runST $ do
  ref <- newSTRef (g,0)
  loop ref
  (_,i) <- readSTRef ref -- 繰り返した回数を取り出して報告
  return i
    where
      loop :: R.RandomGen g => STRef s (g,Int) -> ST s ()
      loop ref = do
        (g,i) <- readSTRef ref
        let (x,g') = R.randomR (0::Int,9) g
        when (x/=n) $ writeSTRef ref (g',i+1) >> loop ref -- n と同じ値が出るまで、状態を書き換えて繰り返す

--
-- trace version
--
countST' :: R.RandomGen g => g -> Int -> Int
countST' g n = runST $ do
  ref <- newSTRef (g,0)
  loop ref
  (_,i) <- readSTRef ref
  return i
    where
      loop :: R.RandomGen g => STRef s (g,Int) -> ST s ()
      loop ref = do
        (g,i) <- readSTRef ref
        let (x,g') = R.randomR (0::Int,9) g
        when (trace ("x: " ++ show x) (x/=n)) $ writeSTRef ref (g',i+1) >> loop ref

全部のせておく。

import qualified System.Random as R
import Debug.Trace (trace)

import Control.Monad (when)
import Control.Monad.State(State,execState,get,put)
import Control.Monad.ST(ST,runST)
import Data.STRef(STRef,newSTRef,readSTRef,writeSTRef)


main = do
  g <- R.newStdGen
  print $ count g 5
  print $ countState g 5
  print $ countST g 5
  print $ count' g 5
  print $ countState' g 5
  print $ countST' g 5



count :: R.RandomGen g => g -> Int -> Int
count g n = length $ takeWhile (/=n) $ R.randomRs (0::Int,9) g

--
-- trace version
--
count' :: R.RandomGen g => g -> Int -> Int
count' g n = length $ takeWhile (\x -> trace ("x: " ++ show x) (x/=n)) $ R.randomRs (0::Int,9) g



countState :: R.RandomGen g => g -> Int -> Int
countState g n = snd $ execState loop (g,0)
  where
    loop :: R.RandomGen g => State (g,Int) ()
    loop = do
      (g,i) <- get
      let (x,g') = R.randomR (0::Int,9) g
      when (x/=n) $ put (g',i+1) >> loop

--
-- trace version
--
countState' :: R.RandomGen g => g -> Int -> Int
countState' g n = snd $ execState loop (g,0)
  where
    loop :: R.RandomGen g => State (g,Int) ()
    loop = do
      (g,i) <- get
      let (x,g') = R.randomR (0::Int,9) g
      when (trace ("x: " ++ show x) (x/=n)) $ put (g',i+1) >> loop
      


countST :: R.RandomGen g => g -> Int -> Int
countST g n = runST $ do
  ref <- newSTRef (g,0)
  loop ref
  (_,i) <- readSTRef ref
  return i
    where
      loop :: R.RandomGen g => STRef s (g,Int) -> ST s ()
      loop ref = do
        (g,i) <- readSTRef ref
        let (x,g') = R.randomR (0::Int,9) g
        when (x/=n) $ writeSTRef ref (g',i+1) >> loop ref

--
-- trace version
--
countST' :: R.RandomGen g => g -> Int -> Int
countST' g n = runST $ do
  ref <- newSTRef (g,0)
  loop ref
  (_,i) <- readSTRef ref
  return i
    where
      loop :: R.RandomGen g => STRef s (g,Int) -> ST s ()
      loop ref = do
        (g,i) <- readSTRef ref
        let (x,g') = R.randomR (0::Int,9) g
        when (trace ("x: " ++ show x) (x/=n)) $ writeSTRef ref (g',i+1) >> loop ref