initial commit
This commit is contained in:
commit
b11d6f283f
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
worktrees/
|
||||
results/
|
||||
queue/
|
||||
|
||||
# Agent prompt files (generated per-session by launchers)
|
||||
CLAUDE.md
|
||||
AGENTS.md
|
||||
|
||||
# Experimental code/artifacts
|
||||
dev/
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.10
|
||||
61
README.md
Normal file
61
README.md
Normal file
@ -0,0 +1,61 @@
|
||||
# autoresearch
|
||||
|
||||
Autonomous LLM pretraining research, driven by AI agents.
|
||||
|
||||
The idea: give an AI agent a small but real LLM training setup and let it run experiments overnight. It modifies the code, trains for 5 minutes, checks if the result improved, keeps or discards, and repeats. You wake up in the morning to a log of experiments and (hopefully) a better model.
|
||||
|
||||
This particular implementation is trying to be the least fancy baseline, but it's clear how one would adjust the `program.md` file to run more sophisticated research programs with more elaborate instructions. For example, the agent can actively do little experiments on research while the job is running.
|
||||
|
||||
## How it works
|
||||
|
||||
The repo is deliberately small and only has a few files:
|
||||
|
||||
- **`constants.py`** — fixed rules: sequence length, time budget, eval tokens. Not modified.
|
||||
- **`prepare.py`** — one-time data prep (downloads training data, trains a BPE tokenizer) and runtime utilities (dataloader, evaluation). Not modified.
|
||||
- **`train.py`** — the single file the agent edits. Contains the full GPT model, optimizer (Muon + AdamW), and training loop. Everything is fair game: architecture, hyperparameters, optimizer, batch size, etc.
|
||||
- **`program.md`** — instructions for the agent. Point your agent here and let it go.
|
||||
|
||||
Training runs for a **fixed 5-minute time budget** (wall clock, excluding startup/compilation). The metric is **val_bpb** (validation bits per byte) — lower is better, and vocab-size-independent so architectural changes are fairly compared.
|
||||
|
||||
## Quick start
|
||||
|
||||
**Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/).
|
||||
|
||||
```bash
|
||||
# 1. Install dependencies
|
||||
uv sync
|
||||
|
||||
# 2. Download data and train tokenizer (one-time, ~5 min)
|
||||
uv run prepare.py
|
||||
|
||||
# 3. Run a single training experiment (5 min + startup)
|
||||
uv run train.py
|
||||
```
|
||||
|
||||
## Running the agent
|
||||
|
||||
Simply spin up your Claude/Codex or whatever you want in this repo, then you can something like:
|
||||
|
||||
```
|
||||
Hi have a look at program.md and let's kick off a new experiment! let's do the setup first.
|
||||
```
|
||||
|
||||
The `program.md` file is essentially a super lightweight "skill".
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
constants.py — fixed constants (do not modify)
|
||||
prepare.py — data prep + runtime utilities (do not modify)
|
||||
train.py — model, optimizer, training loop (agent modifies this)
|
||||
program.md — agent instructions
|
||||
spawn.sh — multi-agent launcher
|
||||
pyproject.toml — dependencies
|
||||
```
|
||||
|
||||
## Design choices
|
||||
|
||||
- **Single file to modify.** The agent only touches `train.py`. This keeps the scope manageable and diffs reviewable.
|
||||
- **Fixed time budget.** Training always runs for exactly 5 minutes. This makes experiments directly comparable regardless of what the agent changes (model size, batch size, architecture, etc).
|
||||
- **BPB metric.** Bits per byte is independent of tokenizer vocabulary size, so the agent could in principle change the vocab size and still get a fair comparison.
|
||||
- **Self-contained.** No external dependencies beyond PyTorch and a few small packages. No distributed training, no complex configs. One GPU, one file, one metric.
|
||||
7
constants.py
Normal file
7
constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
Fixed constants for autoresearch. Do not modify.
|
||||
"""
|
||||
|
||||
MAX_SEQ_LEN = 2048 # context length
|
||||
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
|
||||
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
|
||||
373
prepare.py
Normal file
373
prepare.py
Normal file
@ -0,0 +1,373 @@
|
||||
"""
|
||||
One-time data preparation for autoresearch experiments.
|
||||
Downloads data shards and trains a BPE tokenizer.
|
||||
|
||||
Usage:
|
||||
python prepare.py # full prep (download + tokenizer)
|
||||
python prepare.py --num-shards 8 # download only 8 shards (for testing)
|
||||
|
||||
Data and tokenizer are stored in ~/.cache/autoresearch/.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
import pickle
|
||||
from multiprocessing import Pool
|
||||
|
||||
import requests
|
||||
import pyarrow.parquet as pq
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import torch
|
||||
|
||||
from constants import MAX_SEQ_LEN, EVAL_TOKENS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
|
||||
DATA_DIR = os.path.join(CACHE_DIR, "data")
|
||||
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
||||
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
||||
VOCAB_SIZE = 8192
|
||||
|
||||
# BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(16)]
|
||||
BOS_TOKEN = "<|reserved_0|>"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data download
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def download_single_shard(index):
|
||||
"""Download one parquet shard with retries. Returns True on success."""
|
||||
filename = f"shard_{index:05d}.parquet"
|
||||
filepath = os.path.join(DATA_DIR, filename)
|
||||
if os.path.exists(filepath):
|
||||
return True
|
||||
|
||||
url = f"{BASE_URL}/{filename}"
|
||||
max_attempts = 5
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
temp_path = filepath + ".tmp"
|
||||
with open(temp_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
os.rename(temp_path, filepath)
|
||||
print(f" Downloaded {filename}")
|
||||
return True
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
for path in [filepath + ".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
if attempt < max_attempts:
|
||||
time.sleep(2 ** attempt)
|
||||
return False
|
||||
|
||||
|
||||
def download_data(num_shards):
|
||||
"""Download data shards in parallel."""
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
ids = list(range(min(num_shards, MAX_SHARD + 1)))
|
||||
|
||||
# Count what's already downloaded
|
||||
existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
|
||||
if existing == len(ids):
|
||||
print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
|
||||
return
|
||||
|
||||
needed = len(ids) - existing
|
||||
print(f"Data: downloading {needed} shards ({existing} already exist)...")
|
||||
|
||||
workers = min(8, needed)
|
||||
with Pool(processes=workers) as pool:
|
||||
results = pool.map(download_single_shard, ids)
|
||||
|
||||
ok = sum(1 for r in results if r)
|
||||
print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenizer training
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def list_parquet_files():
|
||||
"""Return sorted list of parquet file paths in the data directory."""
|
||||
files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
|
||||
return [os.path.join(DATA_DIR, f) for f in files]
|
||||
|
||||
|
||||
def text_iterator(max_chars=2_000_000_000, doc_cap=10_000):
|
||||
"""Yield documents from training split (all shards except last)."""
|
||||
parquet_paths = list_parquet_files()[:-1] # last shard is val
|
||||
nchars = 0
|
||||
for filepath in parquet_paths:
|
||||
pf = pq.ParquetFile(filepath)
|
||||
for rg_idx in range(pf.num_row_groups):
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
for text in rg.column("text").to_pylist():
|
||||
doc = text[:doc_cap] if len(text) > doc_cap else text
|
||||
nchars += len(doc)
|
||||
yield doc
|
||||
if nchars >= max_chars:
|
||||
return
|
||||
|
||||
|
||||
def train_tokenizer():
|
||||
"""Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
|
||||
tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
|
||||
token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
|
||||
|
||||
if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
|
||||
print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
|
||||
return
|
||||
|
||||
os.makedirs(TOKENIZER_DIR, exist_ok=True)
|
||||
|
||||
parquet_files = list_parquet_files()
|
||||
if len(parquet_files) < 2:
|
||||
print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
|
||||
sys.exit(1)
|
||||
|
||||
# --- Train with rustbpe ---
|
||||
print("Tokenizer: training BPE tokenizer...")
|
||||
t0 = time.time()
|
||||
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
|
||||
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
|
||||
|
||||
# Build tiktoken encoding from trained merges
|
||||
pattern = tokenizer.get_pattern()
|
||||
mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
|
||||
tokens_offset = len(mergeable_ranks)
|
||||
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
||||
enc = tiktoken.Encoding(
|
||||
name="rustbpe",
|
||||
pat_str=pattern,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
with open(tokenizer_pkl, "wb") as f:
|
||||
pickle.dump(enc, f)
|
||||
|
||||
t1 = time.time()
|
||||
print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
|
||||
|
||||
# --- Build token_bytes lookup for BPB evaluation ---
|
||||
print("Tokenizer: building token_bytes lookup...")
|
||||
special_set = set(SPECIAL_TOKENS)
|
||||
token_bytes_list = []
|
||||
for token_id in range(enc.n_vocab):
|
||||
token_str = enc.decode([token_id])
|
||||
if token_str in special_set:
|
||||
token_bytes_list.append(0)
|
||||
else:
|
||||
token_bytes_list.append(len(token_str.encode("utf-8")))
|
||||
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
|
||||
torch.save(token_bytes_tensor, token_bytes_path)
|
||||
print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
|
||||
|
||||
# Sanity check
|
||||
test = "Hello world! Numbers: 123. Unicode: 你好"
|
||||
encoded = enc.encode_ordinary(test)
|
||||
decoded = enc.decode(encoded)
|
||||
assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
|
||||
print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runtime utilities (imported by train.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Tokenizer:
|
||||
"""Minimal tokenizer wrapper. Training is handled above."""
|
||||
|
||||
def __init__(self, enc):
|
||||
self.enc = enc
|
||||
self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
|
||||
enc = pickle.load(f)
|
||||
return cls(enc)
|
||||
|
||||
def get_vocab_size(self):
|
||||
return self.enc.n_vocab
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self.bos_token_id
|
||||
|
||||
def encode(self, text, prepend=None, num_threads=8):
|
||||
if prepend is not None:
|
||||
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
|
||||
if isinstance(text, str):
|
||||
ids = self.enc.encode_ordinary(text)
|
||||
if prepend is not None:
|
||||
ids.insert(0, prepend_id)
|
||||
elif isinstance(text, list):
|
||||
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
||||
if prepend is not None:
|
||||
for row in ids:
|
||||
row.insert(0, prepend_id)
|
||||
else:
|
||||
raise ValueError(f"Invalid input type: {type(text)}")
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.enc.decode(ids)
|
||||
|
||||
|
||||
def get_token_bytes(device="cpu"):
|
||||
path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
|
||||
with open(path, "rb") as f:
|
||||
return torch.load(f, map_location=device)
|
||||
|
||||
|
||||
def _document_batches(split, tokenizer_batch_size=128):
|
||||
"""Infinite iterator over document batches from parquet files."""
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
epoch = 1
|
||||
while True:
|
||||
for filepath in parquet_paths:
|
||||
pf = pq.ParquetFile(filepath)
|
||||
for rg_idx in range(pf.num_row_groups):
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist()
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], epoch
|
||||
epoch += 1
|
||||
|
||||
|
||||
def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
|
||||
"""
|
||||
BOS-aligned dataloader with best-fit packing.
|
||||
Every row starts with BOS. Documents packed using best-fit to minimize cropping.
|
||||
When no document fits remaining space, crops shortest doc to fill exactly.
|
||||
100% utilization (no padding).
|
||||
"""
|
||||
assert split in ["train", "val"]
|
||||
row_capacity = T + 1
|
||||
batches = _document_batches(split)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
epoch = 1
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal epoch
|
||||
doc_batch, epoch = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
|
||||
doc_buffer.extend(token_lists)
|
||||
|
||||
# Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
|
||||
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T)
|
||||
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
|
||||
while True:
|
||||
for row_idx in range(B):
|
||||
pos = 0
|
||||
while pos < row_capacity:
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - pos
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
|
||||
pos += len(doc)
|
||||
else:
|
||||
# No doc fits — crop shortest to fill remaining
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
||||
pos += remaining
|
||||
|
||||
cpu_inputs.copy_(row_buffer[:, :-1])
|
||||
cpu_targets.copy_(row_buffer[:, 1:])
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
|
||||
yield inputs, targets, epoch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation (DO NOT CHANGE — this is the fixed metric)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, tokenizer, batch_size):
|
||||
"""
|
||||
Bits per byte (BPB): vocab size-independent evaluation metric.
|
||||
Sums per-token cross-entropy (in nats), sums target byte lengths,
|
||||
then converts nats/byte to bits/byte. Special tokens (byte length 0)
|
||||
are excluded from both sums.
|
||||
Uses fixed MAX_SEQ_LEN so results are comparable across configs.
|
||||
"""
|
||||
token_bytes = get_token_bytes(device="cuda")
|
||||
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
|
||||
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
|
||||
total_nats = 0.0
|
||||
total_bytes = 0
|
||||
for _ in range(steps):
|
||||
x, y, _ = next(val_loader)
|
||||
loss_flat = model(x, y, reduction='none').view(-1)
|
||||
y_flat = y.view(-1)
|
||||
nbytes = token_bytes[y_flat]
|
||||
mask = nbytes > 0
|
||||
total_nats += (loss_flat * mask).sum().item()
|
||||
total_bytes += nbytes.sum().item()
|
||||
return total_nats / (math.log(2) * total_bytes)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
|
||||
parser.add_argument("--num-shards", type=int, default=8, help="Number of shards to download (-1 = all 1823)")
|
||||
parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
num_shards = MAX_SHARD + 1 if args.num_shards == -1 else args.num_shards
|
||||
|
||||
print(f"Cache directory: {CACHE_DIR}")
|
||||
print()
|
||||
|
||||
# Step 1: Download data
|
||||
download_data(num_shards)
|
||||
print()
|
||||
|
||||
# Step 2: Train tokenizer
|
||||
train_tokenizer()
|
||||
print()
|
||||
print("Done! Ready to train.")
|
||||
118
program.md
Normal file
118
program.md
Normal file
@ -0,0 +1,118 @@
|
||||
# autoresearch
|
||||
|
||||
This is an experiment to have the LLM do its own research.
|
||||
|
||||
## Setup
|
||||
|
||||
Check your git state. There are two cases:
|
||||
|
||||
**If launched via `spawn.sh` (multi-agent):** The branch and worktree already exist. You are already on the right branch. Skip to step 3.
|
||||
|
||||
**If launched manually (single agent):**
|
||||
|
||||
1. **Agree on a run tag** with the human: propose a tag based on today's date (e.g. `mar5`). The branch `autoresearch/<tag>` must not already exist — this is a fresh run.
|
||||
2. **Create the branch**: `git checkout -b autoresearch/<tag>` from current master.
|
||||
|
||||
**Then, in both cases:**
|
||||
|
||||
3. **Read the in-scope files**: The repo is small. Read these files for full context:
|
||||
- `constants.py` — fixed constants (`MAX_SEQ_LEN`, `TIME_BUDGET`, `EVAL_TOKENS`). Do not modify.
|
||||
- `prepare.py` — data prep, tokenizer, dataloader, evaluation. Do not modify.
|
||||
- `train.py` — the file you modify. Model architecture, optimizer, training loop.
|
||||
4. **Verify data exists**: Check that `~/.cache/autoresearch/` contains data shards and a tokenizer. If not, tell the human to run `uv run prepare.py`.
|
||||
5. **Initialize results.tsv**: Create `results.tsv` with header row and baseline entry. The baseline results are already known from the output format section below (val_bpb: 0.997900, peak_vram_mb: 45060.2). Do NOT re-run the baseline — just record it.
|
||||
6. **Confirm and go**: If a human is present, confirm setup looks good. If launched via `spawn.sh`, proceed directly into the autonomous experiment loop.
|
||||
|
||||
## Experimentation
|
||||
|
||||
Each experiment runs on a single GPU. The training script runs for a **fixed time budget of 5 minutes** (wall clock training time, excluding startup/compilation). You launch it simply as: `uv run train.py`.
|
||||
|
||||
**What you CAN do:**
|
||||
- Modify `train.py` — this is the only file you edit. Everything is fair game: model architecture, optimizer, hyperparameters, training loop, batch size, model size, etc.
|
||||
|
||||
**What you CANNOT do:**
|
||||
- Modify `constants.py` or `prepare.py`. These are read-only. They contain the fixed evaluation, data loading, tokenizer, and training constants (time budget, sequence length, etc).
|
||||
- Install new packages or add dependencies. You can only use what's already in `pyproject.toml`.
|
||||
- Modify the evaluation harness. The `evaluate_bpb` function in `prepare.py` is the ground truth metric.
|
||||
|
||||
**The goal is simple: get the lowest val_bpb.** Since the time budget is fixed, you don't need to worry about training time — it's always 5 minutes. Everything is fair game: change the architecture, the optimizer, the hyperparameters, the batch size, the model size. The only constraint is that the code runs without crashing and finishes within the time budget.
|
||||
|
||||
**VRAM** is a soft constraint. Some increase is acceptable for meaningful val_bpb gains, but it should not blow up dramatically.
|
||||
|
||||
**Simplicity criterion**: All else being equal, simpler is better. A small improvement that adds ugly complexity is not worth it. Conversely, removing something and getting equal or better results is a great outcome — that's a simplification win. When evaluating whether to keep a change, weigh the complexity cost against the improvement magnitude. A 0.001 val_bpb improvement that adds 20 lines of hacky code? Probably not worth it. A 0.001 val_bpb improvement from deleting code? Definitely keep. An improvement of ~0 but much simpler code? Keep.
|
||||
|
||||
## Output format
|
||||
|
||||
Once the script finishes it prints a summary like this:
|
||||
|
||||
```
|
||||
---
|
||||
val_bpb: 0.997900
|
||||
training_seconds: 300.1
|
||||
total_seconds: 325.9
|
||||
peak_vram_mb: 45060.2
|
||||
mfu_percent: 39.80
|
||||
total_tokens_M: 499.6
|
||||
num_steps: 953
|
||||
num_params_M: 50.3
|
||||
depth: 8
|
||||
```
|
||||
|
||||
This is the baseline to beat.
|
||||
|
||||
You can extract the key metric from the log:
|
||||
|
||||
```
|
||||
grep "^val_bpb:" run.log
|
||||
```
|
||||
|
||||
## Logging results
|
||||
|
||||
When an experiment is done, log it to `results.tsv` (tab-separated, NOT comma-separated — commas break in descriptions).
|
||||
|
||||
The TSV has a header row and 5 columns:
|
||||
|
||||
```
|
||||
commit val_bpb memory_gb status description
|
||||
```
|
||||
|
||||
1. git commit hash (short, 7 chars)
|
||||
2. val_bpb achieved (e.g. 1.234567) — use 0.000000 for crashes
|
||||
3. peak memory in GB, round to .1f (e.g. 12.3 — divide peak_vram_mb by 1024) — use 0.0 for crashes
|
||||
4. status: `keep`, `discard`, or `crash`
|
||||
5. short text description of what this experiment tried
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
commit val_bpb memory_gb status description
|
||||
a1b2c3d 0.997900 44.0 keep baseline
|
||||
b2c3d4e 0.993200 44.2 keep increase LR to 0.04
|
||||
c3d4e5f 1.005000 44.0 discard switch to GeLU activation
|
||||
d4e5f6g 0.000000 0.0 crash double model width (OOM)
|
||||
```
|
||||
|
||||
## The experiment loop
|
||||
|
||||
The experiment runs on a dedicated branch (e.g. `autoresearch/mar5` or `autoresearch/mar5-gpu0`).
|
||||
|
||||
LOOP FOREVER (until I wake up and come back in the morning):
|
||||
|
||||
1. Look at the git state: the current branch/commit we're on
|
||||
2. Tune `train.py` with an experimental idea by directly hacking the code.
|
||||
3. git commit
|
||||
4. run the experiment: `uv run train.py > run.log 2>&1` (redirect everything — do NOT use tee or let output flood your context)
|
||||
5. read out the results: `grep "^val_bpb:\|^peak_vram_mb:" run.log`
|
||||
6. record the results in the tsv
|
||||
7. if val_bpb improved (lower), you "advance" the branch, keeping the git commit
|
||||
8. if val_bpb is equal or worse, you git reset back to where you started
|
||||
|
||||
The idea is that you are a completely autonomous researcher trying things out. If they work, keep. If they don't, discard. And you're advancing the branch so that you can iterate. If you feel like you're getting stuck in some way, you can rewind but you should probably do this very very sparingly (if ever).
|
||||
|
||||
**Timeout**: Each experiment should take ~7 minutes total (5 min training + startup + eval). If a run exceeds 10 minutes, kill it and treat it as a failure (discard and revert).
|
||||
|
||||
**Crashes**: If a run crashes (OOM, or a bug, or etc.), use your judgment: If it's something dumb and easy to fix (e.g. a typo, a missing import), fix it and re-run. If the idea itself is fundamentally broken, just skip it, log "CRASH" in the tsv, and move on.
|
||||
|
||||
**NEVER STOP**: Once the experiment loop has begun (after the initial setup), do NOT pause to ask the human if you should continue. Do NOT ask "should I keep going?" or "is this a good stopping point?". The human might be asleep, or gone from a computer and expects you to continue working *indefinitely* until you are manually stopped. You are autonomous. If you run out of ideas, think harder — read papers referenced in the code, re-read the in-scope files for new angles, try combining previous near-misses, try more radical architectural changes. The loop runs until the human interrupts you, period.
|
||||
|
||||
I will usually leave this running for a number of hours, like... 10 or so. If each experiment is ~7 min, you can do ~8/hour, for a total of approx. 80. The hope is that I come back in the morning and we have some improvements.
|
||||
25
pyproject.toml
Normal file
25
pyproject.toml
Normal file
@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "autoresearch"
|
||||
version = "0.1.0"
|
||||
description = "Autonomous pretraining research swarm"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"kernels>=0.11.7",
|
||||
"numpy>=2.2.6",
|
||||
"pyarrow>=21.0.0",
|
||||
"requests>=2.32.0",
|
||||
"rustbpe>=0.1.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"torch==2.9.1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
248
spawn.sh
Executable file
248
spawn.sh
Executable file
@ -0,0 +1,248 @@
|
||||
#!/bin/bash
|
||||
# spawn.sh — launch and manage autonomous research groups
|
||||
#
|
||||
# Usage:
|
||||
# bash spawn.sh launch <tag> <agent:gpu> [agent:gpu ...]
|
||||
# bash spawn.sh stop <tag>
|
||||
# bash spawn.sh status
|
||||
#
|
||||
# Examples:
|
||||
# bash spawn.sh launch mar5 claude:0 claude:1 claude:2 claude:3
|
||||
# bash spawn.sh launch mar5 opus:0 sonnet:1 codex:2 codex:3
|
||||
# bash spawn.sh stop mar5
|
||||
# bash spawn.sh status
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "$0")" && pwd)"
|
||||
WORKTREE_DIR="${REPO_ROOT}/worktrees"
|
||||
INITIAL_PROMPT="Read program.md and follow the instructions."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
usage() {
|
||||
echo "Usage:"
|
||||
echo " $0 launch <tag> <agent:gpu> [agent:gpu ...]"
|
||||
echo " $0 stop <tag>"
|
||||
echo " $0 status"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 launch mar5 claude:0 claude:1 claude:2 claude:3"
|
||||
echo " $0 launch mar5 opus:0 sonnet:1 codex:2 codex:3"
|
||||
echo " $0 stop mar5"
|
||||
exit 1
|
||||
}
|
||||
|
||||
setup_worker() {
|
||||
local tag="$1" gpu="$2" agent="$3"
|
||||
local branch="autoresearch/${tag}-gpu${gpu}"
|
||||
local worktree="${WORKTREE_DIR}/gpu${gpu}"
|
||||
local claude_model="${CLAUDE_MODEL:-sonnet}"
|
||||
|
||||
# Create branch (must be fresh)
|
||||
cd "$REPO_ROOT"
|
||||
if git show-ref --verify --quiet "refs/heads/${branch}"; then
|
||||
echo "ERROR: Branch '${branch}' already exists. Use a different tag or clean up first." >&2
|
||||
exit 1
|
||||
fi
|
||||
git checkout -q master
|
||||
git checkout -q -b "$branch"
|
||||
git checkout -q master
|
||||
|
||||
# Set up worktree
|
||||
if [ -d "$worktree" ]; then
|
||||
git worktree remove "$worktree" --force 2>/dev/null || rm -rf "$worktree"
|
||||
fi
|
||||
mkdir -p "$(dirname "$worktree")"
|
||||
git worktree add "$worktree" "$branch"
|
||||
|
||||
# Symlink .venv so uv/python work
|
||||
ln -sf "${REPO_ROOT}/.venv" "${worktree}/.venv"
|
||||
|
||||
# Build agent command
|
||||
case "$agent" in
|
||||
claude|sonnet|opus|haiku)
|
||||
case "$agent" in
|
||||
sonnet|opus|haiku) claude_model="$agent" ;;
|
||||
esac
|
||||
echo "cd ${worktree} && CUDA_VISIBLE_DEVICES=${gpu} claude --dangerously-skip-permissions --model ${claude_model} \"${INITIAL_PROMPT}\""
|
||||
;;
|
||||
codex)
|
||||
echo "cd ${worktree} && CUDA_VISIBLE_DEVICES=${gpu} codex --dangerously-bypass-approvals-and-sandbox --model gpt-5.3-codex-spark \"${INITIAL_PROMPT}\""
|
||||
;;
|
||||
*)
|
||||
echo "ERROR: Unknown agent '${agent}'. Supported: claude, sonnet, opus, haiku, codex" >&2
|
||||
git worktree remove "$worktree" --force 2>/dev/null || true
|
||||
git branch -D "$branch" 2>/dev/null || true
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
cmd_launch() {
|
||||
if [ $# -lt 2 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
local tag="$1"
|
||||
shift
|
||||
local specs=("$@")
|
||||
local tmux_session="autoresearch-${tag}"
|
||||
local num_workers=${#specs[@]}
|
||||
|
||||
# Kill existing session
|
||||
tmux kill-session -t "$tmux_session" 2>/dev/null || true
|
||||
|
||||
echo ""
|
||||
echo "============================================"
|
||||
echo " RESEARCH GROUP: ${tag}"
|
||||
echo "============================================"
|
||||
echo " Workers: ${num_workers}"
|
||||
|
||||
local pane_cmds=()
|
||||
local labels=()
|
||||
local agents=()
|
||||
|
||||
for spec in "${specs[@]}"; do
|
||||
local agent="${spec%%:*}"
|
||||
local gpu="${spec#*:}"
|
||||
|
||||
echo " - GPU ${gpu}: ${agent}"
|
||||
|
||||
local cmd
|
||||
cmd=$(setup_worker "$tag" "$gpu" "$agent")
|
||||
pane_cmds+=("$cmd")
|
||||
labels+=("$spec")
|
||||
case "$agent" in
|
||||
sonnet|opus|haiku|claude) agents+=("claude") ;;
|
||||
*) agents+=("$agent") ;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo " tmux: ${tmux_session}"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
|
||||
# Create tmux session with tiled grid
|
||||
tmux new-session -d -s "$tmux_session" -n workers \
|
||||
"${pane_cmds[0]}; echo ''; echo 'Session ended. Press any key to exit.'; read"
|
||||
|
||||
for ((i=1; i<num_workers; i++)); do
|
||||
tmux split-window -t "$tmux_session:workers" \
|
||||
"${pane_cmds[$i]}; echo ''; echo 'Session ended. Press any key to exit.'; read"
|
||||
tmux select-layout -t "$tmux_session:workers" tiled
|
||||
done
|
||||
tmux select-layout -t "$tmux_session:workers" tiled
|
||||
|
||||
echo " Attach: tmux attach -t ${tmux_session}"
|
||||
echo " Detach: Ctrl-b d"
|
||||
echo " Zoom: Ctrl-b z (toggle pane fullscreen)"
|
||||
echo " Navigate: Ctrl-b arrow keys"
|
||||
echo ""
|
||||
|
||||
# Background watcher — nudge idle agents
|
||||
local nudge_interval=120
|
||||
echo "Watcher running (nudge interval: ${nudge_interval}s). Ctrl-C to stop watcher only."
|
||||
echo ""
|
||||
|
||||
declare -A prev_hash
|
||||
|
||||
while tmux has-session -t "$tmux_session" 2>/dev/null; do
|
||||
sleep "$nudge_interval"
|
||||
|
||||
for ((i=0; i<num_workers; i++)); do
|
||||
tmux has-session -t "$tmux_session" 2>/dev/null || break
|
||||
|
||||
local pane="${tmux_session}:workers.${i}"
|
||||
local content curr_hash
|
||||
|
||||
content=$(tmux capture-pane -t "$pane" -p 2>/dev/null) || continue
|
||||
curr_hash=$(echo "$content" | md5sum | cut -d' ' -f1)
|
||||
|
||||
if [ "${prev_hash[$i]:-}" != "$curr_hash" ]; then
|
||||
prev_hash[$i]="$curr_hash"
|
||||
continue
|
||||
fi
|
||||
|
||||
local is_idle=false
|
||||
if echo "$content" | grep -qP '^❯\s*$'; then
|
||||
is_idle=true
|
||||
elif echo "$content" | grep -qP '^›\s*$'; then
|
||||
is_idle=true
|
||||
fi
|
||||
|
||||
if [ "$is_idle" = true ]; then
|
||||
local nudge_msg="Keep going. Do not stop — continue your research loop."
|
||||
if [ "${agents[$i]}" = "codex" ]; then
|
||||
tmux send-keys -t "$pane" "$nudge_msg" Enter
|
||||
else
|
||||
tmux send-keys -t "$pane" "$nudge_msg" C-m
|
||||
fi
|
||||
echo "[$(date +%H:%M:%S)] Nudged pane ${i} (${labels[$i]})"
|
||||
prev_hash[$i]=""
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "tmux session ended. Watcher exiting."
|
||||
}
|
||||
|
||||
cmd_stop() {
|
||||
if [ $# -lt 1 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
local tag="$1"
|
||||
local tmux_session="autoresearch-${tag}"
|
||||
|
||||
# Kill tmux session
|
||||
if tmux kill-session -t "$tmux_session" 2>/dev/null; then
|
||||
echo "Killed tmux session '${tmux_session}'"
|
||||
else
|
||||
echo "No tmux session '${tmux_session}' found"
|
||||
fi
|
||||
|
||||
# Remove worktrees
|
||||
local removed=0
|
||||
for wt in "${WORKTREE_DIR}"/gpu*; do
|
||||
[ -d "$wt" ] || continue
|
||||
git -C "$REPO_ROOT" worktree remove "$wt" --force 2>/dev/null && removed=$((removed + 1))
|
||||
done
|
||||
echo "Removed ${removed} worktree(s)"
|
||||
echo "Branches kept for review (git branch -l 'autoresearch/${tag}-*')"
|
||||
}
|
||||
|
||||
cmd_status() {
|
||||
echo "Active tmux sessions:"
|
||||
tmux list-sessions 2>/dev/null | grep "^autoresearch-" || echo " (none)"
|
||||
echo ""
|
||||
echo "Active worktrees:"
|
||||
git -C "$REPO_ROOT" worktree list 2>/dev/null | grep "worktrees/" || echo " (none)"
|
||||
echo ""
|
||||
echo "Research branches:"
|
||||
git -C "$REPO_ROOT" branch -l 'autoresearch/*' 2>/dev/null || echo " (none)"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if [ $# -lt 1 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
COMMAND="$1"
|
||||
shift
|
||||
|
||||
case "$COMMAND" in
|
||||
launch) cmd_launch "$@" ;;
|
||||
stop) cmd_stop "$@" ;;
|
||||
status) cmd_status "$@" ;;
|
||||
*) usage ;;
|
||||
esac
|
||||
628
train.py
Normal file
628
train.py
Normal file
@ -0,0 +1,628 @@
|
||||
"""
|
||||
Autoresearch pretraining script. Single-GPU, single-file.
|
||||
Cherry-picked and simplified from nanochat.
|
||||
Usage: uv run train.py
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from kernels import get_kernel
|
||||
fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
|
||||
from constants import MAX_SEQ_LEN, TIME_BUDGET
|
||||
from prepare import Tokenizer, make_dataloader, evaluate_bpb
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPT Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 12
|
||||
n_head: int = 6
|
||||
n_kv_head: int = 6
|
||||
n_embd: int = 768
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
||||
def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if layer should have Value Embedding (alternating, last always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:]
|
||||
y1 = x1 * cos + x2 * sin
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
return torch.cat([y1, y2], 3)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.n_head = config.n_head
|
||||
self.n_kv_head = config.n_kv_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size):
|
||||
B, T, C = x.size()
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k)
|
||||
|
||||
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
|
||||
# Value embeddings
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({
|
||||
str(i): nn.Embedding(config.vocab_size, kv_dim)
|
||||
for i in range(config.n_layer) if has_ve(i, config.n_layer)
|
||||
})
|
||||
# Rotary embeddings
|
||||
self.rotary_seq_len = config.sequence_len * 10
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
# Transformer blocks
|
||||
n_embd = self.config.n_embd
|
||||
s = 3**0.5 * n_embd**-0.5
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0)
|
||||
self.x0_lambdas.fill_(0.1)
|
||||
# Value embeddings
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
# Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast embeddings to bf16
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
||||
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
||||
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16()
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
||||
return cos, sin
|
||||
|
||||
def _compute_window_sizes(self, config):
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern)
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
char = pattern[layer_idx % len(pattern)]
|
||||
window_sizes.append(char_to_window[char])
|
||||
window_sizes[-1] = (long_window, 0)
|
||||
return window_sizes
|
||||
|
||||
def estimate_flops(self):
|
||||
"""Estimated FLOPs per token (forward + backward)."""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h = self.config.n_head
|
||||
q = self.config.n_embd // self.config.n_head
|
||||
t = self.config.sequence_len
|
||||
attn_flops = 0
|
||||
for window_size in self.window_sizes:
|
||||
window = window_size[0]
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
return 6 * (nparams - nparams_exclude) + attn_flops
|
||||
|
||||
def num_scaling_params(self):
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
return {
|
||||
'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
|
||||
weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
|
||||
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
|
||||
# Scale LR ∝ 1/√dmodel (tuned at 768 dim)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
param_groups = [
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
optimizer = MuonAdamW(param_groups)
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None, reduction='mean'):
|
||||
B, T = idx.size()
|
||||
assert T <= self.cos.size(1)
|
||||
cos_sin = self.cos[:, :T], self.sin[:, :T]
|
||||
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
x0 = x
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i])
|
||||
x = norm(x)
|
||||
|
||||
softcap = 15
|
||||
logits = self.lm_head(x)
|
||||
logits = logits.float()
|
||||
logits = softcap * torch.tanh(logits / softcap)
|
||||
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
|
||||
ignore_index=-1, reduction=reduction)
|
||||
return loss
|
||||
return logits
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optimizer (MuonAdamW, single GPU only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
||||
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
# Polar express orthogonalization
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1):
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X.mT @ X
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + X @ B
|
||||
else:
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
# NorMuon variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
|
||||
class MuonAdamW(torch.optim.Optimizer):
|
||||
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _step_adamw(self, group):
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
state['step'] += 1
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
||||
|
||||
def _step_muon(self, group):
|
||||
params = group['params']
|
||||
if not params:
|
||||
return
|
||||
p = params[0]
|
||||
state = self.state[p]
|
||||
num_params = len(params)
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
stacked_grads = torch.stack([p.grad for p in params])
|
||||
stacked_params = torch.stack(params)
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
muon_step_fused(stacked_grads, stacked_params,
|
||||
state["momentum_buffer"], state["second_momentum_buffer"],
|
||||
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
||||
self._muon_beta2_t, group["ns_steps"], red_dim)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
if group['kind'] == 'adamw':
|
||||
self._step_adamw(group)
|
||||
elif group['kind'] == 'muon':
|
||||
self._step_muon(group)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hyperparameters (edit these directly, no CLI flags needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Model architecture
|
||||
ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
|
||||
HEAD_DIM = 128 # target head dimension for attention
|
||||
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
|
||||
|
||||
# Optimization
|
||||
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
|
||||
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
|
||||
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
|
||||
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
|
||||
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
|
||||
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
|
||||
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
|
||||
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
|
||||
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
|
||||
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
|
||||
|
||||
# Model size
|
||||
DEPTH = 8 # number of transformer layers
|
||||
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup: tokenizer, model, optimizer, dataloader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
t_start = time.time()
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
device = torch.device("cuda")
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
H100_BF16_PEAK_FLOPS = 989.5e12
|
||||
|
||||
tokenizer = Tokenizer.from_directory()
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
def build_model_config(depth):
|
||||
base_dim = depth * ASPECT_RATIO
|
||||
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
|
||||
num_heads = model_dim // HEAD_DIM
|
||||
return GPTConfig(
|
||||
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=WINDOW_PATTERN,
|
||||
)
|
||||
|
||||
config = build_model_config(DEPTH)
|
||||
print(f"Model config: {asdict(config)}")
|
||||
|
||||
with torch.device("meta"):
|
||||
model = GPT(config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
|
||||
param_counts = model.num_scaling_params()
|
||||
print("Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print(f" {key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
|
||||
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
||||
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=UNEMBEDDING_LR,
|
||||
embedding_lr=EMBEDDING_LR,
|
||||
scalar_lr=SCALAR_LR,
|
||||
adam_betas=ADAM_BETAS,
|
||||
matrix_lr=MATRIX_LR,
|
||||
weight_decay=WEIGHT_DECAY,
|
||||
)
|
||||
|
||||
model = torch.compile(model, dynamic=False)
|
||||
|
||||
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
|
||||
x, y, epoch = next(train_loader) # prefetch first batch
|
||||
|
||||
print(f"Time budget: {TIME_BUDGET}s")
|
||||
print(f"Gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Schedules (all based on progress = training_time / TIME_BUDGET)
|
||||
|
||||
def get_lr_multiplier(progress):
|
||||
if progress < WARMUP_RATIO:
|
||||
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
||||
elif progress < 1.0 - WARMDOWN_RATIO:
|
||||
return 1.0
|
||||
else:
|
||||
cooldown = (1.0 - progress) / WARMDOWN_RATIO
|
||||
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
|
||||
|
||||
def get_muon_momentum(step):
|
||||
frac = min(step / 300, 1)
|
||||
return (1 - frac) * 0.85 + frac * 0.95
|
||||
|
||||
def get_weight_decay(progress):
|
||||
return WEIGHT_DECAY * (1 - progress)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
t_start_training = time.time()
|
||||
smooth_train_loss = 0
|
||||
total_training_time = 0
|
||||
step = 0
|
||||
|
||||
while True:
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach()
|
||||
loss = loss / grad_accum_steps
|
||||
loss.backward()
|
||||
x, y, epoch = next(train_loader)
|
||||
|
||||
# Progress and schedules
|
||||
progress = min(total_training_time / TIME_BUDGET, 1.0)
|
||||
lrm = get_lr_multiplier(progress)
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(progress)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
train_loss_f = train_loss.item()
|
||||
|
||||
# Fast fail: abort if loss is exploding
|
||||
if train_loss_f > 100:
|
||||
print("FAIL")
|
||||
exit(1)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
|
||||
if step > 10:
|
||||
total_training_time += dt
|
||||
|
||||
# Logging
|
||||
ema_beta = 0.9
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
|
||||
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
|
||||
remaining = max(0, TIME_BUDGET - total_training_time)
|
||||
|
||||
print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
|
||||
|
||||
# GC management (Python's GC causes ~500ms stalls)
|
||||
if step == 0:
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
gc.disable()
|
||||
elif (step + 1) % 5000 == 0:
|
||||
gc.collect()
|
||||
|
||||
step += 1
|
||||
|
||||
# Time's up — but only stop after warmup steps so we don't count compilation
|
||||
if step > 10 and total_training_time >= TIME_BUDGET:
|
||||
break
|
||||
|
||||
print() # newline after \r training log
|
||||
|
||||
total_tokens = step * TOTAL_BATCH_SIZE
|
||||
|
||||
# Final eval
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
|
||||
|
||||
# Final summary
|
||||
t_end = time.time()
|
||||
startup_time = t_start_training - t_start
|
||||
steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
|
||||
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
||||
|
||||
print("---")
|
||||
print(f"val_bpb: {val_bpb:.6f}")
|
||||
print(f"training_seconds: {total_training_time:.1f}")
|
||||
print(f"total_seconds: {t_end - t_start:.1f}")
|
||||
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
|
||||
print(f"mfu_percent: {steady_state_mfu:.2f}")
|
||||
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
|
||||
print(f"num_steps: {step}")
|
||||
print(f"num_params_M: {num_params / 1e6:.1f}")
|
||||
print(f"depth: {DEPTH}")
|
||||
Loading…
Reference in New Issue
Block a user