【Pytorch】load_state_dictの重み読み込みについてメモ【備忘録】

お疲れ様です。

最近の実装で知ったPytorchにおけるモデルのload_state_dict時にstrict=Falseを指定したときの仕様についてメモを残しておきます。

docs.pytorch.org

strict=Falseの指定について

load_state_dictでは基本的にモデルアーキテクチャ(nn.Module)と読み込もうしたモデルパラメータ(state_dict)の間に違いがあると読み込むことができません。 これはload_state_dictはデフォルトでstrict=Trueが設定されていることに起因します。

こんな感じのエラーが出ます。 モデルのアーキテクチャ、パラメータにはそれぞれキーが設定されており、その名称が一致していないことがエラーで書かれています。 モデル違い

strict=Falseに変更した場合はこのエラーを無視して読み込んでくれます。
具体的には、モデルアーキテクチャとモデルパラメータの間でキーの名称が一致するものは読み込み、それ以外は無視する処理となります。

model.load_state_dict(state_dict, strict=False)

出力層のクラス数が違う場合の対応

ただし、strict=False指定の場合キー名称が異なる場合のみ無視される処理なので、キー名称が同じで中身のパラメータの形状が違う場合は普通にエラーになります。

例えば、分類問題ならモデルは同じでも出力層のクラス数だけ変更する場合があると思います。 こういう場合はstrict=Falseでは対応できないので別の方法を検討する必要があります。 形状違い

転移学習やファインチューニングの際にこの問題にはよく当たるのでその対処方法もメモを残しておきます。

やり方としては読み込んだパラメータ(state_dict)から問題になっているキーを削除するのが簡単そうです。 キー名称は、先ほどのエラー文の中に記載がありますし、単純にstate_dictをprintするだけでも確認できるのでそれで調べることができますね。

state_dict = torch.load("your_checkpoint.pth")

# 不要なキーを削除
for key in ["class_labels_classifier.weight", "class_labels_classifier.bias"]:
    if key in state_dict:  
        del state_dict[key]

model.load_state_dict(state_dict, strict=False)