diff --git a/src/train.py b/src/train.py index 0e47a85..27f8e62 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,4 @@ -from unsloth import FastLanguageModel +from unsloth import FastLanguageModel, FastModel import torch from trl import SFTTrainer, SFTConfig from datasets import load_dataset @@ -29,7 +29,7 @@ fourbit_models = [ ] # More models at https://huggingface.co/unsloth model, tokenizer = FastModel.from_pretrained( - model_name = "unsloth/gemma-3-4B-it", + model_name = "unsloth/gemma-3-1B-it", max_seq_length = 2048, # Choose any for long context! load_in_4bit = True, # 4 bit quantization to reduce memory load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory