OCRモデルTrOCRの実装まとめ

お疲れ様です。

前回の記事でTrOCRの調査内容をまとめたので、今回は実装のお話。 fallpoke-tech.hatenadiary.jp



ソースコード

ソースコードはこちらのGithubリポジトリにまとめています。mainブランチの方を使用する想定です。
プロジェクト全体の構造や使用方法などはREADMEを参照ください。 GitHub Copilotに作成させたのでおかしな点などあるかもしれませんが…。 github.com

下記のサイトのソースコードをベースに日本語対応をしたものになります。
https://qiita.com/relu/items/c027c486758525c0b6b9

コード内容補足

ソースコードの細かい内容はこちらで補足として記載します。

モデル

使用したモデルはこちら。HuggingFaceで利用できる日本語対応の事前学習モデルの中で使えそうなもので採用しました。 日本語の漫画のコマのデータセットを使用して作成したモデルのようです。 huggingface.co

# model_name = "kha-white/manga-ocr-base"
model = VisionEncoderDecoderModel.from_pretrained(model_name)

今回は上記を採用しましたがtransformersのVisionEncoderDecoderModelではEncoderとDecoderに個別にモデルを指定することができます。

  • Encoder -> 画像系Transformer(ViTやDeiTなど)
  • Decoder -> 自然言語系Transformer(BERTやGPTなど)

コードにすると以下のような形になります。encoderとdecoderはそれぞれ以下を使用しています。 日本語対応にしたい場合はDecoder側を日本語対応の生成モデルにすると良さそうです。

Encoder -> https://huggingface.co/facebook/deit-base-distilled-patch16-224
Decoder -> https://huggingface.co/rinna/japanese-gpt2-xsmall

# encoder_name = "facebook/deit-base-distilled-patch16-224"
# decoder_name = "rinna/japanese-gpt2-xsmall" 
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name, decoder_name)

データセット

画像と正解テキストの1組を返すPutorchのDatasetクラスを作成すればOKです。 用意したデータセットに合わせて処理を書き換えてください。

デフォルトでは下記のようなフォルダ構造に対応しています。 annotations.csvに画像と正解テキストの対応を記載してそれをDatasetクラスに登録するイメージです。

.
└──dataset
  ├──annotations.csv
  ├──image1.png
  ├──image2.png
  ⋮

  • annotations.csvの中身(例)
img text
PXL_20251004_070420578_30.png 必ずまた帰ってくる
PXL_20251004_070420578_31.png 東京からまんまで宇宙へ
PXL_20251004_070420578_32.png 今年も似合いの夏

前処理

前処理は基本的にはモデルに対応するProcessorを使用することになります。 今回の例だとこのようになります。
このProcessorでは224x224のリサイズと画像の正規化、torch.Tensorへの変換処理が含まれています。 224x224のリサイズがある関係で、横長や縦長の画像を対象とする場合にデータの欠損が起こりそうではあります…。

# model_name = "kha-white/manga-ocr-base"
processor = TrOCRProcessor.from_pretrained(model_name)

ソースコードではこのProcessorのほか、Albumentationsを使った画像の前処理も加えています。 デフォルトだと以下の処理が含まれます。 modules/loader/augmentation.pyaug_listにAlbumentationsのお好きな処理を追加することもできます。 またデフォルトには含めていませんが、同じソースにカスタムの前処理をいくつか実装しています。

aug_list = [
    # 色調変化のうちいずれか1つをランダムに適用
    A.OneOf([
        A.HueSaturationValue(),  # 画像の色相・彩度・明度をランダムに変化させる(全体の色味を調整)
        A.RGBShift(),            # 各RGBチャンネルをランダムにシフトさせて色のバランスを変える(照明やカメラ特性の変化を模倣)
        A.InvertImg()            # 画像の色を反転(白→黒、黒→白)させる(ネガ画像などへの耐性向上)
    ]),
    
    # 明るさとコントラストをランダムに変化させる(明暗差や照明条件の変動に対応)
    A.RandomBrightnessContrast(
        brightness_limit=0.2,   # 明るさを±20%の範囲で変化
        contrast_limit=0.2,     # コントラストを±20%の範囲で変化
        p=0.5                   # 50%の確率で適用
    ),
    
    # CLAHE(ヒストグラム均等化)を適用し、コントラストを局所的に強調(文字や模様の視認性向上)
    A.CLAHE(p=0.2),             # 20%の確率で適用
    
    # JPEG圧縮をシミュレートして画質を劣化させる(低品質画像に対するロバスト性を向上)
    A.ImageCompression(p=0.3)   # 30%の確率で適用
]

実行結果

実際にソースコードを使用して作成した結果を載せておきます。 結果としてはあまりうまくいっていないようにも思いますが参考程度に。

学習に使用したデータセットはHuggingFaceから利用できる以下のデータセットと自作のデータセットを混ぜたものです。 huggingface.co

自作データセットの方は以前記事で紹介したYomiTokuを使ったプログラムを使って作成しています。 fallpoke-tech.hatenadiary.jp

データとしては15万枚+2000枚程度あり100epochの学習が完了するのに1週間ほどかかりました…。 それでいてあまり精度も良いとは言えない感じなのが残念です。いろいろ見直しは必要になりそうです…。

学習

学習曲線がこのようになりました。 TrainとValidationの段階でそれぞれLossを計算、またValidationではCER(Character Error Rate)も計算して表示しています。

CERについて詳しくはこちら参照。
https://qiita.com/Kchan/items/7bba1f066234ba24898b

learning_curve

推論

作成したモデルを使って推論を実行した結果を一部載せておきます。 プログラムの出力としては以下のような入力画像と正解テキスト、検出テキストをcsvファイルにまとめています。

result

  • PXL_20251001_141512690_12.png
    正解テキスト: 桜の花、舞い上がる道を
    検出テキスト: 桜 の 、 い 上 本 道
    sakuranohana

文字の一部は検出できていますが、全体で正解してはいない感じですね…。 横長の画像が多いので224x224のリサイズで文字がつぶれてしまっていそうです。