画像分類モデルの特徴量を可視化してみる【備忘録】

お疲れ様です。

今回は画像分類モデルが分類の判断に使用する特徴量を可視化してみる回です。 モデルがどのように画像の特徴をとらえているかをおおまかに知ることができ、例えばデータセットの見直しなどでモデルの精度改善の検討ができるようになるかと思います。

ソースコード

例によってソースコードはこちらの画像分類モデルをまとめたリポジトリに残しています。 使用するだけなら学習済みモデルを用意し、main_inference_feature.pyを実行してもらえれば使用できます。

github.com

処理内容

処理の流れとしては以下のような感じです。

  1. データセットを使用してモデルを作成
  2. 推論を実行しその際の特徴量を取得
  3. 得られた特徴量を次元削減
  4. 次元削減後の特徴量を可視化

1. データセットを使用してモデルを作成

モデル作成は普通の学習プログラムを使用します。
今回は以前作成したSEResNeXt50の学習済みモデルを対象とします。 推論で使用するデータセットもこの時のデータセットと同じものです。

fallpoke-tech.hatenadiary.jp

2. 推論を実行しその際の特徴量を取得

学習済みモデルを使用して推論を実行します。 この時モデルから取得するのはモデルの最終出力層(画像分類モデルなら基本はLinear層)を通したスコアではなく、その直前の状態の特徴ベクトルです。

リポジトリではtimmライブラリを使用していますが、このライブラリではモデルのインスタンスに特徴ベクトルを取得するメソッドがあります。 こちらを利用してデータセットの画像1枚ごとの特徴量を取得します。

def feature_extraction(
    self,
    input_img: np.ndarray,
) -> np.ndarray:
    """1画像で特徴量抽出の処理を実行
    """
    input_img = input_img.to(self.device)
    input_img = input_img.unsqueeze(0)
    
    with torch.no_grad():
        output = self.model.forward_features(input_img)
    
    return output.detach().cpu().numpy()[0]

出力(変数output)の形状を確認すると"[1, 2048, 7, 7]"となっています。 頭の"1"はバッチサイズなので無視するとして、SEResNeXtでは画像1枚で"2048x7x7"次元の特徴ベクトルが取得できることがわかります。
shape

3. 得られた特徴量を次元削減

すべての画像の特徴ベクトルを取得したら、次は次元削減を適用します。
次元削減手法はUMAPを使用します。

  • UMAPについて(from ChatGPT)

    UMAP(Uniform Manifold Approximation and Projection)は、高次元データを2次元や3次元に圧縮して可視化する次元削減手法です。データ同士の近さ(局所構造)を保つことを重視しつつ、全体の配置関係もある程度維持します。t-SNEに比べて高速で、大規模データにも適しており、新しいデータを既存の低次元空間に写像できる点が特徴です。主に特徴量や埋め込みの可視化に用いられます。

PythonでUMAPを使用する方法について詳しくは以下を参照ください。

boritaso-blog.com

作成したプログラムでは以下のように使用しています。
画像データすべての特徴量を1つにまとめてUMAPのインスタンスに与えています。 インスタンス作成時に指定した"n_components"が次元削減後の次元数です。 今回は3Dプロットで可視化したいので各画像データの特徴ベクトルを3次元まで削減するように設定しています。

# UMAPインスタンスの作成
reducer = umap.UMAP(n_components=3)

# 画像を1枚ずつ特徴量抽出
features = []
labels = []
for i in tqdm(range(len(test_dataset)), desc="inference"):
    input_img, lbl = test_dataset[i]
    feat = infer.feature_extraction(input_img)
    features.append(feat.flatten())
    labels.append(lbl)

# UMAPによる次元削減
embeddings = reducer.fit_transform(np.array(features))

4. 次元削減後の特徴量を可視化

ここまでで作成したデータを実際に可視化します。
可視化ライブラリにはplotlyを採用しました。 今回は3Dプロットで可視化したいのでplotlyで作成した方が動作が軽かったです。 matplotlibでも作成自体は可能ですが、かなり動作が重かったです…。

コードはざっくりとこんな感じです。下記以外に細かく設定をしているので詳細はリポジトリの"modules/visualize_features.py"を参照してください。

# 3次元散布図の作成
fig = scatter_3d(
    x=features[:, 0],
    y=features[:, 1],
    z=features[:, 2],
    color=[f"{classes.index(classes[label])}-{classes[label]}" for label in labels],
    color_discrete_sequence=color_map,
    labels={"color": "Classes"},
)

