古いバージョンのtorchvisionで発生するOverflowErrorについてのメモ【備忘録】

お疲れ様です。

今回はPytorchでの処理を実装している際に実際に出くわしたエラーについてのメモ。 torchvisionで発生したOverflowErrorについて調べました。

エラー内容

torchvisionのGitHubのissuesに情報がありましたので、載せておきます。

github.com

ほぼ上記issueの内容の通りですが、torchvisionの0.18.x以前のバージョンとnumpyの2.x.x以降のバージョンとの相性の問題のようです。 numpy側の問題のようですね…。
issuesではColorJitterが挙げられていますが、それ以外のtorchvision.transformsの処理でも発生する可能性があります。 私がエラーに出くわしたときはColorJitterとは別の処理でした。

torchvisionの0.19.0以降では解消済みの内容です。 最新のtorchvisionを使う場合は基本的には問題ありませんが、古い環境を使いまわしている場合は注意が必要です。 どうしても0.18.x以前のバージョンのtorchvisionを使いたい場合は、numpyを1.26.x以前のバージョンにすることでエラーを回避できます。

エラーの検証

コードを作成してtorchvisonとnumpyの各バージョンの組み合わせを試してみました。 以下のようなコードを使用します。

import numpy
import torchvision
from torchvision.transforms import ColorJitter
from PIL import Image
from tqdm import tqdm

print("numpy", numpy.__version__)
print("torchvision", torchvision.__version__)

img = Image.open("image.jpg")

trans = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)

try:
    for _ in tqdm(range(10), desc="Test"):
        trans(img)
except Exception as e:
    print("Error:", e)
else:
    print("OK")
  • torchvision==0.18.0, numpy==2.0.0 -> NG
    check1

  • torchvision==0.18.0, numpy==1.26.4 -> OK
    check2

  • torchvision==0.19.0, numpy==1.26.4 -> OK
    check3

  • torchvision==0.19.0, numpy==2.0.0 -> OK
    check4

  • torchvision==0.23.0, numpy==2.2.6 -> OK
    ※記事作成現在の最新バージョン
    check5