Table of Contents
torchvisionとは
torchvisionとはPyTorchのパッケージで、コンピュータビジョンにおける有名なデータセット(MNIST、COCOなど)・モデルアーキテクチャ・画像変換処理から構成されます。
キーポイント検出をやってみる
※本実験で使用する画像はPAKUTASO様より取得したものとなります。
早速やってみます。以下説明しながらコードを記載します。(JupyterNotebookで実行する前提です)
まずパッケージ・テスト画像の読み込みを行います。
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from PIL import Image
from chainercv.visualizations import vis_bbox, vis_point
image_path = 'yuseiIMGL2349_TP_V.jpg'
image = Image.open(image_path).convert('RGB')
使用するデバイス(CPU or GPU)を指定します。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
学習済のKeyPoint R-CNNのモデルを読み込んでデバイスに渡します。model.eval()
にてモデルのモードを推論モードに切り替えます。
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
model = model.to(device)
model.eval()
ImageオブジェクトをTensorにしてデバイスに渡します。ついでにモデルに渡す形(リスト)にしておきます。
image_tensor = torchvision.transforms.functional.to_tensor(image)
x = [image_tensor.to(device)]
推論して、物体検出の結果・キーポイント検出の結果を取得します。
prediction = model(x)[0]
bboxes_np = prediction['boxes'].to(torch.int16).cpu().numpy()
labels_np = prediction['labels'].byte().cpu().numpy()
scores_np = prediction['scores'].cpu().detach().numpy()
keypoints_np = prediction['keypoints'].to(torch.int16).cpu().numpy()
keypoints_scores_np = prediction['keypoints_scores'].cpu().detach().numpy()
今回は結果を描画するのに、chainercv.visualizations.vis_bbox, chainercv.visualizations.vis_point
を使用します。
上記のメソッドを使うためにデータを整形します。詳しい形はこちらに書いてあるので気になる方はご確認ください。 ついでにscoreが0.8以上のもののみ採用するようにしてます。
bboxes = []
labels = []
scores = []
keypoints = []
for i, bbox in enumerate(bboxes_np):
score = scores_np[i]
if score < 0.8:
continue
label = labels_np[i]
keypoint = keypoints_np[i]
bboxes.append([bbox[1], bbox[0], bbox[3], bbox[2]])
labels.append(label - 1)
scores.append(score)
keypoints.append(keypoint)
bboxes = np.array(bboxes)
labels = np.array(labels)
scores = np.array(scores)
keypoints = np.array(keypoints)
以下、結果です!
points = np.dstack([keypoints[:, :, 1], keypoints[:, :, 0]])
img = image_tensor.mul(255).byte().numpy()
vis_bbox(img, bboxes, labels, scores, label_names=('person',))
vis_point(img, points)
plt.show()
こんな感じで簡単に使えます。人物の姿勢検出などに使えそうです。
以上です!
shimao
福岡在住。機械学習エンジニア。AWS SAP・MLS。趣味はKaggleで、現在ソロ銀1。福岡でKaggleもくもく会を主催しているので、どなたでもぜひご参加ください。
Related Posts
Teodor TOSHKOV
2022/06/13
Teodor TOSHKOV
2021/09/28
Teodor TOSHKOV
2021/08/24