結果の確認

1~4を実行した結果が以下のようになります。 ブラウザ上で表示され、マウス操作で拡大や軸を回転させることができます。

result

このように右の判例から表示するデータを絞ることもできます。 「bell pepper(ピーマン)」と「capsicum(唐辛子)」に絞っていますが、これらは近い3次元空間上で近いところに位置しています。 少し見にくいかもしれませんが、混同行列を見てもラベル3(bell pepper)とラベル5(capsicum)の間で誤分類しているのでしっかりとモデルが得た特徴量の傾向を可視化できていそうです。

result2

confusion_matrix

画像分類モデルSEResNeXtについて調べたまとめ

お疲れ様です。

画像分類モデルのSEResNeXtについてのメモです。
個人的にCNNベースの画像分類のモデルアーキテクチャとしてはEfficientNetV2と並んでよく使います。

論文

SEResNeXtの重要なアーキテクチャであるSEブロックに関する論文です。
後の概要にも記載がありますがSEブロックと画像分類モデルResNeXtを組み合わせたものがSEResNeXtと呼ばれています。

arxiv.org

概要(from ChatGPT)

SEResNeXt(Squeeze-and-Excitation ResNeXt)は、 ResNeXtSE(Squeeze-and-Excitation)ブロックを組み込んだ画像分類モデルです。 ResNet / ResNeXt / SENet の長所を組み合わせた構造になっています。


1. ベースとなる3つの考え方

① ResNet(残差学習)

  • skip connection(残差接続) により勾配消失を防ぐ
  • 非常に深いネットワークでも学習が安定
y = F(x) + x

② ResNeXt(Cardinality)

ResNetを拡張し、並列な畳み込みの分岐数(cardinality)を増やす設計。

  • チャンネル数や層を増やすより 効率的に表現力を向上
  • 32x4d のような表記が特徴

例:

  • 32x4d

    • cardinality = 32(分岐数)
    • 各分岐のチャネル幅 = 4

③ SENet(チャンネル注意機構)

重要なチャネルを自動で強調する仕組み。

SEブロックの流れ

  1. Squeeze

    • Global Average Pooling
  2. Excitation

    • 全結合層でチャネルごとの重みを計算
  3. Scale

    • 元の特徴マップに重みを掛ける
channel-wise attention

2. SEResNeXtの構造

ResNeXtの各Residual BlockにSEブロックを追加

Input
 ↓
Grouped Convolution (ResNeXt)
 ↓
SE Block(チャネル注意)
 ↓
Residual Add

特徴

  • ResNeXtの 高い表現力
  • SENetの 重要特徴の強調
  • パラメータ増加は比較的少ない

3. モデル名の読み方

例:seresnext50_32x4d

要素 意味
50 ネットワークの深さ
32 cardinality(分岐数)
4d 各分岐のチャネル幅
se Squeeze-and-Excitation

4. メリット・デメリット

✅ メリット

  • 高精度(ImageNetで実績あり)
  • ResNetより 効率的
  • 少量のパラメータ増加で性能向上
  • 転移学習に強い

❌ デメリット

  • ResNetより計算量がやや増加
  • 軽量モデル(MobileNet系)ほど速くない

5. どんなタスクに向いているか

  • 一般的な 画像分類
  • 医用画像・工業検査
  • 少〜中規模データセットの転移学習
  • 高精度が求められるタスク

OSSライセンスはtimmから利用すると基本的にはApache-2.0なので、商用利用も問題なさそうです。
timm以外から利用する場合は要確認です。 (例えば、ResNeXtの公式実装ではBSDライセンスになっています。)

実装

SEResNeXt自体の公式実装は探してみた感じはなさそうだったので、実際にモデルを使用したい場合はtimmライブラリから使用することになると思います。

モデルはHuggingFaceのモデルページで探すことができます。
.racm_in1k.gluon_in1kの2パターンありますが、.racm_in1kの方を使うのが無難のようです。

huggingface.co

お試し

普段使用している画像分類モデルのリポジトリで使用できるように実装しました。

github.com

コンフィグファイル(config/train_config.toml)のmodel_nameに"SEResNeXt"を指定すると使用できます。

model_name = "SEResNeXt"

上記のコードを使って実際に動かしてみました。
EfficientNetV2とVisionTransformerでも同じ設定で実行し、結果を比較してみます。

