PyTorch Lightning ModelCheckpointのfilename引数にスラッシュを含めるとぶっ壊れる問題

作業環境

始まり

指定された要素(精度や損失)を元に上位や下位Nつのチェックポイントを保存してくれる ModelCheckpoint さん。

使用箇所を抜粋するとこんな感じ。

def on_validation_epoch_end(self) -> None:
    super().on_validation_epoch_end()

    loss = torch.stack(self.valid_outputs["loss"]).mean()
    accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()

    self.log("val_loss", loss)
    self.log("val_accuracy", accuracy)

    # free up the memory
    self.valid_outputs["loss"].clear()
    self.valid_outputs["accuracy"].clear()
ModelCheckpoint(
    monitor="val_accuracy",
    filename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
)


引数名 説明
monitor 監視対象を指定
filename 保存するチェックポイント名のフォーマットを指定
save_top_k 最大で幾つ保存するかを指定
mode 精度なら上位(max)、損失なら下位(min)を指定
save_last 最後/エポック毎のチェックポイントをlast.ckptとして保存するか

たったこれだけ。めっちゃシンプルで便利です。
※他にも指定可能な引数はありますが必要最小限が上記。


ちょっと寄り道します。

PyTorch Lightningのログはステップごとに記録されます。ステップごとの記録の欠点はデータセットが増減した際に横軸がズレるため、ぱっと見で比較がしにくいことです。そのため筆者はステップごととは別にエポックごとにも記録を付けています。

愚直にLoggerに書き込むと画像のように縦長に展開されてしまいます。気にならない方はいいですが、ちょっと縦長過ぎる気がします。

そんな時は グループ・カテゴリ名/名称 (例: "train/loss", "valid/loss") と指定することでグループ・カテゴリ別けができます。

def on_train_epoch_end(self) -> None:
    super().on_train_epoch_end()

    loss = torch.stack(self.train_outputs["loss"]).mean()
    accuracy = torch.cat(self.train_outputs["accuracy"]).mean()

    # train groups
    self.log("train/loss", loss)
    self.log("train/accuracy", accuracy)

    # free up the memory
    self.train_outputs["loss"].clear()
    self.train_outputs["accuracy"].clear()

def on_validation_epoch_end(self) -> None:
    super().on_validation_epoch_end()

    loss = torch.stack(self.valid_outputs["loss"]).mean()
    accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()

    # valid groups
    self.log("valid/loss", loss)
    self.log("valid/accuracy", accuracy)

    # free up the memory
    self.valid_outputs["loss"].clear()
    self.valid_outputs["accuracy"].clear()

trainとvalidで纏めて表示されるため見やすいですね。

問題

本題に戻りますか。

スラッシュ(/)を用いてグループ・カテゴリ別けしたものを監視対象に指定するとどうなるでしょうか。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
)

はい、ぶっ壊れました。

スラッシュがディレクトリの区切り記号として認識されてしまい、ファイル名のフォーマットが挙動不審になってしまうのです。

これの対処法です。

解決

とっても簡単です。
auto_insert_metric_name=Falseを指定するだけです。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
    auto_insert_metric_name=False,
)


引数名 説明
monitor 監視対象を指定
filename 保存するチェックポイント名のフォーマットを指定
save_top_k 最大で幾つ保存するかを指定
mode 精度なら上位(max)、損失なら下位(min)を指定
save_last 最後/エポック毎のチェックポイントをlast.ckptとして保存するか
auto_insert_metric_name ファイル名にメトリック名(たぶんlossとかaccuracyの総称)を含めるか

auto_insert_metric_nameのデフォルト値はTrueです。

そのためfilename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}"と指定するだけでファイル名にepochやval_accuracy、val_lossなどのメトリック名が自動的に挿入されていました。

要はディレクトリの区切り記号とauto_insert_metric_nameの相性が悪いため、発生していた問題なのでした。


auto_insert_metric_name=Falseを指定するとepochaccuracylossなどのメトリック名を自動で挿入してくれないためfilenameに明示的に指定しましょう。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-epoch={epoch}-accuracy={valid/accuracy:.8f}-loss={valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
    auto_insert_metric_name=False,
)

大 解 決

おわり!!!

かれこれ2年ほどこの問題を放置していました。

業務で扱っている訳ではないのでぶっちゃけ未来の自分が汲み取れる範囲なら、ある程度適当であったり、問題を放置したりしても困ることなかったんですよね。

直したいなと思ったきっかけは ReinVisionOCRの大規模改修 第3弾ですね。セレオブ発売を密かに備えています。

思い返すとReinVisionOCRを作り始めて1年ちょいぐらい経つのですね。好きなゲームジャンルである恋愛ADVとPythonと深層学習を絡めた最強の暇つぶしになっています。それ故に好き過ぎて周期的に倦怠期が発生して毎回親密度がリセット状態ですが。

動作させるのも楽しいけど、なによりコードを書くのが楽しいわ。あと今年中にはReinVisionOCRの記事を移転しなければ。やりたいことが多すぎるンゴねぇ。たのぢぃ。