fix(train): make NaN fast-fail check explicit

This commit is contained in:
Contributor 2026-03-11 04:28:08 +00:00
parent b5ba8ac00d
commit ebf357841b

View File

@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import gc
import math
import time
from dataclasses import dataclass, asdict
@ -566,7 +567,7 @@ while True:
train_loss_f = train_loss.item()
# Fast fail: abort if loss is exploding or NaN
if not train_loss_f <= 100:
if math.isnan(train_loss_f) or train_loss_f > 100:
print("FAIL")
exit(1)