お疲れ様です。
今回はPytorchでの処理を実装している際に実際に出くわしたエラーについてのメモ。 torchvisionで発生したOverflowErrorについて調べました。
エラー内容
torchvisionのGitHubのissuesに情報がありましたので、載せておきます。
ほぼ上記issueの内容の通りですが、torchvisionの0.18.x以前のバージョンとnumpyの2.x.x以降のバージョンとの相性の問題のようです。
numpy側の問題のようですね…。
issuesではColorJitterが挙げられていますが、それ以外のtorchvision.transformsの処理でも発生する可能性があります。
私がエラーに出くわしたときはColorJitterとは別の処理でした。
torchvisionの0.19.0以降では解消済みの内容です。 最新のtorchvisionを使う場合は基本的には問題ありませんが、古い環境を使いまわしている場合は注意が必要です。 どうしても0.18.x以前のバージョンのtorchvisionを使いたい場合は、numpyを1.26.x以前のバージョンにすることでエラーを回避できます。
エラーの検証
コードを作成してtorchvisonとnumpyの各バージョンの組み合わせを試してみました。 以下のようなコードを使用します。
import numpy import torchvision from torchvision.transforms import ColorJitter from PIL import Image from tqdm import tqdm print("numpy", numpy.__version__) print("torchvision", torchvision.__version__) img = Image.open("image.jpg") trans = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1) try: for _ in tqdm(range(10), desc="Test"): trans(img) except Exception as e: print("Error:", e) else: print("OK")
torchvision==0.18.0, numpy==2.0.0 -> NG

torchvision==0.18.0, numpy==1.26.4 -> OK

torchvision==0.19.0, numpy==1.26.4 -> OK

torchvision==0.19.0, numpy==2.0.0 -> OK

torchvision==0.23.0, numpy==2.2.6 -> OK
※記事作成現在の最新バージョン
