Fusic Tech Blog

Fusion of Society, IT and Culture

torchvisionを使って人のキーポイント検出をする
2019/07/20

torchvisionを使って人のキーポイント検出をする

前回はchainercvが便利という話だったのですが、今回はtorchvisionのTipsです。

人物のキーポイント検出はOpenPose・PoseNetなどが有名ですが、Keypoint R-CNN(Mask R-CNN)でも可能です。(細部はもちろん異なります) torchvisonを使うと簡単に利用できるのでご紹介致します。

torchvisionとは

torchvisionとはPyTorchのパッケージで、コンピュータビジョンにおける有名なデータセット(MNIST、COCOなど)・モデルアーキテクチャ・画像変換処理から構成されます。

キーポイント検出をやってみる

※本実験で使用する画像はPAKUTASO様より取得したものとなります。

早速やってみます。以下説明しながらコードを記載します。(JupyterNotebookで実行する前提です)

まずパッケージ・テスト画像の読み込みを行います。

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from PIL import Image
from chainercv.visualizations import vis_bbox, vis_point

image_path = 'yuseiIMGL2349_TP_V.jpg'
image = Image.open(image_path).convert('RGB')

使用するデバイス(CPU or GPU)を指定します。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

学習済のKeyPoint R-CNNのモデルを読み込んでデバイスに渡します。model.eval()にてモデルのモードを推論モードに切り替えます。

model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
model = model.to(device)
model.eval()

ImageオブジェクトをTensorにしてデバイスに渡します。ついでにモデルに渡す形(リスト)にしておきます。

image_tensor = torchvision.transforms.functional.to_tensor(image)
x = [image_tensor.to(device)]

推論して、物体検出の結果・キーポイント検出の結果を取得します。

prediction = model(x)[0]

bboxes_np = prediction['boxes'].to(torch.int16).cpu().numpy()
labels_np = prediction['labels'].byte().cpu().numpy()
scores_np = prediction['scores'].cpu().detach().numpy()
keypoints_np = prediction['keypoints'].to(torch.int16).cpu().numpy()
keypoints_scores_np = prediction['keypoints_scores'].cpu().detach().numpy()

今回は結果を描画するのに、chainercv.visualizations.vis_bbox, chainercv.visualizations.vis_pointを使用します。

上記のメソッドを使うためにデータを整形します。詳しい形はこちらに書いてあるので気になる方はご確認ください。 ついでにscoreが0.8以上のもののみ採用するようにしてます。

bboxes = []
labels = []
scores = []
keypoints = []

for i, bbox in enumerate(bboxes_np):
    score = scores_np[i]
    if score < 0.8:
        continue

    label = labels_np[i]
    keypoint = keypoints_np[i]
    
    bboxes.append([bbox[1], bbox[0], bbox[3], bbox[2]])
    labels.append(label - 1)
    scores.append(score)
    keypoints.append(keypoint)

bboxes = np.array(bboxes)
labels = np.array(labels)
scores = np.array(scores)
keypoints = np.array(keypoints)

以下、結果です!

points = np.dstack([keypoints[:, :, 1], keypoints[:, :, 0]])

img = image_tensor.mul(255).byte().numpy()
vis_bbox(img, bboxes, labels, scores, label_names=('person',))
vis_point(img, points)
plt.show()

物体検出結果

キーポイント検出結果

こんな感じで簡単に使えます。人物の姿勢検出などに使えそうです。

以上です!

shimao

shimao

福岡在住です。Perl→PHP→Python、ここ1年ほど機械学習をやっています。福岡でKaggleもくもく会を主催しているので、興味がある方は是非ご参加ください。