ころがる狸

ころがる狸のデータ解析ブログ

【PyTorch】多入力多出力モデルの作り方

こんにちは、Dajiroです。今回は、PyTorchを使った複雑なネットワークの構築についてご紹介します。機械学習モデルを組んでいると、複数の種類の入力(画像と1次元配列状のデータなど)を使ったり、複数の種類の出力を得たい場合などがあります。そんなときに必要となる多入力多出力モデルの作り方を見ていきます。

【目次】

多入力多出力モデル

ここでは、以下のようなモデルの構築を目的とします。入力には画像とそのRBGのヒストグラムを表す3つの1次元配列を用います。出力には、その画像に何が移っているか(大聖堂、食べ物、花など)、それが天然物か人工物かの2種類のクラス分類を行います。なお、実装方法の紹介がメインなので、予測精度向上にはここでは拘らないこととします。

f:id:Dajiro:20200627160058p:plain
今回作成する多入力多出力モデル

データの読み込み

率直に言って、PyTorchでの多入力・多出力モデルの構築自体は難しくはありません。むしろ複数種類のデータを読み込みの準備が面倒です。ここでは以下のようなファイル構成を前提とします。各フォルダにはそのラベルに対応する画像が保存されています。

image_data
├── cathedral
├── flower
├── food
├── lake
├── mountain
└── shrine

PyTorchで画像の分類問題を解く場合にはImageFolderという便利がクラスが存在しますが、ここでは自前のデータセットクラスを実装しそれぞれの画像からRGB情報を抽出したり、ラベル情報に修正を加えたりします。MyDatasetは辞書型オブジェクトであり、適当なキーを入れると画像・RGBヒストグラムの4種類の入力2種類のラベルが返ってきます。このあたりの実装は入力や予測対象によって大いに変わってくるため、参考程度にご覧下さい。

#画像をpillow形式で読み込み出力
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

#読み込み方法の選択。ここではpillowの読み込み器のみ実装
def default_loader(path):
    return pil_loader(path)

#画像へのパスとラベルをまとめて出力
def make_dataset(directory, class_to_idx):
    instances = []
    directory = os.path.expanduser(directory)
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                item = path, class_index
                instances.append(item)
    return instances
    
#自作のデータセットクラス
class MyDataset(Dataset):
    def __init__(self, root, loader = default_loader,
                 transform = None, target_transform=None):
        self.loader = loader
        self.transform = transform
        self.target_transform = target_transform
        classes = [d.name for d in os.scandir(root) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        
        samples = make_dataset(root, class_to_idx)
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]
        
    def __getitem__(self, index):
        path, target1 = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            #transformを利用してpillow形式の出力をテンソル化、標準化する
            sample = self.transform(sample)
            #ヒストグラムはここで生成する
            hist1 = torch.histc(sample[0], bins=256) / 50176
            hist2 = torch.histc(sample[1], bins=256) / 50176
            hist3 = torch.histc(sample[2], bins=256) / 50176
        if self.target_transform is not None:
            target1 = self.target_transform(target1)
        #1つ目のラベルから、人工物・天然物を分ける新しいラベルを生成する
        if target1 == 0 or target1 == 5: #大聖堂or神社
            target2 = 0 # 人工物ラベル
        else:
            target2 = 1 # 天然物ラベル

        return sample, hist1, hist2, hist3, target1, target2

    def __len__(self):
        return len(self.samples)

データローダーの定義

ここは一般的なPyTorchのデータローダーの定義と全く同じなので省略します。なお、今回は深入りしませんがRBGデータの抽出等はデータセットクラスを自前で準備するのでなく、データローダーのcollate_fnを使っても対応できます。

モデルの定義

画像の入力はResnet50の学習済みモデルで処理し、RGBヒストグラムは多層パーセプトロンによって処理します。入力が増えた分はforward関数の引数を増やせば良いですし、増えた出力は返り値を増やせば良いのです。非常にシンプルで実装しやすいと思いませんか?凝ったモデルを作りたいとき、PyTorchの記述のしやすさを実感します。