モデルはそれぞれtimmから利用しています。

データセットについては、過去にも使用したFood_and_Vegetablesを使用しています。 huggingface.co

実際の結果を見てみます。上記のデータセットで試した結果なので参考程度に。

学習曲線が以下です。 20epochまでで比較しましたが、SEResNeXtだけまだ収束していない感じですね…。

learning_curve

SEResNeXtのみ追加で50epochまで学習を回してみましたが、最終的なベストの精度は同じくらい(97%程)になりました。

seresnext_learning_curve

また、全体の実行時間とGPUメモリの使用量を計測した結果が以下になります。
学習の収束が遅いですが、モデル自体が軽いので同条件だと実行時間は早く、GPUメモリも余裕があります。 バッチサイズを上げる余裕があるので更なる高速化もできそうです。

実行時間(s) GPUメモリ使用量(GB)
SEResNeXt 4006.2 6.379
EfficientNetV2 4216.3 13.692
VisionTransformer 4285.8 6.421

【FastAPI】日本語を含むファイル名のファイルをダウンロードする際のエラーと対処【備忘録】

お疲れ様です。

今回はFastAPIのファイルダウンロードで日本語のファイル名をダウンロードする際の注意点についてまとめました。

ちょうど1年前くらいにFastAPIでファイルダウンロードをするAPIを作成していました。 その際はファイル名を半角英数字のみで作成していたため気づかなかったのですが、この時のコードだと日本語を含むファイル名のファイルをダウンロードしようとするとエラーが起こります。

fallpoke-tech.hatenadiary.jp

ソースコード

今回もエラー再現のためにデモページとAPIを作成しました。
以下GitHubに残してありますので必要があればご確認ください。

github.com

実行してページを開くとこんな感じです。 demo_page

エラー内容と対処方法

エラーはFastAPIで作成したAPIエンドポイントのレスポンスの内容によるものです。
以下エラーが起こるパターンとエラーを解消したパターンを記載します。

エラーが発生する書き方

import io
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()


@app.get("/download")
def download_from_df() -> StreamingResponse:
    """DataFrameを指定ファイル形式でダウンロードするAPI
       (ローカルにファイル保存せず、データをファイル化して返す)
    """
    stream = io.StringIO()
    sample_df.to_csv(stream, encoding='utf-8', index=False)
    stream.seek(0)

    filename = "日本語ファイルサンプル.csv"
    media_type = "text/csv"
    
    return StreamingResponse(
        content=stream, 
        media_type=media_type,
        headers={"Content-Disposition": f"attachment; filename={filename}"}
    )

ページ上のボタンを押してダウンロードしようとするとこのように「Internal Server Error」で表示されます。

error_app

Python側の表示を見ると下記のエラーが出ています。 これはStreamingResponseの処理中に起こっており、ファイル名に日本語を含む場合に再現します。
非ASCII文字が含まれていることが原因のようです。

UnicodeEncodeError: 'latin-1' codec can't encode characters in position 21-31: ordinal not in range(256)

error_fastapi

正しい書き方

import io
from urllib.parse import quote # 追加
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()


@app.get("/download")
def download_from_df() -> StreamingResponse:
    """DataFrameを指定ファイル形式でダウンロードするAPI
       (ローカルにファイル保存せず、データをファイル化して返す)
    """
    stream = io.StringIO()
    sample_df.to_csv(stream, encoding='utf-8', index=False)
    stream.seek(0)

    filename = "日本語ファイルサンプル.csv"
    media_type = "text/csv"
    
    return StreamingResponse(
        content=stream, 
        media_type=media_type,
        headers={"Content-Disposition": f"attachment; filename*=UTF-8''{quote(filename)}"} # ここを修正
    )

urllib.parse.quote()を使用します。 こちらを適用させることで日本語を含む文字列(非ASCII文字を含む文字列)をURLエンコーディングしてファイル名に設定しています。

実際に実行するとこんな感じ。
日本語の部分が%から始まる特殊な文字に置き換えられています。

urllib.parse.quote

これで正常にダウンロードできます。実際にダウンロードすると変換する前の元の日本語ファイル名なっていることがわかります。

download

参考

下記を参考にしました。併せてご参照ください。
ありがとうございました。

qiita.com

WSL+DockerでSAM3の環境構築をしてお試し実行

お疲れ様です。

SAM3の実行環境をWSL+Dockerで作成し、実際に実行して試してみた記録です。

