Fusic Tech Blog

Fusion of Society, IT and Culture

SE-ResNeXtが簡単に使えるよ、そうChainerCVならね
2019/06/28

SE-ResNeXtが簡単に使えるよ、そうChainerCVならね

こんにちは、嶋生です。

ブログのリニューアルに伴い投稿が滞っていたのですが、無事完了しましたので早速投稿します。

今回は、ChainerCVの小ネタになります。 ChainerCVを使うとSE-ResNeXtなどの自分で作ると複雑なモデルも簡単に使えるので便利だよ、という話です。

本記事の対象

  • 機械学習に興味がある人
  • Chainer(CV)に興味がある人

ChainerCVとは?

  • Githubのコピペですが、以下のような素晴らしいツールです。

    • ChainerCV is a collection of tools to train and run neural networks for computer vision tasks using Chainer.
    • (Google翻訳さんによると) ChainerCVはChainerを使用してコンピュータビジョンタスクのためのニューラルネットワークを訓練し実行するためのツールのコレクションです。

SE-ResNeXtとは?

  • Squeeze-and-Excitation Networks(以下、SENet)ResNeXtに適用させたものです。
  • SENetは、ILSVRC2017画像分類コンペティションで1位になったモデルです。※以下、こちらの記事より抜粋。

    • 任意のCNNモデルにAttension&gating機構を追加できるSE blockを提案
    • 畳み込み層で計算された特徴量をそのまま使用するのではなく、入力画像に応じて有効な特徴量チャンネルを制御
  • ResNeXt

    • ResNetのbottleneck blockをいくつか枝分かれさせたあとに足し合わせる構造を導入したモデル(こちらのページより引用)
    • ResNeXtについてはちゃんと勉強したことがなく、詳しくありません…🙇

コード

とりあえず貼ります。後ほど一部説明します。

学習

import chainer
import chainercv
import numpy as np
import pandas as pd

from pathlib import Path
from chainer.links import Classifier
from chainer.training import extensions
from chainercv.links.model.resnet import Bottleneck
from chainercv.links import SEResNeXt50
from chainer.datasets import LabeledImageDataset
from sklearn.model_selection import train_test_split

BATCHSIZE = 64
EPOCH = 80

# 画像データは images/{LabelID}/{画像名}.jpg という感じで置いている
image_dir = Path('./images')
image_paths = sorted(list(image_dir.glob('**/*.jpg')))


label_dict = {}
for i, label_path in enumerate(sorted(list(image_dir.glob('*')))):
    label_dict[i] = label_path.name

label_dict_inv = {v: k for k, v in label_dict.items()}

class_num = len(label_dict.keys())


labels = []
for i, path in enumerate(image_paths):
    label_id = path.parent.name
    labels.append(label_dict_inv[label_id])

dataset_df = pd.DataFrame({'path': image_paths, 'label_id': labels})
dataset_df.head()

train_df, test_df = train_test_split(dataset_df, test_size=0.1, stratify=dataset_df['label_id'])

train_ = LabeledImageDataset(train_df.values)
test_ = LabeledImageDataset(test_df.values)

def transform(data):
    img, label = data
    img = chainercv.transforms.resize(img, (128, 128))
    img = img / 255.0
    img = np.array(img, dtype=np.float32)
    return img, label

train = chainer.datasets.TransformDataset(train_, transform)
test = chainer.datasets.TransformDataset(test_, transform)


log_interval = 1, 'epoch'
gpu = 0
train_iter = chainer.iterators.MultiprocessIterator(train, BATCHSIZE, n_processes=4, repeat=True, shuffle=True)
test_iter = chainer.iterators.MultiprocessIterator(test, BATCHSIZE, n_processes=4, repeat=False, shuffle=False)

extractor = SEResNeXt50(n_class=class_num)
extractor.pick = 'fc6'
extractor.to_gpu(gpu)

model = Classifier(extractor)
for l in model.links():
    if isinstance(l, Bottleneck):
        l.conv3.bn.gamma.data[:] = 0

optimizer = chainer.optimizers.Adam(alpha=1e-3)
optimizer.setup(model)

updater = chainer.training.StandardUpdater(train_iter, optimizer, device=gpu)
trainer = chainer.training.Trainer(updater, (EPOCH, 'epoch'), 'add_exp_classification_seresnext_results')

trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu), trigger=log_interval)
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']), trigger=log_interval)
trainer.extend(extensions.snapshot_object(extractor, 'model_epoch_{.updater.epoch}'), trigger=log_interval)

trainer.run()

推論

import matplotlib.pyplot as plt

test_iter = chainer.iterators.SerialIterator(test, BATCHSIZE, repeat=False, shuffle=False)
batch = test_iter.next()
test_x, test_y = chainer.dataset.concat_examples(batch)

test_x_ = extractor.xp.asarray(test_x)
preds = extractor(test_x_)

# 10個だけ表示
for i in range(10):
    img, label = test_x[i], test_y[i]
    plt.imshow(img.transpose((1,2,0)))
    plt.title(label)
    plt.show()
    
    print('予測ラベル:', preds[i].array.argmax())

一部説明

train_df, test_df = train_test_split(dataset_df, test_size=0.1, stratify=dataset_df['label_id'])
  • 今回データセットのラベル毎のデータ数に偏りがったため、train・test(validationとして使ってますが)に分ける際にstratify指定しています。stratifyを指定すると、ラベル毎のデータ数割合を保ったままデータ分割ができます。
train_iter = chainer.iterators.MultiprocessIterator(train, BATCHSIZE, n_processes=4, repeat=True, shuffle=True)
  • 上記の学習のiterator作成時ですが、SerialIteratorではなくMultiprocessIteratorを使っています。こちらは指定したn_processes分のCPUでデータ取得処理を並列化してくれます。n_processesを指定しないと、実行環境の全てのCPUを使用します。
  • 最初、説明を読まずに名前から「複数GPU時に使うものだ!」と勘違いしていました…。MultiprocessIteratorを使うことで学習速度がかなり早くなりました。

所感

  • 2年ほど前にSE-ResNetを自分で実装しようとして難航した記憶があったので、良い時代になったなぁと思いました。
  • 便利なものは率先して使って楽をして、実験の本質的な所などに注力したいと思います。

shimao

shimao

福岡在住です。Perl→PHP→Python、ここ1年ほど機械学習をやっています。福岡でKaggleもくもく会を主催しているので、興味がある方は是非ご参加ください。