fix NaN loss not caught by fast-fail check

This commit is contained in:
Andrej 2026-03-10 22:31:43 -07:00 committed by GitHub
commit 0be1e4fdf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import gc
import math
import time
from dataclasses import dataclass, asdict
@ -565,8 +566,8 @@ while True:
train_loss_f = train_loss.item()
# Fast fail: abort if loss is exploding
if train_loss_f > 100:
# Fast fail: abort if loss is exploding or NaN
if math.isnan(train_loss_f) or train_loss_f > 100:
print("FAIL")
exit(1)