ころがる狸

ころがる狸のデータ解析ブログ

【PyTorch+Numpy】Dataloaderに潜むありがちなバグ

PyTorchは素晴らしい機械学習フレームワークですが、データ読み込みに使うDatasetとNumpyによる乱数発生の組み合わせは思わぬバグの発生源となっているようです。2021年4月10日に投稿されたこちらの記事がTwitter上で話題になっています。

tanelp.github.io

一言で要約するなら:PyTorchでデータを読み込む際にマルチプロセス処理を行うと、親プロセスのNumpyの乱数生成器の状態が子プロセスに継承されるため、ランダムであるべき配列の値がすべて同一になる。上記の記事にはコードもついているので、どういうことか手を動かして確認してみましょう。

悪い例その1と解決策

上の記事でも紹介されているコードがこちらです。numpy.random.randint(0, 1000, 3)で0以上1000未満の3要素からなるランダムな配列を返すようDatasetを定義していますね。このデータセットから4つのプロセスを用いて2バッチずつ取り出すコードです。
#悪例その1
import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 3)

    def __len__(self):
        return 16
    
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for batch in dataloader:
    print(batch)
実行結果
tensor([[109, 172, 650],
        [962, 630, 762]]) #プロセス0の1バッチ目
tensor([[109, 172, 650],
        [962, 630, 762]]) #プロセス1の2バッチ目
tensor([[109, 172, 650],
        [962, 630, 762]]) #プロセス2の3バッチ目
tensor([[109, 172, 650],
        [962, 630, 762]]) #プロセス3の4バッチ目
tensor([[571,  66, 602],
        [568,  92, 400]]) #プロセス0の5バッチ目
tensor([[571,  66, 602],
        [568,  92, 400]]) #プロセス1の6バッチ目
tensor([[571,  66, 602],
        [568,  92, 400]]) #プロセス2の7バッチ目
tensor([[571,  66, 602],
        [568,  92, 400]]) #プロセス3の8バッチ目
結果を見ると明らかなように、それぞれのプロセスで返される配列の要素が同じになっています。この原因は、マルチプロセスでデータを取り出す際のプロセスの起動方法がforkであることにあります。プロセスの起動にはspawn, fork, forkserverとありますが、forkによってプロセスを起動する場合、親プロセスの乱数生成の状態が子プロセスに継承されます。つまり、Dataloaderからデータを取り出す前にRandomDatasetで定義された同じ状態が子プロセスに渡るため、それぞれのプロセスは同じランダム配列を返すという仕組みになっています。

1つの解決策は、Dataloaderのオプション引数としてworker_init_fnを定義することです。これはDataloaderを起動する前に各プロセスの初期化条件を指定できます。実装するとこのようになります。
#解決方法
import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 3)

    def __len__(self):
        return 16
 
worker_init_fnの定義
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4, 
            worker_init_fn=worker_init_fn)
for batch in dataloader:
    print(batch)
実行結果
tensor([[282,   4, 785],
        [ 35, 581, 521]])
tensor([[684,  17,  95],
        [774, 794, 420]])
tensor([[939, 988,  37],
        [983, 933, 821]])
tensor([[832,  50, 453],
        [ 37, 322, 981]])
tensor([[180, 413,  50],
        [894, 318, 729]])
tensor([[530, 594, 116],
        [636, 468, 264]])
tensor([[142,  88, 429],
        [407, 499, 422]])
tensor([[ 69, 965, 760],
        [360, 872,  22]])
確かにそれぞれのプロセスでランダムに異なる配列が返ってきました。しかしまだめでたしめでたしとはいきません。

悪い例その2と解決策

上のコードのprint部分を、エポックでループを回すように変更します。すると各エポックで同じ配列が返ってきます。

#エポックでループを回すと同じ配列が返る
for epoch in range(3):
    print(f"epoch: {epoch}")
    for batch in dataloader:
        print(batch)
    print("-"*25)
実行結果
epoch: 0
tensor([[282,   4, 785],
        [ 35, 581, 521]])
tensor([[684,  17,  95],
        [774, 794, 420]])
tensor([[939, 988,  37],
        [983, 933, 821]])
tensor([[832,  50, 453],
        [ 37, 322, 981]])
tensor([[180, 413,  50],
        [894, 318, 729]])
