Core ML 3 で On-Device Model Personalization
これはInfocom Advent Calendar 2019 22日目の記事です.
WWDC2019で発表され,Core ML 3から可能になったOn-Device Model Personalizationについてまとめ,アップデート可能なMLモデルを作ってみました.
On-Device Model Personalizationとは
ユーザに共通だったMLモデルを,デバイス上でMLモデルの追加学習をすることで,ユーザごとに個別化することを言います.
従来,MLモデルは学習済みのものをアプリにバンドルしておいたり, 更新版をダウンロードして入れ替えたりしてアプリに適用していましたが, いづれもモデルはユーザに共通のものでした. しかし,Core ML 3でサポートされたMLモデルのアップデート(追加学習)の仕組みにより, iOSデバイス上でMLモデルを追加学習し,ユーザ個別のMLモデルへとアップデートできるようになりました.
端末上でスタンドアローンにアップデートが行えるため,サーバが必要ない,オフライン動作が可能,プライバシが保護できる,といった利点があります. 追加学習には,iOS13以降で使えるMLUpdateTaskクラスを利用します.
WWDCのセッションでは,手書きの図形から絵文字に変換するアプリの機能において,ユーザが図形→絵文字の対応を任意に追加できる,というデモをしていました. 詳しくはWWDCの”Core ML 3 Framework”セッションを参照してください.
Core ML 3 Framework - WWDC 2019 - Videos - Apple Developer
MLモデル作成ツール
MLモデルを作るためにTuri Create(トゥーリクリエイト)を使いました. Turi CreateはPythonのライブラリとして提供されていて,Kerasのようなアルゴリズムベースではなく,タスクベース(画像分類,オブジェクト検出,回帰など)で簡単に早くMLモデルを作成できるツールです. CreateMLを使ってもMLモデルを簡単に作ることはできますが,CreateMLを使って作ったMLモデルはアップデート可能なMLモデルに変換できないため,今回のケースでは使えません.
また,MLモデルをアップデート可能なものにするためにcoremltoolsを使いました. こちらもPythonのライブラリとして提供されています.
他にも手段はあると思いますが,どちらもAppleが公式に提供しているツールなので,今回はこれらを使いました.
アップデート可能なMLモデルの作成
On-Device Model Personalizationを行う,猫種分類アプリを作る想定でMLモデルを作ってみます. 猫種分類アプリは猫の画像を入力として,猫種(今回はべンガル,バーマン,三毛猫,ペルシャの4種)を分類するアプリとします. それだけだと通常のMLアプリですが,これに自分の猫を追加学習できるようにし,自分の猫も分類できるようにする想定です.
アップデート可能なMLモデルの作成は次のステップで行います.
- 猫種分類MLモデルの作成(CatsClassifier.mlmodel)
- MLモデルをアップデート可能なものに変換(UpdatableCatsClassifier.mlmodel)
アップデート可能モデルの拡張子はmlmodelですが,アプリ実行時に追加学習のためにDocumentsディレクトリ等にコピーする際はUpdatableCatsClassifier.mlmodelcをコピーする点に注意が必要です. アップデート可能なMLモデルをアプリに組み込んで追加学習を行う際は気をつけてください. mlmodelcファイルはmlmodelファイルがコンパイルされたもので,アプリビルド時に自動的にXcodeが作成します.
環境
$ sw_vers
ProductName: Mac OS X
ProductVersion: 10.14.5
BuildVersion: 18F132
$ python --version
Python 3.7.1
$ pip show turicreate | grep Version
Version: 6.0
$ pip show coremltools | grep Version
Version: 3.1
猫種分類MLモデルの作成
TuriCreateで作成していきます. ちなみに,以前CreateMLを使ってMLモデルを作ったことはあったのですが,それ以外で自分で作ったことはありませんでした. しかし,Turi Createはとても簡単で,チュートリアルを読めばすぐ自分でMLモデルを作れました!
まずは,学習用画像を集めます.
詳細は本記事では割愛しますが,icrawlerなどを使えば楽に画像収集ができます.
Turi Createの学習用画像はJPEGのみ有効です(他の形式がある場合,その画像はスキップされワーニングが表示されます).
収集した画像はtraining-images
ディレクリにラベル名をディレクトリ名として保存します.
各ラベルごとの画像は50枚程度用意しました.
training-images
├── bengal
│ ├── 000001.jpg
│ ├── ...
├── birman
│ ├── 000001.jpg
│ ├── ...
├── mikeneko
│ ├── 000001.jpg
│ ├── ...
└── persian
├── 000001.jpg
├── ...
学習用画像が準備できたら,create-sframe.py
を実行しSFrameファイルを作成します.
SFrameはTuri Createで共通的に使われるデータフレームオブジェクトです.
import turicreate as tc
import re
# 学習用画像をロードする(JPEG以外はスキップされる)
data = tc.image_analysis.load_images('training-images', with_path=True)
# ディレクトリ名からラベルを作成する
data['label'] = data['path'].apply(lambda path: re.match('training-images/(.*)/.*', path).group(1))
# SFrameをファイルに保存する
data.save('cats.sframe')
# データの内容をビジュアルで確認したい場合は,インタラクティブモードでPythonを動かして次を実行する
# data.explore()
$ python create-sframe.py
python create-sframe.py 11.01s user 3.21s system 180% cpu 7.859 total
$ ll -d cats.sframe
drwxr-xr-x 7 otti staff 224B Dec 22 01:52 cats.sframe/
次に,create-model.py
を実行し,MLモデルを作成します.
ここでは,さきほど作成したSFrameを元に学習を実行し,Turi Createのモデルフォーマットである.model
と,CoreML用のモデルフォーマットである.mlmodel
を生成しています.
デフォルトではResnet50をベースに転移学習が行われるため,学習が早いです(30秒くらいでした).
import turicreate as tc
# SFrameの読み込み
data = tc.SFrame('cats.sframe')
# 学習用画像を8:2の割合でランダムに訓練用とテスト用に分ける
train_data, test_data = data.random_split(0.8)
# モデルの作成
model = tc.image_classifier.create(train_data, target='label')
# モデルを評価する
metrics = model.evaluate(test_data)
print(metrics['accuracy'])
# Turi Create形式のモデルファイルを保存する
model.save('cats.model')
# Core ML形式のモデルファイルを保存する
model.export_coreml('CatsClassifier.mlmodel')
$ python create-model.py
Analyzing and extracting image features.
+------------------+--------------+------------------+
| Images Processed | Elapsed Time | Percent Complete |
+------------------+--------------+------------------+
| 64 | 1.73s | 33.25% |
| 128 | 4.34s | 66.5% |
| 159 | 6.94s | 100% |
+------------------+--------------+------------------+
PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.
You can set ``validation_set=None`` to disable validation tracking.
Logistic regression:
--------------------------------------------------------
Number of examples : 151
Number of classes : 4
Number of feature columns : 1
Number of unpacked features : 2048
Number of coefficients : 6147
Starting L-BFGS
--------------------------------------------------------
+-----------+----------+-----------+--------------+-------------------+---------------------+
| Iteration | Passes | Step size | Elapsed Time | Training Accuracy | Validation Accuracy |
+-----------+----------+-----------+--------------+-------------------+---------------------+
| 0 | 2 | 1.000000 | 0.043686 | 0.814570 | 0.625000 |
| 1 | 5 | 0.500000 | 0.144851 | 0.887417 | 0.625000 |
| 2 | 7 | 0.319723 | 0.229275 | 0.933775 | 0.750000 |
| 3 | 10 | 1.598616 | 0.333527 | 0.980132 | 0.875000 |
| 4 | 12 | 1.253864 | 0.411987 | 0.993377 | 1.000000 |
| 9 | 19 | 1.000000 | 0.737624 | 1.000000 | 0.875000 |
+-----------+----------+-----------+--------------+-------------------+---------------------+
Analyzing and extracting image features.
+------------------+--------------+------------------+
| Images Processed | Elapsed Time | Percent Complete |
+------------------+--------------+------------------+
| 34 | 1.73s | 100% |
+------------------+--------------+------------------+
0.9705882352941176
WARNING:root:TensorFlow version 2.0.0 detected. Last version known to be fully compatible is 1.14.0 .
WARNING:tensorflow:From /Users/otti/.pyenv/versions/odmp/lib/python3.7/site-packages/tensorflow_core/python/compat/v2_compat.py:65: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:tensorflow:From /Users/otti/.pyenv/versions/odmp/lib/python3.7/site-packages/tensorflow_core/python/compat/v2_compat.py:65: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
Downloading https://docs-assets.developer.apple.com/turicreate/models/resnet-50-TuriCreate-6.0.h5
Download completed: /var/folders/hj/26l3gdxj1jz7bl1pmr0b_j040000gn/T/model_cache/resnet-50-TuriCreate-6.0.h5
Downloading https://docs-assets.developer.apple.com/turicreate/models/resnet-50-TuriCreate-6.0.mlmodel
Download completed: /var/folders/hj/26l3gdxj1jz7bl1pmr0b_j040000gn/T/model_cache/resnet-50-TuriCreate-6.0.mlmodel
Input name(s) and shape(s):
data : (C,H,W) = (3, 224, 224)
Neural Network compiler 0: 160 , name = bn_data, output shape : (C,H,W) = (3, 224, 224)
(中略)
Neural Network compiler 175: 175 , name = labelProbability, output shape : (C,H,W) = (4, 1, 1)
python create-model.py 33.12s user 8.71s system 105% cpu 39.781 total
$ ll -d cats.model CatsClassifier.mlmodel
-rw-r--r-- 1 otti staff 90M Dec 22 01:53 CatsClassifier.mlmodel
drwxr-xr-x 31 otti staff 992B Dec 22 01:53 cats.model/
途中で出力しているモデルの精度は 0.9705882352941176 でした.
MLモデルをアップデート可能なものに変換
作成したCatsClassifier.mlmodel
を使えば学習した4種の猫種を分類することはできますが,
アップデート可能なMLモデルにするにはもうひと手間必要になります.
make-model-updatable.py
を実行すると,主に次のことをMLモデルに対して行います.
- MLモデルとニューラルネットワークレイヤの
isUpdatable
プロパティをtrueにセット - ロスファンクションの追加(学習時に精度を評価するもの)
- オプティマイザを追加(学習時にニューラルネットワークのウェイトを調整するもの)
- 学習用のハイパーパラメタを追加
なお,ここではロスファンクションにcategorical cross-entropyを,オプティマイザにSGD (Stochastic Gradient Descent)を使っていますが, MSE (mean squared error)というロスファンクションや,Adamというオプティマイザを使うこともできます. これらの違いはちょっとよくわかっていません.
import coremltools
import numpy as np
from coremltools.models.neural_network import NeuralNetworkBuilder, SgdParams
# モデルの読み込み
model = coremltools.models.MLModel("CatsClassifier.mlmodel")
spec = model._spec
# アップデート可能にするレイヤ(最後の全結合レイヤ)
layer = spec.neuralNetworkClassifier.layers[-2]
# 全結合レイヤをユーザの猫を1つ登録できるように拡張する
layer.innerProduct.outputChannels = 5 # 既存4猫種+ユーザ用の1つ
weights = np.zeros(1 * 2048) # Netronで確認すると,1チャネルにつき2048であるため
biases = np.zeros(1)
layer.innerProduct.weights.floatValue.extend(weights)
layer.innerProduct.bias.floatValue.extend(biases)
# モデルの分類ラベルを追加
labels = ["my cat"]
spec.neuralNetworkClassifier.stringClassLabels.vector.extend(labels)
# モデルをベースに,新しいモデルを作成
builder = NeuralNetworkBuilder(spec=model._spec)
spec.specificationVersion = 4
# 全結合レイヤをアップデート可能にする
builder.make_updatable(["fc1"]) # layer.nameの値
# ロスファンクションの追加
builder.set_categorical_cross_entropy_loss(name="lossLayer", input="labelProbability")
# SDGオプティマイザの追加
sgd_params = SgdParams(lr=0.001, batch=8, momentum=0)
sgd_params.set_batch(8, [1, 2, 8, 16])
builder.set_sgd_optimizer(sgd_params)
# 学習用インプットに名前をつけておく
builder.spec.description.trainingInput[0].shortDescription = "Example image"
builder.spec.description.trainingInput[1].shortDescription = "True label"
# ハイパーパラメタ:エポックの追加
builder.set_epochs(10, [1, 10, 50])
# アップデート可能モデルの保存
coremltools.utils.save_spec(builder.spec, "UpdatableCatsClassifier.mlmodel")
$ python make-model-updatable.py
WARNING:root:TensorFlow version 2.0.0 detected. Last version known to be fully compatible is 1.14.0 .
Now adding input labelProbability_true as target for categorical cross-entropy loss layer.
Input name(s) and shape(s):
image : (C,H,W) = (3, 224, 224)
Neural Network compiler 0: 160 , name = bn_data, output shape : (C,H,W) = (3, 224, 224)
(中略)
Neural Network compiler 175: 175 , name = labelProbability, output shape : (C,H,W) = (4, 1, 1)
$ ll UpdatableCatsClassifier.mlmodel
-rw-r--r-- 1 otti staff 90M Dec 22 02:03 UpdatableCatsClassifier.mlmodel
具体的には,CatsClassifier.mlmodel
の最後の全結合レイヤ(猫種を決定する部分)をアップデート可能にしています.
ちなみにこのMLモデルの中身の画像化にはNetronを使っています.
出力されたUpdatableCatsClassifier.mlmodel
をXcodeで開くと,Updateというセクションがあり,追加学習が可能なモデルになっていることが確認できました.
ただし,ユーザ猫用のチャネルはweight: 0, bias: 0で初期化しており,これは学習済みモデルの1番目のチャネル(ベンガル)と同じになるので, 追加学習していないデフォルトのMLモデルで推論すると,Probabilityがベンガルと同じになってしまいます. そのため,実際アプリで使用する際は,追加学習の実施有無などを見て推論結果の表示をコントロールしたりする必要がありそうです.
※ この猫は昔の飼い猫でキジトラです.12年前の今日(2007/12/22)天国に行ってしまいました(TωT)
ハマったところ
はじめ,pip install coremltools
で入ったバージョンが3.0b3だったので,make-model-updatable.py
でアップデート可能なモデルの作成ができませんでした.
pip install -U coremltools
とすることで,3.1がインストールされ,正しく動作するようになりました.
turicreateのバージョンも同様に古いものが入ってしまっていたので,pip install -U turicreate
としました.