Top View


Author ishiyama

DeepLearningを用いた画像補間アプリケーションを実装してみた

2020/04/08

画像補間(inpainting)とは

 画像補間(inpaiting)とは、例えば、画像やイラストなどの一部が切り抜かれたり他の色で塗られたりしたものに対して、その部分を補完するというタスクです。破れたり傷が入ったりしてノイズがかかった写真の修復や、写真や画像の一部を除去したい時などにそこを自然に埋めることなどできます。今回のは上から白で塗られた画像に対してその部分に関して補間を行います。最近はDeepLearningを利用するのが主流で、今回はその1つであるEdgeConnect公式実装 を元に実装しました。なお、こちらのリポジトリは非商用利用のライセンスとなっているため、ご注意ください。(今回は技術理解のために利用させていただきました。会社としての利用を行う想定はありません。)

EdgeConnectの簡単な解説

EdgeConnect

(論文より引用)

EdgeConnectは以下のように2段構えのganベースのネットワーク構成になっている。

  1. Edge Generator: 上からマスクされた画像とその画像のエッジとマスク画像からG1を通しマスクされた部分を含むエッジを作成するネットワーク
  2. Image Completion Network: 1で作成されたエッジの画像とマスクされた画像からG2を通し修復された画像を作成するネットワーク

学習はEdge GeneratorとImage Completion Networkの別々に行い、それぞれが学習し終わったら、Edge GeneratorとImage Completion Networkのつなぎこみを行い、さらに学習すると言う流れです。

エッジ作成

 Edge Generatorにデータ渡す際において、作成したアプリケーションではブラウザからマスクされた画像とマスクのみの画像を送るようにしています。そのため、エッジについてはマスクされた画像から別で作成する必要があります。今回は論文に従って、canny法(canny edge detector)を用いて実装します。  canny法はガウシアンフィルタで畳み込んだあとxとy方向に微分した勾配についてさらに細線化と閾値処理を施すことで未検出・誤検出が少なく、検出位置が正確と言われています。  canny法はscikit-imageを利用することで簡単に実装できます。

from skimage.feature import canny
edge_masked = canny(image_masked_gray, sigma=2, mask=(1 - mask_gray).astype(np.bool))
  • image_masked_gray:マスクされた画像(1ch)
  • image_masked_gray:マスク画像(1ch) (※マスクされていない部分をtrueとして渡す)

ネットワーク実装

 webアプリケーション化するに当たって、今回必要なのは修復された画像を推論するためのGenerator部分のニューラルネットワークとそのネットワークの学習済みモデルです。 学習済みモデルについてはEdgeConnectのgithubで公開されているためそれをダウンロードします。(dicriminatorの学習モデルは必要ないです) 学習済みモデルをダウンロードしたら、EdgeConnectソースコードから該当箇所を参考にしてネットワークを作成します。

Edge Generator (code)

import torch
import torch.nn as nn

class EdgeNet(nn.Module):
    def __init__(self, residual_blocks=8, use_spectral_norm=True):
        super(EdgeNet, self).__init__()

        #学習済みモデルのセッティング
        self.weight = torch.load("./src/weight/EdgeModel_gen.pth", map_location=lambda storage, loc: storage)

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        blocks = [ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) for _ in range(residual_blocks)]

        self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = torch.sigmoid(x)
        return x

