Top View


Author naoya

Deep Learningで背景削除をしてみる

2020/01/20

背景削除とは

背景削除とはその名の通り、画像から対象物を検出し背景を削除するタスクです。

背景削除の有名なサービスとしては「remove.bg」というものがあります。下の画像はremove.bgを使って背景削除した例です。

original image

removebg

右の男性に注目すると髪の毛の部分まできちんと背景削除されていることが確認できます。。   今回はこれを目標にしていきたいと思います。

※本実験で使用する画像はPAKUTASO様より取得したものとなります。

使う技術

  • Image Matting

    • 機械学習のタスクの1つ
    • 画像を「前景」「背景」「そのどちらか」に粗く分割し、それぞれの領域における画像特徴を見て境界を決定する
    • 今回はImage MattingのうちIndexNet Mattingというモデルを利用します
  • Semantic Segmentation

    • Image Mattingを行うための前処理として利用する

trimapの作成

  • trimap

    • 「前景」「背景」「そのどちらか」の3つに分割した画像のこと,Image Mattingを実行するために必要
    • Semantic Segmentationにより作成する(手作業で作成してもいい)

実装

  1. 各種パッケージをimportします。今回はtorchvisionを使ってsegmentationを行います。
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import transforms
  1. 画像を読み込み、DeepLabv3の入力サイズに合わせてリサイズします。
image_path = 'men.png'
img = cv2.imread(image_path)
img = img[...,::-1] #BGR->RGB
h,w,_ = img.shape
img = cv2.resize(img,(320,320))
  1. モデルをデバイスに渡し、推論モードに切り替えます。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model = model.to(device)
model.eval();
  1. 画像のnumpy配列をtensor型にし、正規化します。また、バッチの次元を追加します。
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)
  1. 推論すると下の右のような画像が得られます。
with torch.no_grad():
    output = model(input_batch)['out'][0]
output = output.argmax(0)
mask = output.byte().cpu().numpy()
mask = cv2.resize(mask,(w,h))
img = cv2.resize(img,(w,h))
plt.gray()
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(mask);

deeplab out

  1. OpenCVで膨張収縮処理をしてtrimapを生成し、適当な場所に保存します。下の右のようなtrimapが得られます。
def gen_trimap(mask,k_size=(5,5),ite=1):
    kernel = np.ones(k_size,np.uint8)
    eroded = cv2.erode(mask,kernel,iterations = ite)
    dilated = cv2.dilate(mask,kernel,iterations = ite)
    trimap = np.full(mask.shape,128)
    trimap[eroded == 255] = 255
    trimap[dilated == 0] = 0
    return trimap
trimap = gen_trimap(mask,k_size=(10,10),ite=5)
cv2.imwrite('./examples/trimaps/'+id,trimap)
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(trimap)

trimap

Image Mattingでマスクを取得

IndexNet Mattingの論文の公式実装を利用します。

  1. 公式実装のリポジトリをcloneする
  2. 背景削除したい画像を"./examples/images/"に、先ほど作成したtrimapを"./examples/trimaps/"に置く
  3. script/demo.pyのimage_pathに"./examples/images/men.png" ,trimap_pathに"./examples/trimaps/men.png" を追加する
  4. script/demo.pyを実行、"./example/mattes/"に出力が格納される。

mask sample

後処理

Image Mattingにより画像に対するマスクが出力として得られます。このマスクを画像の4番目のチャンネル(アルファチャンネル)として利用します。これは画像の透過度情報を扱うためのチャンネルです。このアルファチャンネル を利用して背景を合成します。

実装

  1. 各種パッケージをimportします。
import numpy as np
import cv2
import matplotlib.pyplot as plt
  1. 画像、マスク(アルファチャンネル)を読み込み、背景を準備します。
id = 'men.png'
img = cv2.imread('./examples/images/'+id)
img = img[...,::-1]
matte = cv2.imread('./examples/mattes/'+id)
h,w,_ = img.shape
bg = np.full_like(img,255) #white background
  1. マスクを0~1.に標準化し、画像、背景に下のようにそれぞれ掛けます。それらを足し合わせると最終的な出力が得られます。
img = img.astype(float)
bg = bg.astype(float)

matte = matte.astype(float)/255
img = cv2.multiply(img, matte)
bg = cv2.multiply(bg, 1.0 - matte)
outImage = cv2.add(img, bg)
plt.imshow(outImage/255)

result people

結構上手くできました。右の男性に注目すると髪の毛の部分まできちんと抽出できていることが分かります。しかし、右の男性の手の部分など少し背景が残っているところもありました。

ちなみに猫でやってみたところなかなか上手くいったため人間以外にも応用可能なことが分かりました。

result cat

まとめ

機械学習を用いて背景削除(前景抽出)を実現することができました。上記のパイプラインを改良することで全自動で背景削除を行うこともできそうです。

naoya

naoya

学部4年のインターン生です。 kaggle expert