作業環境
- Windows 10
- Visual Studio Code
- Python 3.11
- PyTorch Lightning 2.2.1
始まり
指定された要素(精度や損失)を元に上位や下位Nつのチェックポイントを保存してくれる ModelCheckpoint
さん。
使用箇所を抜粋するとこんな感じ。
def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() loss = torch.stack(self.valid_outputs["loss"]).mean() accuracy = torch.cat(self.valid_outputs["accuracy"]).mean() self.log("val_loss", loss) self.log("val_accuracy", accuracy) # free up the memory self.valid_outputs["loss"].clear() self.valid_outputs["accuracy"].clear()
ModelCheckpoint( monitor="val_accuracy", filename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}", save_top_k=3, mode="max", save_last=True, )
引数名 | 説明 |
---|---|
monitor | 監視対象を指定 |
filename | 保存するチェックポイント名のフォーマットを指定 |
save_top_k | 最大で幾つ保存するかを指定 |
mode | 精度なら上位(max)、損失なら下位(min)を指定 |
save_last | 最後/エポック毎のチェックポイントをlast.ckpt として保存するか |
たったこれだけ。めっちゃシンプルで便利です。
※他にも指定可能な引数はありますが必要最小限が上記。
ちょっと寄り道します。
PyTorch Lightningのログはステップごとに記録されます。ステップごとの記録の欠点はデータセットが増減した際に横軸がズレるため、ぱっと見で比較がしにくいことです。そのため筆者はステップごととは別にエポックごとにも記録を付けています。
愚直にLoggerに書き込むと画像のように縦長に展開されてしまいます。気にならない方はいいですが、ちょっと縦長過ぎる気がします。
そんな時は グループ・カテゴリ名/名称 (例: "train/loss"
, "valid/loss"
) と指定することでグループ・カテゴリ別けができます。
def on_train_epoch_end(self) -> None: super().on_train_epoch_end() loss = torch.stack(self.train_outputs["loss"]).mean() accuracy = torch.cat(self.train_outputs["accuracy"]).mean() # train groups self.log("train/loss", loss) self.log("train/accuracy", accuracy) # free up the memory self.train_outputs["loss"].clear() self.train_outputs["accuracy"].clear() def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() loss = torch.stack(self.valid_outputs["loss"]).mean() accuracy = torch.cat(self.valid_outputs["accuracy"]).mean() # valid groups self.log("valid/loss", loss) self.log("valid/accuracy", accuracy) # free up the memory self.valid_outputs["loss"].clear() self.valid_outputs["accuracy"].clear()
trainとvalidで纏めて表示されるため見やすいですね。
問題
本題に戻りますか。
スラッシュ(/)を用いてグループ・カテゴリ別けしたものを監視対象に指定するとどうなるでしょうか。
ModelCheckpoint( monitor="valid/accuracy", filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}", save_top_k=3, mode="max", save_last=True, )
はい、ぶっ壊れました。
スラッシュがディレクトリの区切り記号として認識されてしまい、ファイル名のフォーマットが挙動不審になってしまうのです。
これの対処法です。
解決
とっても簡単です。
auto_insert_metric_name=False
を指定するだけです。
ModelCheckpoint( monitor="valid/accuracy", filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}", save_top_k=3, mode="max", save_last=True, auto_insert_metric_name=False, )
引数名 | 説明 |
---|---|
monitor | 監視対象を指定 |
filename | 保存するチェックポイント名のフォーマットを指定 |
save_top_k | 最大で幾つ保存するかを指定 |
mode | 精度なら上位(max)、損失なら下位(min)を指定 |
save_last | 最後/エポック毎のチェックポイントをlast.ckpt として保存するか |
auto_insert_metric_name | ファイル名にメトリック名(たぶんlossとかaccuracyの総称)を含めるか |
auto_insert_metric_name
のデフォルト値はTrue
です。
そのためfilename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}"
と指定するだけでファイル名にepochやval_accuracy、val_lossなどのメトリック名が自動的に挿入されていました。
要はディレクトリの区切り記号とauto_insert_metric_name
の相性が悪いため、発生していた問題なのでした。
auto_insert_metric_name=False
を指定するとepochやaccuracy、lossなどのメトリック名を自動で挿入してくれないためfilename
に明示的に指定しましょう。
ModelCheckpoint( monitor="valid/accuracy", filename="checkpoint-epoch={epoch}-accuracy={valid/accuracy:.8f}-loss={valid/loss:.8f}", save_top_k=3, mode="max", save_last=True, auto_insert_metric_name=False, )
大 解 決
おわり!!!
かれこれ2年ほどこの問題を放置していました。
業務で扱っている訳ではないのでぶっちゃけ未来の自分が汲み取れる範囲なら、ある程度適当であったり、問題を放置したりしても困ることなかったんですよね。
直したいなと思ったきっかけは ReinVisionOCRの大規模改修 第3弾ですね。セレオブ発売を密かに備えています。
思い返すとReinVisionOCRを作り始めて1年ちょいぐらい経つのですね。好きなゲームジャンルである恋愛ADVとPythonと深層学習を絡めた最強の暇つぶしになっています。それ故に好き過ぎて周期的に倦怠期が発生して毎回親密度がリセット状態ですが。
動作させるのも楽しいけど、なによりコードを書くのが楽しいわ。あと今年中にはReinVisionOCRの記事を移転しなければ。やりたいことが多すぎるンゴねぇ。たのぢぃ。