fix(train): make NaN fast-fail check explicit
This commit is contained in:
parent
b5ba8ac00d
commit
ebf357841b
3
train.py
3
train.py
@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
@ -566,7 +567,7 @@ while True:
|
||||
train_loss_f = train_loss.item()
|
||||
|
||||
# Fast fail: abort if loss is exploding or NaN
|
||||
if not train_loss_f <= 100:
|
||||
if math.isnan(train_loss_f) or train_loss_f > 100:
|
||||
print("FAIL")
|
||||
exit(1)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user