Top View


Author Yosuke Higuchi

SageMaker Inference Toolkit + MMS の構成を用いてねずみ検知モデルDAMMをリアルタイムで動かしてみる

2024/08/22

概要

ECR に登録したコンテナを用いて SageMaker にねずみ検知モデルをデプロイし、リアルタイムに推論できるエンドポイントを作成しました。

目的

画像を投げるとねずみを検知してバウンディングボックスを重ねた画像を出力してくれるモデルを SageMaker 上でリアルタイム推論の形で動かしたい。

今回は SageMaker Inference Toolkit + MMS (Multi Model Server) を用いた構成(こちらの記事におけるパターン2)をベースに開発を行います。

MMS (Multi Model Server)

OSSのモデルサービングライブラリで、SageMaker 上に関わらず様々なプラットフォーム上で動作し、特定のライブラリに依存しません。 1つのコンテナ内での複数モデルの管理を行ってくれます。

SageMaker Inference Toolkit

Amazon SageMaker におけるカスタムコンテナ実装パターン詳説 〜推論編〜 より

MMS 専用の SageMaker 連携用ライブラリで SageMaker Endpoint のコンテナ内で利用されます。* ローカルの推論コードを entry_point として指定できる機能

  • ヘルスチェックへの応答や推論エラーなどをサービスのコントロールプレーンに通知する機能
  • などなど

これらの機能を実装する手間がSageMaker Inference Toolkitを用いることで省かれます。

SageMaker Real-Time Inference

Black Beltの図より alt text 今回用いる SageMaker でのリアルタイム推論の構成では、 EC2 や Load Balancer の設定をせずとも SageMaker がそれらを自動で管理してくれるところがよいところです。

用意するもの

その上で今回用意する必要のあるものが以下の3つになります。

  • モデル
  • コンテナイメージ
  • 推論コード

今回のような MMS を用いた構成では、カスタムコンテナを独自に定義することができるため、 デフォルトのコンテナイメージには存在しないような、モデルや推論コードの依存ライブラリ を用いることができます。

モデルとしては、GitHub からダウンロードしてきた ネズミ検知モデル(DAMM) を用いました。 推論コードでは、DAMM ライブラリ内の検知用のコード (DAMM_detector) を SageMaker 上の推論コードから呼び出して使用します。 また、推論コード、コンテナイメージは multi_model_bring_your_own にカスタムコンテナを用いた MMS の実装例がのっていたため、 基本的にこれらのファイルを改変して実装しました。

今回の特例的な事項

DAMM_detector を推論コードから呼び出す場合、DAMM_detector の input、output が画像ファイルであるため、コンテナ内に input、output ディレクトリを作成し、そのディレクトリで画像データを管理します。 今回はコンテナ上で1つのモデルしか使わない & S3 等を用いずにコンテナ上にモデルをダウンロードして用いました。

コンテナイメージ

multi_model_bring_your_own より改変

FROM ubuntu:20.04

# Set a docker label to advertise multi-model support on the container
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true

ENV DEBIAN_FRONTEND=noninteractive