2025年11月にリリースされたSAM(Segment Anything Model)シリーズの最新モデルです。 SAM3では、プロンプトで画像内の検出したい物体を指示することで目的の物体のセグメンテーションとBBoxの出力ができます。

(他にも3Dオブジェクトに対応したSAM3Dもありますが今回は扱いません。)

環境構築

  • 実行環境

OS: Windows 11 Pro
CPU: Intel Core i7-13700
メモリ: 32GB
GPU: NVIDIA GeForce RTX 4060 Ti (VRAM: 16GB)

環境は上述の通りWSL+Dockerを使用しました。また、Python環境はuvを使用しています。
ベースの環境の作成については過去記事をご参考ください。

fallpoke-tech.hatenadiary.jp

Windows環境の場合、一部のライブラリがLinuxでしか使えず自前でビルドする必要があるのでWSLを使う方が良いと思います。

今回使用した環境設定を含めたリポジトリGitHubに残しています。 SAM3の公式リポジトリをforkして環境設定ファイルを追加したのみですが…。

github.com

実行

公式があげているデモ用のコードを参考に作成した下記のソースコードを実行しました。

注意点として、モデルの重みのダウンロードにはHuggingFaceのモデルページで利用申請が必要になります。

import os

from PIL import Image
import matplotlib.pyplot as plt
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
from huggingface_hub import login
from dotenv import load_dotenv

load_dotenv()

login(token=os.getenv("HF_TOKEN"))

# モデルの準備
model = build_sam3_image_model()
processor = Sam3Processor(model)
# 画像の読み込み
image = Image.open("data/1624777685449_985774_photo1.jpeg")
inference_state = processor.set_image(image)
# テキストプロンプトを設定して推論を実行
output = processor.set_text_prompt(state=inference_state, prompt="tomato")

plot_results(image, output)
plt.show()
plt.close()

上記を実行するとこんな感じで出力されます。
tomato

プロンプトの指示である程度検出したい物体を絞ることも可能です。例えばprompt="red tomato"と変更すると出力が変わります。
red tomato

私の環境での話にはなりますが、VRAMを大体5GBくらい使用しているので比較的軽そうです。
また、画像1枚あたりの推論時間は0.20sほどだったのでこちらもなかなか速いです。
gpu state

参考サイト

Pytorchのモデル学習の中断と再開の処理を実装したメモ【備忘録】

お疲れ様です。

今回はPytorchで学習の途中再開をするためのコードのメモです。 長期間モデル学習を実行する際に予期せぬトラブルで処理が止まってしまった場合などにも使えると思います。

ソースコード

以前作成した画像分類のプロジェクトに実装しています。 github.com

実装

  • モデルの保存時
    下記modules/trainer.pyのTrainerクラスのモデル保存用メソッドです。 modelのパラメータに加えてoptimizerのパラメータを保存するのがポイントになっています。
    ちなみに下記は最新epochのモデルを保存するためのメソッドですが、これを毎epoch行うようにしています。 こうすることで処理を止めたor止まったタイミングから再開することができます。
def save_weight_latest(
    self,
    epoch: int
) -> None:
    """モデルの重みを保存
    """        
    # 最終epochのモデル
    model_name = "model_latest.pth"
    checkpoint = {
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, self.output_path.joinpath(model_name))
    print(f"model saved: {model_name}")
  • 途中再開時
    学習の途中再開用にmain_train_resume.pyを作成しています。 torch.loadで読み込みmodelとoptimizerそれぞれで読み込みをする形になります。
# 途中保存した重みの読み込み
checkpoint_path = output_path.joinpath("model_latest.pth")
checkpoint = torch.load(checkpoint_path, map_location=device)

# モデルの定義
model, _ = get_model_train(
    model_name=model_name, 
    num_classes=train_dataset.num_classes, 
    use_pretrained=use_pretrained
)
# 途中保存の重みに更新
model.load_state_dict(checkpoint["model_state_dict"])
params = model.parameters()

# optimizerの定義
optimizer = RAdamScheduleFree(params, lr=lr)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

実行するとこんな感じになります。 以下、「①中断せずに20epoch回した結果」と「②10epochで中断して再開した結果」を並べてみました。
似たような曲線を描いているので問題なく途中再開できていそうです。


  • 1


  • 2

参考

参考にさせていただいた記事を載せておきます。
qiita.com

再現性も保持したい場合は乱数の状態も一緒に保存し、再開の際に読み込むことで実現できるようです。 今回作成のコードでは実装していませんが。 qiita.com

