fix(train): make NaN fast-fail check explicit
This commit is contained in:
parent
b5ba8ac00d
commit
ebf357841b
3
train.py
3
train.py
@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|||||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import math
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
@ -566,7 +567,7 @@ while True:
|
|||||||
train_loss_f = train_loss.item()
|
train_loss_f = train_loss.item()
|
||||||
|
|
||||||
# Fast fail: abort if loss is exploding or NaN
|
# Fast fail: abort if loss is exploding or NaN
|
||||||
if not train_loss_f <= 100:
|
if math.isnan(train_loss_f) or train_loss_f > 100:
|
||||||
print("FAIL")
|
print("FAIL")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user