Edge Generator(summary)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ReflectionPad2d-1          [-1, 3, 518, 518]               0
            Conv2d-2         [-1, 64, 512, 512]           9,472
    InstanceNorm2d-3         [-1, 64, 512, 512]               0
              ReLU-4         [-1, 64, 512, 512]               0
            Conv2d-5        [-1, 128, 256, 256]         131,200
    InstanceNorm2d-6        [-1, 128, 256, 256]               0
              ReLU-7        [-1, 128, 256, 256]               0
            Conv2d-8        [-1, 256, 128, 128]         524,544
    InstanceNorm2d-9        [-1, 256, 128, 128]               0
             ReLU-10        [-1, 256, 128, 128]               0
  ReflectionPad2d-11        [-1, 256, 132, 132]               0
           Conv2d-12        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-13        [-1, 256, 128, 128]               0
             ReLU-14        [-1, 256, 128, 128]               0
  ReflectionPad2d-15        [-1, 256, 130, 130]               0
           Conv2d-16        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-17        [-1, 256, 128, 128]               0
      ResnetBlock-18        [-1, 256, 128, 128]               0
  ReflectionPad2d-19        [-1, 256, 132, 132]               0
           Conv2d-20        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-21        [-1, 256, 128, 128]               0
             ReLU-22        [-1, 256, 128, 128]               0
  ReflectionPad2d-23        [-1, 256, 130, 130]               0
           Conv2d-24        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-25        [-1, 256, 128, 128]               0
      ResnetBlock-26        [-1, 256, 128, 128]               0
  ReflectionPad2d-27        [-1, 256, 132, 132]               0
           Conv2d-28        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-29        [-1, 256, 128, 128]               0
             ReLU-30        [-1, 256, 128, 128]               0
  ReflectionPad2d-31        [-1, 256, 130, 130]               0
           Conv2d-32        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-33        [-1, 256, 128, 128]               0
      ResnetBlock-34        [-1, 256, 128, 128]               0
  ReflectionPad2d-35        [-1, 256, 132, 132]               0
           Conv2d-36        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-37        [-1, 256, 128, 128]               0
             ReLU-38        [-1, 256, 128, 128]               0
  ReflectionPad2d-39        [-1, 256, 130, 130]               0
           Conv2d-40        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-41        [-1, 256, 128, 128]               0
      ResnetBlock-42        [-1, 256, 128, 128]               0
  ReflectionPad2d-43        [-1, 256, 132, 132]               0
           Conv2d-44        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-45        [-1, 256, 128, 128]               0
             ReLU-46        [-1, 256, 128, 128]               0
  ReflectionPad2d-47        [-1, 256, 130, 130]               0
           Conv2d-48        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-49        [-1, 256, 128, 128]               0
      ResnetBlock-50        [-1, 256, 128, 128]               0
  ReflectionPad2d-51        [-1, 256, 132, 132]               0
           Conv2d-52        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-53        [-1, 256, 128, 128]               0
             ReLU-54        [-1, 256, 128, 128]               0
  ReflectionPad2d-55        [-1, 256, 130, 130]               0
           Conv2d-56        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-57        [-1, 256, 128, 128]               0
      ResnetBlock-58        [-1, 256, 128, 128]               0
  ReflectionPad2d-59        [-1, 256, 132, 132]               0
           Conv2d-60        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-61        [-1, 256, 128, 128]               0
             ReLU-62        [-1, 256, 128, 128]               0
  ReflectionPad2d-63        [-1, 256, 130, 130]               0
           Conv2d-64        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-65        [-1, 256, 128, 128]               0
      ResnetBlock-66        [-1, 256, 128, 128]               0
  ReflectionPad2d-67        [-1, 256, 132, 132]               0
           Conv2d-68        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-69        [-1, 256, 128, 128]               0
             ReLU-70        [-1, 256, 128, 128]               0
  ReflectionPad2d-71        [-1, 256, 130, 130]               0
           Conv2d-72        [-1, 256, 128, 128]         589,824
   InstanceNorm2d-73        [-1, 256, 128, 128]               0
      ResnetBlock-74        [-1, 256, 128, 128]               0
  ConvTranspose2d-75        [-1, 128, 256, 256]         524,416
   InstanceNorm2d-76        [-1, 128, 256, 256]               0
             ReLU-77        [-1, 128, 256, 256]               0
  ConvTranspose2d-78         [-1, 64, 512, 512]         131,136
   InstanceNorm2d-79         [-1, 64, 512, 512]               0
             ReLU-80         [-1, 64, 512, 512]               0
  ReflectionPad2d-81         [-1, 64, 518, 518]               0
           Conv2d-82          [-1, 1, 512, 512]           3,137
================================================================
Total params: 10,761,089
Trainable params: 10,761,089
Non-trainable params: 0
----------------------------------------------------------------

Image Completion Network(code)

class InpaintNet(nn.Module):
    def __init__(self, residual_blocks=8):
        super(InpaintNet, self).__init__()
        #学習済みモデルのセット
        self.weight = torch.load("./src/weight/InpaintingModel_gen.pt", map_location=lambda storage, loc: storage)

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        blocks = [ResnetBlock(256, 2) for _ in range(residual_blocks)]

        self.middle = nn.Sequential(*blocks)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        )
    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = (torch.tanh(x) + 1) / 2

        return x