HuggingFace Datasetsのload_datasetでダウンロードに失敗するときの対処【備忘録】

お疲れ様です。

HuggingFaceのDatasetsからデータセットを読み込み時にエラーが出たときの対応方法のメモです。

今回問題のあったデータセット

  • Food_and_Vegetables(画像分類用のデータセット

huggingface.co

通常のやり方

通常Datasetsからデータセットを読み込む際はload_dataset()を使用します。
以前作成したプログラムのソースからですが、このようにして問題なくデータセットのダウンロードができていました。

# Datasetの読み込み
# https://huggingface.co/datasets/Bingsu/Human_Action_Recognition
dataset = load_dataset("Bingsu/Human_Action_Recognition")

今回対象とした"Food_and_Vegetables"では、以下のようにエラーが出ました。 エラー内容としてはHTTP ErrorでREADMEや画像データなどDatasetに含まれるファイルをダウンロードする際にエラーになっているようです。

dataset = load_dataset("SunnyAgarwal4274/Food_and_Vegetables")

HTTP error

対処方法

load_dataset()を使用せず、直接ダウンロード(git clone)することでデータ自体をダウンロードしました。 Datasetsのページ自体がGitリポジトリとなっているのでgitがインストールされていればURLを指定してcloneすることができます。

git clone https://huggingface.co/datasets/SunnyAgarwal4274/Food_and_Vegetables

git clone

このようにページ内にあるデータをそのままクローン出来ました。データの欠損等もなさそうです。 この方法を使用した場合、load_dataset()を使用したときのように自動でHuggingFaceのDatasetクラスの形式にすることはできないので、データセットの形式に合わせてDatasetクラスを自作する必要はありそうです。

dataset

今回のデータセット(Food_and_Vegetables)の場合ImageFolder形式のフォルダ構造になっているので、load_dataset()で下記のような書き方をすればHuggingFaceのDatasetクラスで読み込むことができるようです。

dataset = load_dataset(
    "imagefolder", 
    data_dir="./dataset" # git cloneしたデータセットのパスを指定(ImageFolder形式)
)

load_datasetでエラーになった原因(推測)

対象とした"Food_and_Vegetables"ではデータセットの画像が生データのままリポジトリに格納されていました。 このデータを一挙にダウンロードしようとして一部のデータでエラーが起こっていたと推測しています。

page_Food_and_Vegetables

また正常にダウンロードできた"Human_Action_Recognition"の方はデータセットをparquetファイルに変換していることがわかりました。 ざっと見た感じ他の多くのデータセットもparquetファイルかつ決まったフォルダ構造になっているように見受けられたので、 load_dataset()がこの形式を推奨しているということもありそうです…。

page_Human_Action_Recognition

OCRモデルTrOCRの実装まとめ

お疲れ様です。

前回の記事でTrOCRの調査内容をまとめたので、今回は実装のお話。 fallpoke-tech.hatenadiary.jp



ソースコード

ソースコードはこちらのGithubリポジトリにまとめています。mainブランチの方を使用する想定です。
プロジェクト全体の構造や使用方法などはREADMEを参照ください。 GitHub Copilotに作成させたのでおかしな点などあるかもしれませんが…。 github.com

下記のサイトのソースコードをベースに日本語対応をしたものになります。
https://qiita.com/relu/items/c027c486758525c0b6b9

コード内容補足

ソースコードの細かい内容はこちらで補足として記載します。

モデル

使用したモデルはこちら。HuggingFaceで利用できる日本語対応の事前学習モデルの中で使えそうなもので採用しました。 日本語の漫画のコマのデータセットを使用して作成したモデルのようです。 huggingface.co

# model_name = "kha-white/manga-ocr-base"
model = VisionEncoderDecoderModel.from_pretrained(model_name)

今回は上記を採用しましたがtransformersのVisionEncoderDecoderModelではEncoderとDecoderに個別にモデルを指定することができます。

  • Encoder -> 画像系Transformer(ViTやDeiTなど)
  • Decoder -> 自然言語系Transformer(BERTやGPTなど)

コードにすると以下のような形になります。encoderとdecoderはそれぞれ以下を使用しています。 日本語対応にしたい場合はDecoder側を日本語対応の生成モデルにすると良さそうです。

Encoder -> https://huggingface.co/facebook/deit-base-distilled-patch16-224
Decoder -> https://huggingface.co/rinna/japanese-gpt2-xsmall

# encoder_name = "facebook/deit-base-distilled-patch16-224"
# decoder_name = "rinna/japanese-gpt2-xsmall" 
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name, decoder_name)

