numpyのRandomStateとmultiprocessing使用時のseed値の罠
numpyとmultiprocessingでシミュレーション的なことをしていた。思ったより良い感じにならなかったしなんとなくおかしいなと思って調査。
numpy.randomとmultiprocessingを使って並列に乱数生成する。適当にその乱数を出力してみる。
コード
import os import time from multiprocessing import Pool import numpy as np def func(x: int) -> None: print(f"x={x}, np.random={np.random.rand()}, pid={os.getpid()}") if __name__ == '__main__': with Pool(10) as p: p.map(func, range(10))
結果
python test.py x=0, np.random=0.5962941676890798, pid=29550 x=1, np.random=0.5962941676890798, pid=29551 x=2, np.random=0.5962941676890798, pid=29552 x=3, np.random=0.5962941676890798, pid=29553 x=4, np.random=0.5962941676890798, pid=29554 x=5, np.random=0.5962941676890798, pid=29555 x=6, np.random=0.5962941676890798, pid=29556 x=9, np.random=0.030654881568257686, pid=29550 x=7, np.random=0.5962941676890798, pid=29557 x=8, np.random=0.5962941676890798, pid=29558
ほとんど同じやんけ。
調べてみるとnumpyのRandomStateは生成時に現在時刻からseed値を決定するらしい(しっかり調べてない)。なんとなくわかった。
回避策としては親プロセスが子プロセスにseed値を与えるぐらい?
変更後
def func(y: tuple) -> None: x, seed = y np.random.seed(seed) print(f"x={x}, np.random={np.random.rand()}, pid={os.getpid()}") if __name__ == '__main__': with Pool(10) as p: p.map(func, zip(range(10), np.random.randint(0, 2 ** 32 -1, 10)))
変更後結果
python test.py x=0, np.random=0.9123107234323904, pid=30126 x=1, np.random=0.01700757843088274, pid=30127 x=2, np.random=0.15402676571119933, pid=30128 x=3, np.random=0.8384796480668341, pid=30129 x=4, np.random=0.03180545379613742, pid=30130 x=5, np.random=0.9923826065352217, pid=30131 x=6, np.random=0.8970540102911799, pid=30132 x=7, np.random=0.9861043353661219, pid=30133 x=8, np.random=0.6360555488868728, pid=30134 x=9, np.random=0.9287125485831335, pid=30135
もっとエレガントな方法はないのか。