tensor([[530, 594, 116],
        [636, 468, 264]])
tensor([[142,  88, 429],
        [407, 499, 422]])
tensor([[ 69, 965, 760],
        [360, 872,  22]])
-------------------------
epoch: 1
tensor([[282,   4, 785],
        [ 35, 581, 521]])
tensor([[684,  17,  95],
        [774, 794, 420]])
tensor([[939, 988,  37],
        [983, 933, 821]])
tensor([[832,  50, 453],
        [ 37, 322, 981]])
tensor([[180, 413,  50],
        [894, 318, 729]])
tensor([[530, 594, 116],
        [636, 468, 264]])
tensor([[142,  88, 429],
        [407, 499, 422]])
tensor([[ 69, 965, 760],
        [360, 872,  22]])
-------------------------
epoch: 2
tensor([[282,   4, 785],
        [ 35, 581, 521]])
tensor([[684,  17,  95],
        [774, 794, 420]])
tensor([[939, 988,  37],
        [983, 933, 821]])
tensor([[832,  50, 453],
        [ 37, 322, 981]])
tensor([[180, 413,  50],
        [894, 318, 729]])
tensor([[530, 594, 116],
        [636, 468, 264]])
tensor([[142,  88, 429],
        [407, 499, 422]])
tensor([[ 69, 965, 760],
        [360, 872,  22]])
-------------------------
この原因は、それぞれの子プロセスの状態が各エポックが終わるごとにkillされる一方で親プロセスの状態が保持されることにあります。これにより、親プロセスは同じ状態に基づいて子プロセスを同じ乱数シードを用いてエポックごとに子プロセスを初期化するため、各エポックで得られる配列が同じになります。 これを解決するには、エポックのループが回るごとにnumpyの乱数シードを変更します。
#エポックごとに乱数シードを変更する
initial_seed = 42
for epoch in range(3):
    np.random.seed(initial_seed + epoch)
    print(f"epoch: {epoch}")
    for batch in dataloader:
        print(batch)
    print("-"*25)
このように一行追加することで、各エポックで異なる配列が得られることを検証することが出来ます。

torch.randintの力に頼ろう

最後にもう一つ、最も単純な解決策はnumpy.random.randintではなくtorch.randintの力に頼ることです。RandomDatasetの戻り値をtorch.randintに変更するだけで上記の問題はすべて解決します。

#悪例その1
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch

class RandomDataset(Dataset):
    def __getitem__(self, index):
        #ここをnumpyでなくtorchに変更する
        return torch.randint(0, 1000, (3,))

    def __len__(self):
        return 8

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=4,worker_init_fn=worker_init_fn)

for epoch in range(3):
    print(f"epoch: {epoch}")
    for batch in dataloader:
        print(batch)
    print("-"*25)
実行結果
epoch: 0
tensor([[287,  92, 420]])
tensor([[  5, 657,  59]])
tensor([[756, 153, 295]])
tensor([[614, 414, 475]])
tensor([[797, 189, 356]])
tensor([[169, 632,  88]])
tensor([[968, 105,  56]])
tensor([[276, 462, 900]])
-------------------------
epoch: 1
tensor([[681, 980, 294]])
tensor([[275, 343, 287]])
tensor([[111, 191, 580]])
tensor([[890, 615,  62]])
tensor([[553, 665, 775]])
tensor([[786,  46, 419]])
tensor([[398, 359, 288]])
tensor([[391, 322, 163]])
-------------------------
epoch: 2
tensor([[358, 967, 823]])
tensor([[637, 881, 759]])
tensor([[458,  59, 107]])
tensor([[567, 293,   4]])
tensor([[752, 267, 728]])
tensor([[ 43, 684, 590]])
tensor([[145, 415, 425]])
tensor([[207,  52, 279]])
-------------------------
以上のことから、PyTorchとNumpy、混ぜるな危険・・・とまではもちろん言えませんが、組み合わせるときは結果の挙動について十分な注意を払う必要がありそうです。実際PyTorchの公式チュートリアルやPyTorchを使っている数多のGithubレポジトリでこうしたバグが見出されたと報告されています。元の記事でも触れられていますが、数値シミュレーションと違って機械学習のバグは予測結果の不確実性に飲み込まれてしまうため検出が難しいです。このような事例があることはしっかり頭に入れておきたいですね。