Fusic Tech Blog

Fusion of Society, IT and Culture

Simple Transformersを使ってみた
2021/03/30

Simple Transformersを使ってみた

こんにちは。機械学習チームの佐藤です。テキスト要約のタスクのモデルでMultilingual-T5がありますが、オリジナルリポジトリのTensorFlow版Multilingual-T5だと少々使いづらく感じたので(特にpredictの際に)、今回Simple Transformersを使ってみました。
また今回、KaggleにてSimple Transformers T5を使ったノートブックも併せて公開しています。

https://simpletransformers.ai/

Simple Transformersはその名前の通り、自然言語処理分野のTransformer系のモデルが、 非常にシンプルな実装で使え、そのシンプルさからビギナーにやさしいライブラリとのことです。 他のライブラリ同様、テキスト分類や会話型AI等にも対応しています。

pipでインストールする場合は以下の通りです。


$ pip install simpletransformers==0.60.9

Convert

既にTensorFlow版Multilingual-T5のファインチューニング済みモデルが手元にある場合は、TensorFlowからPytorchへのコンバートをすればそのまま使えます。

コンバートするには、以下のシェルスクリプトを実行します。


参考:https://huggingface.co/transformers/converting_tensorflow_models.html


# コンバートするにあたって必要な、config.jsonを用意しておく
wget -P ./finetuned_model/ \
    https://huggingface.co/google/mt5-large/resolve/main/config.json


export T5=./finetuned_model/

transformers-cli convert --model_type t5 \
      --tf_checkpoint $T5/model.ckpt-1100000 \
        --config $T5/config.json \
          --pytorch_dump_output $T5/model.bin


# コンバート後、Pytorchとして実行するのに必要なファイルを揃える
wget -P ./finetuned_model/model.bin/ \
    https://huggingface.co/google/mt5-small/resolve/main/special_tokens_map.json

wget -P ./finetuned_model/model.bin/ \
    https://huggingface.co/google/mt5-small/resolve/main/spiece.model

wget -P ./finetuned_model/model.bin/ \
    https://huggingface.co/google/mt5-small/resolve/main/tokenizer_config.json

Fine Tuning (Training)

モデルのファインチューニングから始める場合は以下のように実装します。


import pandas as pd
from simpletransformers.t5 import T5Model

# 入力データをinput_text、教師データをtarget_text、とカラム名を変更しておく
train = pd.read_csv('train.csv').rename(columns={'inputs': 'input_text', 'targets': 'target_text'})
eval = pd.read_csv('eval.csv').rename(columns={'inputs': 'input_text', 'targets': 'target_text'})
train['prefix'] = ''
eval['prefix'] = ''

train_params = {
    'max_seq_length': 96,
    'max_length': 64,
    'train_batch_size': 16,
    'eval_batch_size': 16,
    'num_train_epochs': 2,
    'evaluate_during_training': True,
    'use_multiprocessing': False,
    'fp16': False,
    'save_steps': -1,
    'save_eval_checkpoints': False,
    'save_model_every_epoch': False,
    'no_cache': True,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'preprocess_inputs': False,
    'num_return_sequences': 1 
}

model = T5Model('t5', 't5-small', args=train_params, use_cuda=cuda)
model.train_model(train, eval_data=valid)

ファインチューニング中にディレクトリが自動生成され、ベストモデルはそこに格納されます。



Predict

TensorFlow版のMultilingual-T5は、predictとモデルのロードが切り離せない仕様の為、predictを繰り返し行う際に都度モデルのロードもされてしまい、余計に時間が掛かってしまうのですが、Simple Transformersでは通常通りこれらは分かれています。


pred_params = {
        'max_seq_length': 512,
        'use_multiprocessed_decoding': False
        }

model = T5Model('t5', 'outputs/best_model', args=pred_params, use_cuda=cuda) 

# テストデータはlistで渡す
pred = model.predict(list(test['input_text']))
print(pred)

元々は上記のデメリットの関係で、外部ライブラリをいろいろと探してしたのですが、Simple Transformersの場合、実装自体もこのようにシンプルにまとめられる為、今回取り上げました。
他のタスクに取り組む際にもまた使ってみようと思います。

Koshiro Sato

Koshiro Sato

仕事では自然言語処理、趣味ではKaggleをやっています。