データセット

画像と正解テキストの1組を返すPutorchのDatasetクラスを作成すればOKです。 用意したデータセットに合わせて処理を書き換えてください。

デフォルトでは下記のようなフォルダ構造に対応しています。 annotations.csvに画像と正解テキストの対応を記載してそれをDatasetクラスに登録するイメージです。

.
└──dataset
  ├──annotations.csv
  ├──image1.png
  ├──image2.png
  ⋮

  • annotations.csvの中身(例)
img text
PXL_20251004_070420578_30.png 必ずまた帰ってくる
PXL_20251004_070420578_31.png 東京からまんまで宇宙へ
PXL_20251004_070420578_32.png 今年も似合いの夏

前処理

前処理は基本的にはモデルに対応するProcessorを使用することになります。 今回の例だとこのようになります。
このProcessorでは224x224のリサイズと画像の正規化、torch.Tensorへの変換処理が含まれています。 224x224のリサイズがある関係で、横長や縦長の画像を対象とする場合にデータの欠損が起こりそうではあります…。

# model_name = "kha-white/manga-ocr-base"
processor = TrOCRProcessor.from_pretrained(model_name)

ソースコードではこのProcessorのほか、Albumentationsを使った画像の前処理も加えています。 デフォルトだと以下の処理が含まれます。 modules/loader/augmentation.pyaug_listにAlbumentationsのお好きな処理を追加することもできます。 またデフォルトには含めていませんが、同じソースにカスタムの前処理をいくつか実装しています。

aug_list = [
    # 色調変化のうちいずれか1つをランダムに適用
    A.OneOf([
        A.HueSaturationValue(),  # 画像の色相・彩度・明度をランダムに変化させる(全体の色味を調整)
        A.RGBShift(),            # 各RGBチャンネルをランダムにシフトさせて色のバランスを変える(照明やカメラ特性の変化を模倣)
        A.InvertImg()            # 画像の色を反転(白→黒、黒→白)させる(ネガ画像などへの耐性向上)
    ]),
    
    # 明るさとコントラストをランダムに変化させる(明暗差や照明条件の変動に対応)
    A.RandomBrightnessContrast(
        brightness_limit=0.2,   # 明るさを±20%の範囲で変化
        contrast_limit=0.2,     # コントラストを±20%の範囲で変化
        p=0.5                   # 50%の確率で適用
    ),
    
    # CLAHE(ヒストグラム均等化)を適用し、コントラストを局所的に強調(文字や模様の視認性向上)
    A.CLAHE(p=0.2),             # 20%の確率で適用
    
    # JPEG圧縮をシミュレートして画質を劣化させる(低品質画像に対するロバスト性を向上)
    A.ImageCompression(p=0.3)   # 30%の確率で適用
]

実行結果

実際にソースコードを使用して作成した結果を載せておきます。 結果としてはあまりうまくいっていないようにも思いますが参考程度に。

学習に使用したデータセットはHuggingFaceから利用できる以下のデータセットと自作のデータセットを混ぜたものです。 huggingface.co

自作データセットの方は以前記事で紹介したYomiTokuを使ったプログラムを使って作成しています。 fallpoke-tech.hatenadiary.jp

データとしては15万枚+2000枚程度あり100epochの学習が完了するのに1週間ほどかかりました…。 それでいてあまり精度も良いとは言えない感じなのが残念です。いろいろ見直しは必要になりそうです…。

学習

学習曲線がこのようになりました。 TrainとValidationの段階でそれぞれLossを計算、またValidationではCER(Character Error Rate)も計算して表示しています。

CERについて詳しくはこちら参照。
https://qiita.com/Kchan/items/7bba1f066234ba24898b

learning_curve

推論

作成したモデルを使って推論を実行した結果を一部載せておきます。 プログラムの出力としては以下のような入力画像と正解テキスト、検出テキストをcsvファイルにまとめています。

result

  • PXL_20251001_141512690_12.png
    正解テキスト: 桜の花、舞い上がる道を
    検出テキスト: 桜 の 、 い 上 本 道
    sakuranohana

文字の一部は検出できていますが、全体で正解してはいない感じですね…。 横長の画像が多いので224x224のリサイズで文字がつぶれてしまっていそうです。