画像分類モデルの特徴量を可視化してみる【備忘録】

お疲れ様です。

今回は画像分類モデルが分類の判断に使用する特徴量を可視化してみる回です。 モデルがどのように画像の特徴をとらえているかをおおまかに知ることができ、例えばデータセットの見直しなどでモデルの精度改善の検討ができるようになるかと思います。

ソースコード

例によってソースコードはこちらの画像分類モデルをまとめたリポジトリに残しています。 使用するだけなら学習済みモデルを用意し、main_inference_feature.pyを実行してもらえれば使用できます。

github.com

処理内容

処理の流れとしては以下のような感じです。

  1. データセットを使用してモデルを作成
  2. 推論を実行しその際の特徴量を取得
  3. 得られた特徴量を次元削減
  4. 次元削減後の特徴量を可視化

1. データセットを使用してモデルを作成

モデル作成は普通の学習プログラムを使用します。
今回は以前作成したSEResNeXt50の学習済みモデルを対象とします。 推論で使用するデータセットもこの時のデータセットと同じものです。

fallpoke-tech.hatenadiary.jp

2. 推論を実行しその際の特徴量を取得

学習済みモデルを使用して推論を実行します。 この時モデルから取得するのはモデルの最終出力層(画像分類モデルなら基本はLinear層)を通したスコアではなく、その直前の状態の特徴ベクトルです。

リポジトリではtimmライブラリを使用していますが、このライブラリではモデルのインスタンスに特徴ベクトルを取得するメソッドがあります。 こちらを利用してデータセットの画像1枚ごとの特徴量を取得します。

def feature_extraction(
    self,
    input_img: np.ndarray,
) -> np.ndarray:
    """1画像で特徴量抽出の処理を実行
    """
    input_img = input_img.to(self.device)
    input_img = input_img.unsqueeze(0)
    
    with torch.no_grad():
        output = self.model.forward_features(input_img)
    
    return output.detach().cpu().numpy()[0]

出力(変数output)の形状を確認すると"[1, 2048, 7, 7]"となっています。 頭の"1"はバッチサイズなので無視するとして、SEResNeXtでは画像1枚で"2048x7x7"次元の特徴ベクトルが取得できることがわかります。
shape

3. 得られた特徴量を次元削減

すべての画像の特徴ベクトルを取得したら、次は次元削減を適用します。
次元削減手法はUMAPを使用します。

  • UMAPについて(from ChatGPT)

    UMAP(Uniform Manifold Approximation and Projection)は、高次元データを2次元や3次元に圧縮して可視化する次元削減手法です。データ同士の近さ(局所構造)を保つことを重視しつつ、全体の配置関係もある程度維持します。t-SNEに比べて高速で、大規模データにも適しており、新しいデータを既存の低次元空間に写像できる点が特徴です。主に特徴量や埋め込みの可視化に用いられます。

PythonでUMAPを使用する方法について詳しくは以下を参照ください。

boritaso-blog.com

作成したプログラムでは以下のように使用しています。
画像データすべての特徴量を1つにまとめてUMAPのインスタンスに与えています。 インスタンス作成時に指定した"n_components"が次元削減後の次元数です。 今回は3Dプロットで可視化したいので各画像データの特徴ベクトルを3次元まで削減するように設定しています。

# UMAPインスタンスの作成
reducer = umap.UMAP(n_components=3)

# 画像を1枚ずつ特徴量抽出
features = []
labels = []
for i in tqdm(range(len(test_dataset)), desc="inference"):
    input_img, lbl = test_dataset[i]
    feat = infer.feature_extraction(input_img)
    features.append(feat.flatten())
    labels.append(lbl)

# UMAPによる次元削減
embeddings = reducer.fit_transform(np.array(features))

4. 次元削減後の特徴量を可視化

ここまでで作成したデータを実際に可視化します。
可視化ライブラリにはplotlyを採用しました。 今回は3Dプロットで可視化したいのでplotlyで作成した方が動作が軽かったです。 matplotlibでも作成自体は可能ですが、かなり動作が重かったです…。

コードはざっくりとこんな感じです。下記以外に細かく設定をしているので詳細はリポジトリの"modules/visualize_features.py"を参照してください。

# 3次元散布図の作成
fig = scatter_3d(
    x=features[:, 0],
    y=features[:, 1],
    z=features[:, 2],
    color=[f"{classes.index(classes[label])}-{classes[label]}" for label in labels],
    color_discrete_sequence=color_map,
    labels={"color": "Classes"},
)

結果の確認

1~4を実行した結果が以下のようになります。 ブラウザ上で表示され、マウス操作で拡大や軸を回転させることができます。

result

このように右の判例から表示するデータを絞ることもできます。 「bell pepper(ピーマン)」と「capsicum(唐辛子)」に絞っていますが、これらは近い3次元空間上で近いところに位置しています。 少し見にくいかもしれませんが、混同行列を見てもラベル3(bell pepper)とラベル5(capsicum)の間で誤分類しているのでしっかりとモデルが得た特徴量の傾向を可視化できていそうです。

result2

confusion_matrix