Table of Contents
概要
ECR に登録したコンテナを用いて SageMaker にねずみ検知モデルをデプロイし、リアルタイムに推論できるエンドポイントを作成しました。
目的
画像を投げるとねずみを検知してバウンディングボックスを重ねた画像を出力してくれるモデルを SageMaker 上でリアルタイム推論の形で動かしたい。
今回は SageMaker Inference Toolkit + MMS (Multi Model Server) を用いた構成(こちらの記事におけるパターン2)をベースに開発を行います。
MMS (Multi Model Server)
OSSのモデルサービングライブラリで、SageMaker 上に関わらず様々なプラットフォーム上で動作し、特定のライブラリに依存しません。 1つのコンテナ内での複数モデルの管理を行ってくれます。
SageMaker Inference Toolkit
Amazon SageMaker におけるカスタムコンテナ実装パターン詳説 〜推論編〜 より
MMS 専用の SageMaker 連携用ライブラリで SageMaker Endpoint のコンテナ内で利用されます。* ローカルの推論コードを entry_point として指定できる機能
- ヘルスチェックへの応答や推論エラーなどをサービスのコントロールプレーンに通知する機能
- などなど
これらの機能を実装する手間がSageMaker Inference Toolkitを用いることで省かれます。
SageMaker Real-Time Inference
Black Beltの図より 今回用いる SageMaker でのリアルタイム推論の構成では、 EC2 や Load Balancer の設定をせずとも SageMaker がそれらを自動で管理してくれるところがよいところです。
用意するもの
その上で今回用意する必要のあるものが以下の3つになります。
- モデル
- コンテナイメージ
- 推論コード
今回のような MMS を用いた構成では、カスタムコンテナを独自に定義することができるため、 デフォルトのコンテナイメージには存在しないような、モデルや推論コードの依存ライブラリ を用いることができます。
モデルとしては、GitHub からダウンロードしてきた ネズミ検知モデル(DAMM) を用いました。 推論コードでは、DAMM ライブラリ内の検知用のコード (DAMM_detector) を SageMaker 上の推論コードから呼び出して使用します。 また、推論コード、コンテナイメージは multi_model_bring_your_own にカスタムコンテナを用いた MMS の実装例がのっていたため、 基本的にこれらのファイルを改変して実装しました。
今回の特例的な事項
DAMM_detector を推論コードから呼び出す場合、DAMM_detector の input、output が画像ファイルであるため、コンテナ内に input、output ディレクトリを作成し、そのディレクトリで画像データを管理します。 今回はコンテナ上で1つのモデルしか使わない & S3 等を用いずにコンテナ上にモデルをダウンロードして用いました。
コンテナイメージ
multi_model_bring_your_own より改変
FROM ubuntu:20.04
# Set a docker label to advertise multi-model support on the container
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
ENV DEBIAN_FRONTEND=noninteractive
# Install necessary dependencies for MMS and SageMaker Inference Toolkit
RUN apt-get update && \
apt-get -y install --no-install-recommends \
wget \
git \
build-essential \
ca-certificates \
openjdk-8-jdk-headless \
curl \
libssl-dev \
libffi-dev \
python3.9 \
python3.9-dev \
libgl1-mesa-glx \
libglib2.0-0 \
python3-distutils \
&& rm -rf /var/lib/apt/lists/*
&& curl -O https://bootstrap.pypa.io/get-pip.py \
&& python3.9 get-pip.py
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1
# Install MMS, and SageMaker Inference Toolkit to set up MMS
RUN pip --no-cache-dir install multi-model-server \
sagemaker-inference \
retrying
# clone and install model dependencies
RUN git clone https://github.com/backprop64/DAMM && \
pip install -r DAMM/requirements-gpu.txt
RUN cp -r DAMM/DAMM /usr/local/lib/python3.9/dist-packages/
RUN mkdir -p /home/data/weights/ \
/home/data/configs/ \
/home/data/inputs/ \
/home/data/outputs/
# download model
RUN wget -P /home/data/weights/ https://www.dropbox.com/s/39a690qldduxawz/DAMM_weights.pth
RUN wget -P /home/data/configs/ https://www.dropbox.com/s/wegw8l5zq3vqln0/DAMM_config.yaml
# set model for initialization
RUN mkdir -p /opt/ml/model
RUN tar -czvf /opt/ml/model/damm_detector.tar.gz /home/data/weights/DAMM_weights.pth
# Copy entrypoint script to the image
COPY dockerd-entrypoint.py /usr/local/bin/dockerd-entrypoint.py
RUN chmod +x /usr/local/bin/dockerd-entrypoint.py
RUN mkdir -p /home/model-server/
# Copy the default custom service file to handle incoming data and inference requests
COPY model_handler.py /home/model-server/model_handler.py
# Define an entrypoint script for the docker image
ENTRYPOINT ["python", "/usr/local/bin/dockerd-entrypoint.py"]
# Define command to be passed to the entrypoint
CMD ["serve"]
RUN cp -r DAMM/DAMM /usr/local/lib/python3.9/dist-packages/
推論コード model_handler.py での推論時にライブラリを読み込むためにここに置きます。
RUN mkdir -p /opt/ml/model
RUN tar -czvf /opt/ml/model/damm_detector.tar.gz /home/data/weights/DAMM_weights.pth
DAMM のモデルの weights を圧縮して /opt/ml/model にコピーしていますが、この場所に置かれた weights はダミーです使用されません。 MMS の serve 時に下記の推論コードの initialize() とも違うプロセスでのモデルのロードが行われるため、それ用に用意したファイルとなります。 ないとエラーが出ます。
Model Handler
multi_model_bring_your_own より改変
import glob
import os
import base64
import cv2
import numpy as np
from DAMM.detection import Detector
CONFIG_PATH = '/home/data/configs/DAMM_config.yaml'
WEIGHTS_PATH = '/home/data/weights/DAMM_weights.pth'
INPUT_DIR = '/home/data/inputs'
OUTPUT_DIR = '/home/data/outputs'
class ModelHandler(object):
def __init__(self):
self.initialized = False
self.damm_detector = None
def initialize(self, context):
self.initialized = True
self.damm_detector = Detector(
cfg_path=CONFIG_PATH,
model_path=WEIGHTS_PATH,
output_dir=OUTPUT_DIR,
)
def preprocess(self, request):
img_path_list = []
for idx, data in enumerate(request):
# Read the bytearray of the image from the input
img_arr = data.get("body")
img_np = np.frombuffer(img_arr, np.uint8)
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
if img is None:
return None
input_path = os.path.join(INPUT_DIR, f"image_{idx}.jpg")
cv2.imwrite(input_path, img)
img_path_list.append(input_path)
return img_path_list
def inference(self, input_paths):
self.damm_detector.predict_img(
input_paths,
output_folder=OUTPUT_DIR,
)
return OUTPUT_DIR
def encode_image_to_base64(self, image_path):
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Image at path {image_path} could not be loaded.")
_, buffer = cv2.imencode('.jpg', img)
img_bytes = buffer.tobytes()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
return img_base64
def clear_images(self):
image_patterns = ['*.jpg', '*.jpeg', '*.png']
for pattern in image_patterns:
image_files = glob.glob(os.path.join(INPUT_DIR, pattern))
image_files.extend(glob.glob(os.path.join(OUTPUT_DIR, pattern)))
for image_file in image_files:
os.remove(image_file)
print(f"Deleted: {image_file}")
print("All image files have been deleted.")
def postprocess(self, inference_output_dir):
images_json = []
for filename in os.listdir(inference_output_dir):
if filename.endswith(('.jpg', '.jpeg', '.png')):
image_path = os.path.join(inference_output_dir, filename)
img_base64 = self.encode_image_to_base64(image_path)
images_json.append({
'filename': filename,
'image_base64': img_base64
})
self.clear_images()
return images_json
def handle(self, data, context):
print("handle start")
input_paths = self.preprocess(data)
print("inference start")
model_out = self.inference(input_paths)
print("postprocess start")
return self.postprocess(model_out)
_service = ModelHandler()
def handle(data, context):
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
return _service.handle(data, context)
リアルタイム推論のオーバーヘッドを下げるため、 Initialize() が最初のリクエストのみ1度呼ばれます。 handle() では、前処理、推論、後処理を行います。 リクエストにおける画像のやり取りは基本的にbase64をbodyに乗せて行います。
ローカルでの実行
イメージを build して run すると、localhostでサーバーが立つ (参考)ので、 ここに curl などでリクエストを送ると、エンドポイントのテストができます。
コンテナイメージの ECR へのプッシュ
事前に AWS マネージメントコンソールで ECR のプライベートリポジトリを作成しておきます。 今回のリポジトリ名は mouse-detection-inference
aws configure
AWS CLI にログインした後、
aws ecr get-login-password --region (リージョン名) | docker login --username AWS --password-stdin (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com
ECR にログインし、
docker tag mouse_detection_004:latest (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com/mouse-detection-inference:v1.0
イメージ名を(例) mouse_detection_004:latest から ECR の URI に変更。 この時、URI の末尾はリポジトリ名とそろえる必要があります。
docker push (AWSアカウントID).dkr.ecr.(リージョン名).amazonaws.com/mouse-detection-inference:v1.0
イメージを push
AWS マネージメントコンソールでのエンドポイントの作成
SageMakerのページを開いた後は、推論タブのモデル、エンドポイント設定、エンドポイントの3つを設定します。 モデル設定には適当なモデル名をつけます。 推論コードイメージ場所 には先ほど push したイメージの URI を設定します。 バリアントの設定では、コンテナのインスタンスタイプが(左にスクロールすると)設定できます。 デフォルトの ml.m4.xlarge には gpu が載っていないので注意
以上でエンドポイントの作成は完了です。
推論エンドポイントの呼び出し
import json
import boto3
import base64
# Boto3のSageMakerランタイムクライアントを取得する
sagemaker_runtime_client = boto3.client('sagemaker-runtime')
# ローカルにある画像ファイルをバイナリデータに変換する
image_path = './video_0_frame_16164.png'
with open(image_path, 'rb') as file:
fileDataBinary = file.read()
# 推論エンドポイントを呼び出す
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName='mouse-detection-endpoint',
ContentType='application/octet-stream',
Body=fileDataBinary,
Accept='application/json'
)
# 推論結果がJSONで返ってくる
result = json.loads(response['Body'].read().decode())
with open("./result.jpg", "wb") as f:
f.write(base64.b64decode(result["image_base64"]))
参考になったURL
Yosuke Higuchi