画像分類モデルDeiTについて調べたまとめ

お疲れ様です。

画像分類モデルのDeiTについて、使う機会があり内容を調べてみたのでメモ的にまとめておきます。

論文

arxiv.org

要約(from ChatGPT)

  • 背景 Vision Transformer (ViT) は自然言語処理のTransformerを画像分類に応用したモデルだが、従来は数億枚規模のデータセット(例: JFT-300M)と大規模計算資源が必要で、一般的な利用は難しかった。

  • 提案手法(DeiT: Data-efficient image Transformers)

    • ImageNet (1.3M枚) のみを使い、単一のGPUノード(8GPU)で 3日以内 に高精度なTransformerを学習。
    • 異なるサイズのモデル(DeiT-Ti, DeiT-S, DeiT-B)を設計。ResNet-18/50に対応する軽量版も提供。
    • 蒸留 (Distillation) の工夫

      • 通常のラベル蒸留ではなく、Transformerに特化した「蒸留トークン (distillation token)」を導入。
      • このトークンはクラス分類用トークンと同様に自己注意機構を通じて学習され、教師モデル(主にCNN)の予測を模倣する。
      • ConvNet教師から学ぶ方がTransformer教師より有効であり、CNNの持つ帰納バイアスがTransformerに移植される。
  • 性能

    • ImageNet で ViT を大幅に改善し、最大 Top-1精度 85.2% を達成(ViT-BをJFT-300Mで事前学習したモデルより高精度)。
    • EfficientNet 系ConvNetに匹敵する効率性を実現。
    • CIFAR-10/100、Oxford-102 Flowers、Stanford Cars、iNaturalist などの転移学習タスクでも強力な結果。
  • 学習上の工夫

    • RandAugment, Mixup, CutMix, Stochastic Depth, Repeated Augmentation など強力なデータ拡張と正則化を組み合わせ。
    • AdamW最適化手法を利用し、重み減衰を小さめに調整。
    • 224解像度で学習 → 384解像度でファインチューニング。
  • 結論 DeiTは、大規模データや計算資源がなくてもViTを高精度に訓練可能にした。CNNに依存しない新しい標準的アーキテクチャになる可能性がある。

ざっくりと、アーキテクチャ自体はVisionTransformer(ViT)と同じで学習方法が異なるモデルのようです。 ViTでは大量のデータが必要になるところ、DeiTでは知識蒸留を使い少ないデータで効率よく学習することを可能にしたという感じでしょうか。

ライセンスはApache-2.0なので商用でも使いやすいです。

実装

Metaによる公式実装は以下。
github.com

実際にモデルを使用したい場合はtimmtransformersなどのライブラリから使用可能。

お試し

transformersのモデルページにあるデモコードを使って試してみました。
huggingface.co

import os
if not os.path.exists("./pretrained"):
    os.makedirs("./pretrained")
os.environ["HF_HOME"] = "./pretrained"
import warnings
warnings.filterwarnings("ignore")

from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')

inputs = feature_extractor(images=image, return_tensors="pt")

# forward pass
outputs = model(**inputs)
logits = outputs.logits

# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

入力画像と結果がこのようになりました。
判定結果が「tabby, tabby cat」とのこと。これは縞模様のある猫を指すので合っていそうですね。

input

alt text

コードはこちらに残しています。

github.com

情報収集で参考にしたサイト

追記: コード実装

fallpoke-tech.hatenadiary.jp