class Net(nn.Module):
    def __init__(self, pretrained):
        super(Net, self).__init__()
        self.pretrained = pretrained
        self.pretrained.fc = nn.Linear(2048, 64)
        self.fc1 = nn.Linear(256, 64)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(256, 64)
        self.fc_y1 = nn.Linear(64*2, 6)
        self.fc_y2 = nn.Linear(64*2, 2)
    
 #複数の入力値を引数に与える
    def forward(self, x0, x1, x2, x3):
        x0 = self.pretrained(x0)
        x1 = F.relu(self.fc1(x1))
        x2 = F.relu(self.fc2(x2))
        x3 = F.relu(self.fc3(x3))
        
        x123 = x1 + x2 + x3
        x = torch.cat((x0, x123), 1)
        
        y1 = self.fc_y1(x)
        y2 = self.fc_y2(x)
    
    #複数の出力を返す   
    return y1, y2

#モデルを実体化
use_pretrained = True
pretrained = models.resnet50(pretrained=use_pretrained)
model = Net(pretrained)

オプティマイザ

学習済みモデルの畳み込み層の重みはそのまま活用し、畳み込み後の全結合層や新たに追加したRBGデータを処理するための全結合層を更新対象として設定を行います。

#Optimizer(Adamを使用)
pram_to_update = []
#更新対象の重みパラメータを設定する
update_param_names = ["fc.weight", "fc.bias",
                      "fc1.weight", "fc1.bias",
                      "fc2.weight", "fc2.bias",
                      "fc3.weight", "fc3.bias",
                      "fc_y1.weight", "fc_y1.bias",
                      "fc_y2.weight", "fc_y2.bias"]
for name, param in model.named_parameters():
    if name in update_param_names:
        param.reqiores_grad = True
        pram_to_update.append(param)
        print(name)
    else:
        param.reqiores_grad = False
        
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

損失関数

損失関数はCrossEntropyLossとします。

criterion = nn.CrossEntropyLoss()

学習実行部分

さて、作成したデータセットやモデルを用いて学習を行います(正解率の実装などは省略しています)。データローダーからの取り出しパラメータが増えること、モデルの引数や戻り値が増えることに注意します。また、それぞれの出力に対して損失を定義しますが、学習全体の損失をその加算として定義します(loss = loss1 + loss2)。このとき、loss1, loss2に適当な重みを掛けることにより、どちらの損失をより重視して学習を行うかを制御できます。この係数が新たなハイパーパラメータとなるので、うまく制御すればより良いモデルが作れるかもしれません。それ以外は一般的な学習とほとんど同じです。

dataloaders_dict = {"train":train_loader, "val":val_loader}
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    print("-----------")
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
        
        epoch_loss = 0
        
        if (epoch == 0) and (phase == "train"):
            continue
        
            #入力、ラベルを一挙に取り出し
            for inp1, inp2, inp3, inp4, label1, label2 in dataloaders_dict[phase]:
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                out1, out2 = model(inp1, inp2, inp3, inp4)
                loss1 = criterion(out1, label1)
                loss2 = criterion(out2, label2)
                #2つの損失を適当な重みをかけて束ねる 
                loss = loss_ratio * loss1 + (1 - loss_ratio) * loss2
                
                if phase == "train":
                    loss.backward()
                    optimizer.step()
                
                epoch_loss += loss.item() * inp1.size(0)
        
        epoch_loss = epoch_loss/len(dataloaders_dict[phase].dataset)
        print(f"{phase} Loss:{epoch_loss :.4f}")

おわりに

PyTorchの多入力、多出力モデルの構築方法を見てきました。モデルの構築は非常に簡単で、複雑なネットワークも楽に組むことができます。一方でデータの読み込みに関してはImageFolderなどの便利なツールがそのまま使えないことも多く、自前のデータセットを準備する必要があるなどやや面倒です。自分で実装するとデータの読み込みが妙に遅かったりするので、そこらへんはPyTorchのソースコードを参考に実装するのが良いと思います。