diff --git a/train.py b/train.py index 1378bab..2e74397 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" import gc +import math import time from dataclasses import dataclass, asdict @@ -566,7 +567,7 @@ while True: train_loss_f = train_loss.item() # Fast fail: abort if loss is exploding or NaN - if not train_loss_f <= 100: + if math.isnan(train_loss_f) or train_loss_f > 100: print("FAIL") exit(1)