# Install necessary dependencies for MMS and SageMaker Inference Toolkit
RUN apt-get update && \
    apt-get -y install --no-install-recommends \
    wget \
    git \
    build-essential \
    ca-certificates \
    openjdk-8-jdk-headless \
    curl \
    libssl-dev \
    libffi-dev \
    python3.9 \
    python3.9-dev \
    libgl1-mesa-glx \
    libglib2.0-0 \
    python3-distutils \
    && rm -rf /var/lib/apt/lists/* 
    && curl -O https://bootstrap.pypa.io/get-pip.py \
    && python3.9 get-pip.py

RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1

# Install  MMS, and SageMaker Inference Toolkit to set up MMS
RUN pip --no-cache-dir install multi-model-server \
                                sagemaker-inference \
                                retrying

# clone and install model dependencies
RUN git clone https://github.com/backprop64/DAMM && \
    pip install -r DAMM/requirements-gpu.txt

RUN cp -r DAMM/DAMM /usr/local/lib/python3.9/dist-packages/

RUN mkdir -p /home/data/weights/ \
             /home/data/configs/ \
             /home/data/inputs/ \
             /home/data/outputs/

# download model
RUN wget -P /home/data/weights/ https://www.dropbox.com/s/39a690qldduxawz/DAMM_weights.pth
RUN wget -P /home/data/configs/ https://www.dropbox.com/s/wegw8l5zq3vqln0/DAMM_config.yaml

# set model for initialization
RUN mkdir -p /opt/ml/model
RUN tar -czvf /opt/ml/model/damm_detector.tar.gz /home/data/weights/DAMM_weights.pth

# Copy entrypoint script to the image
COPY dockerd-entrypoint.py /usr/local/bin/dockerd-entrypoint.py
RUN chmod +x /usr/local/bin/dockerd-entrypoint.py

RUN mkdir -p /home/model-server/

# Copy the default custom service file to handle incoming data and inference requests
COPY model_handler.py /home/model-server/model_handler.py

# Define an entrypoint script for the docker image
ENTRYPOINT ["python", "/usr/local/bin/dockerd-entrypoint.py"]

# Define command to be passed to the entrypoint
CMD ["serve"]

RUN cp -r DAMM/DAMM /usr/local/lib/python3.9/dist-packages/

推論コード model_handler.py での推論時にライブラリを読み込むためにここに置きます。

RUN mkdir -p /opt/ml/model
RUN tar -czvf /opt/ml/model/damm_detector.tar.gz /home/data/weights/DAMM_weights.pth

DAMM のモデルの weights を圧縮して /opt/ml/model にコピーしていますが、この場所に置かれた weights はダミーです使用されません。 MMS の serve 時に下記の推論コードの initialize() とも違うプロセスでのモデルのロードが行われるため、それ用に用意したファイルとなります。 ないとエラーが出ます。

Model Handler

multi_model_bring_your_own より改変

import glob
import os
import base64

import cv2
import numpy as np

from DAMM.detection import Detector

CONFIG_PATH = '/home/data/configs/DAMM_config.yaml'
WEIGHTS_PATH = '/home/data/weights/DAMM_weights.pth'
INPUT_DIR = '/home/data/inputs'
OUTPUT_DIR = '/home/data/outputs'


class ModelHandler(object):
    def __init__(self):
        self.initialized = False
        self.damm_detector = None

    def initialize(self, context):
        self.initialized = True
        self.damm_detector = Detector(
            cfg_path=CONFIG_PATH,
            model_path=WEIGHTS_PATH,
            output_dir=OUTPUT_DIR,
        )

    def preprocess(self, request):
        img_path_list = []
        for idx, data in enumerate(request):
            # Read the bytearray of the image from the input
            img_arr = data.get("body")
            img_np = np.frombuffer(img_arr, np.uint8)
            img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
            if img is None:
                return None

            input_path = os.path.join(INPUT_DIR, f"image_{idx}.jpg")
            cv2.imwrite(input_path, img)

            img_path_list.append(input_path)

        return img_path_list

    def inference(self, input_paths):
        self.damm_detector.predict_img(
            input_paths,
            output_folder=OUTPUT_DIR,
        )
        return OUTPUT_DIR
    
    def encode_image_to_base64(self, image_path):
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Image at path {image_path} could not be loaded.")
        
        _, buffer = cv2.imencode('.jpg', img)
        img_bytes = buffer.tobytes()
        
        img_base64 = base64.b64encode(img_bytes).decode('utf-8')

        return img_base64

    def clear_images(self):
        image_patterns = ['*.jpg', '*.jpeg', '*.png']
        for pattern in image_patterns:
            image_files = glob.glob(os.path.join(INPUT_DIR, pattern))
            image_files.extend(glob.glob(os.path.join(OUTPUT_DIR, pattern)))
            for image_file in image_files:
                os.remove(image_file)
                print(f"Deleted: {image_file}")
        print("All image files have been deleted.")

    def postprocess(self, inference_output_dir):
        images_json = []
        for filename in os.listdir(inference_output_dir):
            if filename.endswith(('.jpg', '.jpeg', '.png')):
                image_path = os.path.join(inference_output_dir, filename)
                img_base64 = self.encode_image_to_base64(image_path)
                images_json.append({
                    'filename': filename,
                    'image_base64': img_base64
                })

        self.clear_images()
        return images_json

    def handle(self, data, context):
        print("handle start")
        input_paths = self.preprocess(data)
        print("inference start")
        model_out = self.inference(input_paths)
        print("postprocess start")
        return self.postprocess(model_out)


_service = ModelHandler()


def handle(data, context):
    if not _service.initialized:
        _service.initialize(context)

    if data is None:
        return None

    return _service.handle(data, context)

リアルタイム推論のオーバーヘッドを下げるため、 Initialize() が最初のリクエストのみ1度呼ばれます。 handle() では、前処理、推論、後処理を行います。 リクエストにおける画像のやり取りは基本的にbase64をbodyに乗せて行います。

ローカルでの実行

イメージを build して run すると、localhostでサーバーが立つ (参考)ので、 ここに curl などでリクエストを送ると、エンドポイントのテストができます。

コンテナイメージの ECR へのプッシュ

事前に AWS マネージメントコンソールで ECR のプライベートリポジトリを作成しておきます。 今回のリポジトリ名は mouse-detection-inference

aws configure

AWS CLI にログインした後、

aws ecr get-login-password --region (リージョン名) | docker login --username AWS --password-stdin (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com

ECR にログインし、

docker tag mouse_detection_004:latest (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com/mouse-detection-inference:v1.0

イメージ名を(例) mouse_detection_004:latest から ECR の URI に変更。 この時、URI の末尾はリポジトリ名とそろえる必要があります。

docker push (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com/mouse-detection-inference:v1.0

イメージを push

AWS マネージメントコンソールでのエンドポイントの作成

alt text SageMakerのページを開いた後は、推論タブのモデル、エンドポイント設定、エンドポイントの3つを設定します。 alt text モデル設定には適当なモデル名をつけます。 推論コードイメージ場所 には先ほど push したイメージの URI を設定します。 alt text バリアントの設定では、コンテナのインスタンスタイプが(左にスクロールすると)設定できます。 デフォルトの ml.m4.xlarge には gpu が載っていないので注意

以上でエンドポイントの作成は完了です。

推論エンドポイントの呼び出し

import json
import boto3
import base64

# Boto3のSageMakerランタイムクライアントを取得する
sagemaker_runtime_client = boto3.client('sagemaker-runtime')

# ローカルにある画像ファイルをバイナリデータに変換する
image_path = './video_0_frame_16164.png'
with open(image_path, 'rb') as file:
    fileDataBinary = file.read()

# 推論エンドポイントを呼び出す
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName='mouse-detection-endpoint',
    ContentType='application/octet-stream',
    Body=fileDataBinary,
    Accept='application/json'
)

# 推論結果がJSONで返ってくる
result = json.loads(response['Body'].read().decode())

with open("./result.jpg", "wb") as f:
    f.write(base64.b64decode(result["image_base64"]))

参考になったURL

Yosuke Higuchi

Yosuke Higuchi