Image Completion Network(summary)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ReflectionPad2d-1          [-1, 4, 518, 518]               0
            Conv2d-2         [-1, 64, 512, 512]          12,608
    InstanceNorm2d-3         [-1, 64, 512, 512]               0
              ReLU-4         [-1, 64, 512, 512]               0
            Conv2d-5        [-1, 128, 256, 256]         131,200
    InstanceNorm2d-6        [-1, 128, 256, 256]               0
              ReLU-7        [-1, 128, 256, 256]               0
            Conv2d-8        [-1, 256, 128, 128]         524,544
    InstanceNorm2d-9        [-1, 256, 128, 128]               0
             ReLU-10        [-1, 256, 128, 128]               0
  ReflectionPad2d-11        [-1, 256, 132, 132]               0
           Conv2d-12        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-13        [-1, 256, 128, 128]               0
             ReLU-14        [-1, 256, 128, 128]               0
  ReflectionPad2d-15        [-1, 256, 130, 130]               0
           Conv2d-16        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-17        [-1, 256, 128, 128]               0
      ResnetBlock-18        [-1, 256, 128, 128]               0
  ReflectionPad2d-19        [-1, 256, 132, 132]               0
           Conv2d-20        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-21        [-1, 256, 128, 128]               0
             ReLU-22        [-1, 256, 128, 128]               0
  ReflectionPad2d-23        [-1, 256, 130, 130]               0
           Conv2d-24        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-25        [-1, 256, 128, 128]               0
      ResnetBlock-26        [-1, 256, 128, 128]               0
  ReflectionPad2d-27        [-1, 256, 132, 132]               0
           Conv2d-28        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-29        [-1, 256, 128, 128]               0
             ReLU-30        [-1, 256, 128, 128]               0
  ReflectionPad2d-31        [-1, 256, 130, 130]               0
           Conv2d-32        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-33        [-1, 256, 128, 128]               0
      ResnetBlock-34        [-1, 256, 128, 128]               0
  ReflectionPad2d-35        [-1, 256, 132, 132]               0
           Conv2d-36        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-37        [-1, 256, 128, 128]               0
             ReLU-38        [-1, 256, 128, 128]               0
  ReflectionPad2d-39        [-1, 256, 130, 130]               0
           Conv2d-40        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-41        [-1, 256, 128, 128]               0
      ResnetBlock-42        [-1, 256, 128, 128]               0
  ReflectionPad2d-43        [-1, 256, 132, 132]               0
           Conv2d-44        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-45        [-1, 256, 128, 128]               0
             ReLU-46        [-1, 256, 128, 128]               0
  ReflectionPad2d-47        [-1, 256, 130, 130]               0
           Conv2d-48        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-49        [-1, 256, 128, 128]               0
      ResnetBlock-50        [-1, 256, 128, 128]               0
  ReflectionPad2d-51        [-1, 256, 132, 132]               0
           Conv2d-52        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-53        [-1, 256, 128, 128]               0
             ReLU-54        [-1, 256, 128, 128]               0
  ReflectionPad2d-55        [-1, 256, 130, 130]               0
           Conv2d-56        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-57        [-1, 256, 128, 128]               0
      ResnetBlock-58        [-1, 256, 128, 128]               0
  ReflectionPad2d-59        [-1, 256, 132, 132]               0
           Conv2d-60        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-61        [-1, 256, 128, 128]               0
             ReLU-62        [-1, 256, 128, 128]               0
  ReflectionPad2d-63        [-1, 256, 130, 130]               0
           Conv2d-64        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-65        [-1, 256, 128, 128]               0
      ResnetBlock-66        [-1, 256, 128, 128]               0
  ReflectionPad2d-67        [-1, 256, 132, 132]               0
           Conv2d-68        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-69        [-1, 256, 128, 128]               0
             ReLU-70        [-1, 256, 128, 128]               0
  ReflectionPad2d-71        [-1, 256, 130, 130]               0
           Conv2d-72        [-1, 256, 128, 128]         590,080
   InstanceNorm2d-73        [-1, 256, 128, 128]               0
      ResnetBlock-74        [-1, 256, 128, 128]               0
  ConvTranspose2d-75        [-1, 128, 256, 256]         524,416
   InstanceNorm2d-76        [-1, 128, 256, 256]               0
             ReLU-77        [-1, 128, 256, 256]               0
  ConvTranspose2d-78         [-1, 64, 512, 512]         131,136
   InstanceNorm2d-79         [-1, 64, 512, 512]               0
             ReLU-80         [-1, 64, 512, 512]               0
  ReflectionPad2d-81         [-1, 64, 518, 518]               0
           Conv2d-82          [-1, 3, 512, 512]           9,411
================================================================
Total params: 10,774,595
Trainable params: 10,774,595
Non-trainable params: 0
----------------------------------------------------------------

実際に動かしてみた

画像補完の結果

以上のようになかなかの精度で画像修復できました。

補足

 アプリケーションの全体像は、Nvidia社の(https://www.nvidia.com/research/inpainting/) とほぼ同じですが、Nvidia社のものはPconv(Image Inpainting for Irregular Holes Using Partial Convolutions)と言う手法を使っています。(論文)  PConvはマスクされた場所の付近から推測するものです。ネットワークの特徴としてはganベースでなくPConv-UnetというUnetライクなニューラルネットワークで構成されていて、conv-Unet普通の畳み込み処理ではなく、マスクされた画像とマスク画像に対してそれぞれに畳み込み処理を行い、マスクを掛け直していくため、CPUとメモリの使用量が多くなり、webアプリケーションとして動かすにはマシン的にかなりコストがかかりそうだったので断念しました。

実装してみての感想

 画像として存在していない箇所が新しく自然に埋め合わされるという仕組みは非常に興味深く、処理結果も想像していたより精度が高くてすごいなと思いました。ただ、今回のEdge-ConnectはPConvより処理が軽いと行っても、webアプリケーションとして実際に運用するにはかなり重たいので、もっと軽量化する必要があるし、ディープラーニングを使ったwebアプリーションを実際に運用するのはかなりコストがかかるなといった印象を受けました。

ishiyama

ishiyama

学部3年 インターン生