initial commit

This commit is contained in:
Andrej Karpathy 2026-03-06 21:58:52 +00:00
commit b11d6f283f
10 changed files with 2848 additions and 0 deletions

20
.gitignore vendored Normal file
View File

@ -0,0 +1,20 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
worktrees/
results/
queue/
# Agent prompt files (generated per-session by launchers)
CLAUDE.md
AGENTS.md
# Experimental code/artifacts
dev/

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.10

61
README.md Normal file
View File

@ -0,0 +1,61 @@
# autoresearch
Autonomous LLM pretraining research, driven by AI agents.
The idea: give an AI agent a small but real LLM training setup and let it run experiments overnight. It modifies the code, trains for 5 minutes, checks if the result improved, keeps or discards, and repeats. You wake up in the morning to a log of experiments and (hopefully) a better model.
This particular implementation is trying to be the least fancy baseline, but it's clear how one would adjust the `program.md` file to run more sophisticated research programs with more elaborate instructions. For example, the agent can actively do little experiments on research while the job is running.
## How it works
The repo is deliberately small and only has a few files:
- **`constants.py`** — fixed rules: sequence length, time budget, eval tokens. Not modified.
- **`prepare.py`** — one-time data prep (downloads training data, trains a BPE tokenizer) and runtime utilities (dataloader, evaluation). Not modified.
- **`train.py`** — the single file the agent edits. Contains the full GPT model, optimizer (Muon + AdamW), and training loop. Everything is fair game: architecture, hyperparameters, optimizer, batch size, etc.
- **`program.md`** — instructions for the agent. Point your agent here and let it go.
Training runs for a **fixed 5-minute time budget** (wall clock, excluding startup/compilation). The metric is **val_bpb** (validation bits per byte) — lower is better, and vocab-size-independent so architectural changes are fairly compared.
## Quick start
**Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/).
```bash
# 1. Install dependencies
uv sync
# 2. Download data and train tokenizer (one-time, ~5 min)
uv run prepare.py
# 3. Run a single training experiment (5 min + startup)
uv run train.py
```
## Running the agent
Simply spin up your Claude/Codex or whatever you want in this repo, then you can something like:
```
Hi have a look at program.md and let's kick off a new experiment! let's do the setup first.
```
The `program.md` file is essentially a super lightweight "skill".
## Project structure
```
constants.py — fixed constants (do not modify)
prepare.py — data prep + runtime utilities (do not modify)
train.py — model, optimizer, training loop (agent modifies this)
program.md — agent instructions
spawn.sh — multi-agent launcher
pyproject.toml — dependencies
```
## Design choices
- **Single file to modify.** The agent only touches `train.py`. This keeps the scope manageable and diffs reviewable.
- **Fixed time budget.** Training always runs for exactly 5 minutes. This makes experiments directly comparable regardless of what the agent changes (model size, batch size, architecture, etc).
- **BPB metric.** Bits per byte is independent of tokenizer vocabulary size, so the agent could in principle change the vocab size and still get a fair comparison.
- **Self-contained.** No external dependencies beyond PyTorch and a few small packages. No distributed training, no complex configs. One GPU, one file, one metric.

7
constants.py Normal file
View File

@ -0,0 +1,7 @@
"""
Fixed constants for autoresearch. Do not modify.
"""
MAX_SEQ_LEN = 2048 # context length
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval

373
prepare.py Normal file
View File

@ -0,0 +1,373 @@
"""
One-time data preparation for autoresearch experiments.
Downloads data shards and trains a BPE tokenizer.
Usage:
python prepare.py # full prep (download + tokenizer)
python prepare.py --num-shards 8 # download only 8 shards (for testing)
Data and tokenizer are stored in ~/.cache/autoresearch/.
"""
import os
import sys
import time
import math
import argparse
import pickle
from multiprocessing import Pool
import requests
import pyarrow.parquet as pq
import rustbpe
import tiktoken
import torch
from constants import MAX_SEQ_LEN, EVAL_TOKENS
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
DATA_DIR = os.path.join(CACHE_DIR, "data")
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
VOCAB_SIZE = 8192
# BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(16)]
BOS_TOKEN = "<|reserved_0|>"
# ---------------------------------------------------------------------------
# Data download
# ---------------------------------------------------------------------------
def download_single_shard(index):
"""Download one parquet shard with retries. Returns True on success."""
filename = f"shard_{index:05d}.parquet"
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
return True
url = f"{BASE_URL}/{filename}"
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
temp_path = filepath + ".tmp"
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
os.rename(temp_path, filepath)
print(f" Downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
for path in [filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass
if attempt < max_attempts:
time.sleep(2 ** attempt)
return False
def download_data(num_shards):
"""Download data shards in parallel."""
os.makedirs(DATA_DIR, exist_ok=True)
ids = list(range(min(num_shards, MAX_SHARD + 1)))
# Count what's already downloaded
existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
if existing == len(ids):
print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
return
needed = len(ids) - existing
print(f"Data: downloading {needed} shards ({existing} already exist)...")
workers = min(8, needed)
with Pool(processes=workers) as pool:
results = pool.map(download_single_shard, ids)
ok = sum(1 for r in results if r)
print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
# ---------------------------------------------------------------------------
# Tokenizer training
# ---------------------------------------------------------------------------
def list_parquet_files():
"""Return sorted list of parquet file paths in the data directory."""
files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
return [os.path.join(DATA_DIR, f) for f in files]
def text_iterator(max_chars=2_000_000_000, doc_cap=10_000):
"""Yield documents from training split (all shards except last)."""
parquet_paths = list_parquet_files()[:-1] # last shard is val
nchars = 0
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(pf.num_row_groups):
rg = pf.read_row_group(rg_idx)
for text in rg.column("text").to_pylist():
doc = text[:doc_cap] if len(text) > doc_cap else text
nchars += len(doc)
yield doc
if nchars >= max_chars:
return
def train_tokenizer():
"""Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
return
os.makedirs(TOKENIZER_DIR, exist_ok=True)
parquet_files = list_parquet_files()
if len(parquet_files) < 2:
print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
sys.exit(1)
# --- Train with rustbpe ---
print("Tokenizer: training BPE tokenizer...")
t0 = time.time()
tokenizer = rustbpe.Tokenizer()
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
# Build tiktoken encoding from trained merges
pattern = tokenizer.get_pattern()
mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
tokens_offset = len(mergeable_ranks)
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens,
)
# Save tokenizer
with open(tokenizer_pkl, "wb") as f:
pickle.dump(enc, f)
t1 = time.time()
print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
# --- Build token_bytes lookup for BPB evaluation ---
print("Tokenizer: building token_bytes lookup...")
special_set = set(SPECIAL_TOKENS)
token_bytes_list = []
for token_id in range(enc.n_vocab):
token_str = enc.decode([token_id])
if token_str in special_set:
token_bytes_list.append(0)
else:
token_bytes_list.append(len(token_str.encode("utf-8")))
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
torch.save(token_bytes_tensor, token_bytes_path)
print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
# Sanity check
test = "Hello world! Numbers: 123. Unicode: 你好"
encoded = enc.encode_ordinary(test)
decoded = enc.decode(encoded)
assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
# ---------------------------------------------------------------------------
# Runtime utilities (imported by train.py)
# ---------------------------------------------------------------------------
class Tokenizer:
"""Minimal tokenizer wrapper. Training is handled above."""
def __init__(self, enc):
self.enc = enc
self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
@classmethod
def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
enc = pickle.load(f)
return cls(enc)
def get_vocab_size(self):
return self.enc.n_vocab
def get_bos_token_id(self):
return self.bos_token_id
def encode(self, text, prepend=None, num_threads=8):
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id)
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for row in ids:
row.insert(0, prepend_id)
else:
raise ValueError(f"Invalid input type: {type(text)}")
return ids
def decode(self, ids):
return self.enc.decode(ids)
def get_token_bytes(device="cpu"):
path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
with open(path, "rb") as f:
return torch.load(f, map_location=device)
def _document_batches(split, tokenizer_batch_size=128):
"""Infinite iterator over document batches from parquet files."""
parquet_paths = list_parquet_files()
assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
epoch = 1
while True:
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(pf.num_row_groups):
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], epoch
epoch += 1
def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
"""
BOS-aligned dataloader with best-fit packing.
Every row starts with BOS. Documents packed using best-fit to minimize cropping.
When no document fits remaining space, crops shortest doc to fill exactly.
100% utilization (no padding).
"""
assert split in ["train", "val"]
row_capacity = T + 1
batches = _document_batches(split)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
epoch = 1
def refill_buffer():
nonlocal epoch
doc_batch, epoch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
doc_buffer.extend(token_lists)
# Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
cpu_inputs = cpu_buffer[:B * T].view(B, T)
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
while True:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - pos
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
pos += len(doc)
else:
# No doc fits — crop shortest to fill remaining
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
yield inputs, targets, epoch
# ---------------------------------------------------------------------------
# Evaluation (DO NOT CHANGE — this is the fixed metric)
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_bpb(model, tokenizer, batch_size):
"""
Bits per byte (BPB): vocab size-independent evaluation metric.
Sums per-token cross-entropy (in nats), sums target byte lengths,
then converts nats/byte to bits/byte. Special tokens (byte length 0)
are excluded from both sums.
Uses fixed MAX_SEQ_LEN so results are comparable across configs.
"""
token_bytes = get_token_bytes(device="cuda")
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
total_nats = 0.0
total_bytes = 0
for _ in range(steps):
x, y, _ = next(val_loader)
loss_flat = model(x, y, reduction='none').view(-1)
y_flat = y.view(-1)
nbytes = token_bytes[y_flat]
mask = nbytes > 0
total_nats += (loss_flat * mask).sum().item()
total_bytes += nbytes.sum().item()
return total_nats / (math.log(2) * total_bytes)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
parser.add_argument("--num-shards", type=int, default=8, help="Number of shards to download (-1 = all 1823)")
parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
args = parser.parse_args()
num_shards = MAX_SHARD + 1 if args.num_shards == -1 else args.num_shards
print(f"Cache directory: {CACHE_DIR}")
print()
# Step 1: Download data
download_data(num_shards)
print()
# Step 2: Train tokenizer
train_tokenizer()
print()
print("Done! Ready to train.")

118
program.md Normal file
View File

@ -0,0 +1,118 @@
# autoresearch
This is an experiment to have the LLM do its own research.
## Setup
Check your git state. There are two cases:
**If launched via `spawn.sh` (multi-agent):** The branch and worktree already exist. You are already on the right branch. Skip to step 3.
**If launched manually (single agent):**
1. **Agree on a run tag** with the human: propose a tag based on today's date (e.g. `mar5`). The branch `autoresearch/<tag>` must not already exist — this is a fresh run.
2. **Create the branch**: `git checkout -b autoresearch/<tag>` from current master.
**Then, in both cases:**
3. **Read the in-scope files**: The repo is small. Read these files for full context:
- `constants.py` — fixed constants (`MAX_SEQ_LEN`, `TIME_BUDGET`, `EVAL_TOKENS`). Do not modify.
- `prepare.py` — data prep, tokenizer, dataloader, evaluation. Do not modify.
- `train.py` — the file you modify. Model architecture, optimizer, training loop.
4. **Verify data exists**: Check that `~/.cache/autoresearch/` contains data shards and a tokenizer. If not, tell the human to run `uv run prepare.py`.
5. **Initialize results.tsv**: Create `results.tsv` with header row and baseline entry. The baseline results are already known from the output format section below (val_bpb: 0.997900, peak_vram_mb: 45060.2). Do NOT re-run the baseline — just record it.
6. **Confirm and go**: If a human is present, confirm setup looks good. If launched via `spawn.sh`, proceed directly into the autonomous experiment loop.
## Experimentation
Each experiment runs on a single GPU. The training script runs for a **fixed time budget of 5 minutes** (wall clock training time, excluding startup/compilation). You launch it simply as: `uv run train.py`.
**What you CAN do:**
- Modify `train.py` — this is the only file you edit. Everything is fair game: model architecture, optimizer, hyperparameters, training loop, batch size, model size, etc.
**What you CANNOT do:**
- Modify `constants.py` or `prepare.py`. These are read-only. They contain the fixed evaluation, data loading, tokenizer, and training constants (time budget, sequence length, etc).
- Install new packages or add dependencies. You can only use what's already in `pyproject.toml`.
- Modify the evaluation harness. The `evaluate_bpb` function in `prepare.py` is the ground truth metric.
**The goal is simple: get the lowest val_bpb.** Since the time budget is fixed, you don't need to worry about training time — it's always 5 minutes. Everything is fair game: change the architecture, the optimizer, the hyperparameters, the batch size, the model size. The only constraint is that the code runs without crashing and finishes within the time budget.
**VRAM** is a soft constraint. Some increase is acceptable for meaningful val_bpb gains, but it should not blow up dramatically.
**Simplicity criterion**: All else being equal, simpler is better. A small improvement that adds ugly complexity is not worth it. Conversely, removing something and getting equal or better results is a great outcome — that's a simplification win. When evaluating whether to keep a change, weigh the complexity cost against the improvement magnitude. A 0.001 val_bpb improvement that adds 20 lines of hacky code? Probably not worth it. A 0.001 val_bpb improvement from deleting code? Definitely keep. An improvement of ~0 but much simpler code? Keep.
## Output format
Once the script finishes it prints a summary like this:
```
---
val_bpb: 0.997900
training_seconds: 300.1
total_seconds: 325.9
peak_vram_mb: 45060.2
mfu_percent: 39.80
total_tokens_M: 499.6
num_steps: 953
num_params_M: 50.3
depth: 8
```
This is the baseline to beat.
You can extract the key metric from the log:
```
grep "^val_bpb:" run.log
```
## Logging results
When an experiment is done, log it to `results.tsv` (tab-separated, NOT comma-separated — commas break in descriptions).
The TSV has a header row and 5 columns:
```
commit val_bpb memory_gb status description
```
1. git commit hash (short, 7 chars)
2. val_bpb achieved (e.g. 1.234567) — use 0.000000 for crashes
3. peak memory in GB, round to .1f (e.g. 12.3 — divide peak_vram_mb by 1024) — use 0.0 for crashes
4. status: `keep`, `discard`, or `crash`
5. short text description of what this experiment tried
Example:
```
commit val_bpb memory_gb status description
a1b2c3d 0.997900 44.0 keep baseline
b2c3d4e 0.993200 44.2 keep increase LR to 0.04
c3d4e5f 1.005000 44.0 discard switch to GeLU activation
d4e5f6g 0.000000 0.0 crash double model width (OOM)
```
## The experiment loop
The experiment runs on a dedicated branch (e.g. `autoresearch/mar5` or `autoresearch/mar5-gpu0`).
LOOP FOREVER (until I wake up and come back in the morning):
1. Look at the git state: the current branch/commit we're on
2. Tune `train.py` with an experimental idea by directly hacking the code.
3. git commit
4. run the experiment: `uv run train.py > run.log 2>&1` (redirect everything — do NOT use tee or let output flood your context)
5. read out the results: `grep "^val_bpb:\|^peak_vram_mb:" run.log`
6. record the results in the tsv
7. if val_bpb improved (lower), you "advance" the branch, keeping the git commit
8. if val_bpb is equal or worse, you git reset back to where you started
The idea is that you are a completely autonomous researcher trying things out. If they work, keep. If they don't, discard. And you're advancing the branch so that you can iterate. If you feel like you're getting stuck in some way, you can rewind but you should probably do this very very sparingly (if ever).
**Timeout**: Each experiment should take ~7 minutes total (5 min training + startup + eval). If a run exceeds 10 minutes, kill it and treat it as a failure (discard and revert).
**Crashes**: If a run crashes (OOM, or a bug, or etc.), use your judgment: If it's something dumb and easy to fix (e.g. a typo, a missing import), fix it and re-run. If the idea itself is fundamentally broken, just skip it, log "CRASH" in the tsv, and move on.
**NEVER STOP**: Once the experiment loop has begun (after the initial setup), do NOT pause to ask the human if you should continue. Do NOT ask "should I keep going?" or "is this a good stopping point?". The human might be asleep, or gone from a computer and expects you to continue working *indefinitely* until you are manually stopped. You are autonomous. If you run out of ideas, think harder — read papers referenced in the code, re-read the in-scope files for new angles, try combining previous near-misses, try more radical architectural changes. The loop runs until the human interrupts you, period.
I will usually leave this running for a number of hours, like... 10 or so. If each experiment is ~7 min, you can do ~8/hour, for a total of approx. 80. The hope is that I come back in the morning and we have some improvements.

25
pyproject.toml Normal file
View File

@ -0,0 +1,25 @@
[project]
name = "autoresearch"
version = "0.1.0"
description = "Autonomous pretraining research swarm"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"kernels>=0.11.7",
"numpy>=2.2.6",
"pyarrow>=21.0.0",
"requests>=2.32.0",
"rustbpe>=0.1.0",
"tiktoken>=0.11.0",
"torch==2.9.1",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

248
spawn.sh Executable file
View File

@ -0,0 +1,248 @@
#!/bin/bash
# spawn.sh — launch and manage autonomous research groups
#
# Usage:
# bash spawn.sh launch <tag> <agent:gpu> [agent:gpu ...]
# bash spawn.sh stop <tag>
# bash spawn.sh status
#
# Examples:
# bash spawn.sh launch mar5 claude:0 claude:1 claude:2 claude:3
# bash spawn.sh launch mar5 opus:0 sonnet:1 codex:2 codex:3
# bash spawn.sh stop mar5
# bash spawn.sh status
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "$0")" && pwd)"
WORKTREE_DIR="${REPO_ROOT}/worktrees"
INITIAL_PROMPT="Read program.md and follow the instructions."
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
usage() {
echo "Usage:"
echo " $0 launch <tag> <agent:gpu> [agent:gpu ...]"
echo " $0 stop <tag>"
echo " $0 status"
echo ""
echo "Examples:"
echo " $0 launch mar5 claude:0 claude:1 claude:2 claude:3"
echo " $0 launch mar5 opus:0 sonnet:1 codex:2 codex:3"
echo " $0 stop mar5"
exit 1
}
setup_worker() {
local tag="$1" gpu="$2" agent="$3"
local branch="autoresearch/${tag}-gpu${gpu}"
local worktree="${WORKTREE_DIR}/gpu${gpu}"
local claude_model="${CLAUDE_MODEL:-sonnet}"
# Create branch (must be fresh)
cd "$REPO_ROOT"
if git show-ref --verify --quiet "refs/heads/${branch}"; then
echo "ERROR: Branch '${branch}' already exists. Use a different tag or clean up first." >&2
exit 1
fi
git checkout -q master
git checkout -q -b "$branch"
git checkout -q master
# Set up worktree
if [ -d "$worktree" ]; then
git worktree remove "$worktree" --force 2>/dev/null || rm -rf "$worktree"
fi
mkdir -p "$(dirname "$worktree")"
git worktree add "$worktree" "$branch"
# Symlink .venv so uv/python work
ln -sf "${REPO_ROOT}/.venv" "${worktree}/.venv"
# Build agent command
case "$agent" in
claude|sonnet|opus|haiku)
case "$agent" in
sonnet|opus|haiku) claude_model="$agent" ;;
esac
echo "cd ${worktree} && CUDA_VISIBLE_DEVICES=${gpu} claude --dangerously-skip-permissions --model ${claude_model} \"${INITIAL_PROMPT}\""
;;
codex)
echo "cd ${worktree} && CUDA_VISIBLE_DEVICES=${gpu} codex --dangerously-bypass-approvals-and-sandbox --model gpt-5.3-codex-spark \"${INITIAL_PROMPT}\""
;;
*)
echo "ERROR: Unknown agent '${agent}'. Supported: claude, sonnet, opus, haiku, codex" >&2
git worktree remove "$worktree" --force 2>/dev/null || true
git branch -D "$branch" 2>/dev/null || true
exit 1
;;
esac
}
# ---------------------------------------------------------------------------
# Commands
# ---------------------------------------------------------------------------
cmd_launch() {
if [ $# -lt 2 ]; then
usage
fi
local tag="$1"
shift
local specs=("$@")
local tmux_session="autoresearch-${tag}"
local num_workers=${#specs[@]}
# Kill existing session
tmux kill-session -t "$tmux_session" 2>/dev/null || true
echo ""
echo "============================================"
echo " RESEARCH GROUP: ${tag}"
echo "============================================"
echo " Workers: ${num_workers}"
local pane_cmds=()
local labels=()
local agents=()
for spec in "${specs[@]}"; do
local agent="${spec%%:*}"
local gpu="${spec#*:}"
echo " - GPU ${gpu}: ${agent}"
local cmd
cmd=$(setup_worker "$tag" "$gpu" "$agent")
pane_cmds+=("$cmd")
labels+=("$spec")
case "$agent" in
sonnet|opus|haiku|claude) agents+=("claude") ;;
*) agents+=("$agent") ;;
esac
done
echo " tmux: ${tmux_session}"
echo "============================================"
echo ""
# Create tmux session with tiled grid
tmux new-session -d -s "$tmux_session" -n workers \
"${pane_cmds[0]}; echo ''; echo 'Session ended. Press any key to exit.'; read"
for ((i=1; i<num_workers; i++)); do
tmux split-window -t "$tmux_session:workers" \
"${pane_cmds[$i]}; echo ''; echo 'Session ended. Press any key to exit.'; read"
tmux select-layout -t "$tmux_session:workers" tiled
done
tmux select-layout -t "$tmux_session:workers" tiled
echo " Attach: tmux attach -t ${tmux_session}"
echo " Detach: Ctrl-b d"
echo " Zoom: Ctrl-b z (toggle pane fullscreen)"
echo " Navigate: Ctrl-b arrow keys"
echo ""
# Background watcher — nudge idle agents
local nudge_interval=120
echo "Watcher running (nudge interval: ${nudge_interval}s). Ctrl-C to stop watcher only."
echo ""
declare -A prev_hash
while tmux has-session -t "$tmux_session" 2>/dev/null; do
sleep "$nudge_interval"
for ((i=0; i<num_workers; i++)); do
tmux has-session -t "$tmux_session" 2>/dev/null || break
local pane="${tmux_session}:workers.${i}"
local content curr_hash
content=$(tmux capture-pane -t "$pane" -p 2>/dev/null) || continue
curr_hash=$(echo "$content" | md5sum | cut -d' ' -f1)
if [ "${prev_hash[$i]:-}" != "$curr_hash" ]; then
prev_hash[$i]="$curr_hash"
continue
fi
local is_idle=false
if echo "$content" | grep -qP '^\s*$'; then
is_idle=true
elif echo "$content" | grep -qP '^\s*$'; then
is_idle=true
fi
if [ "$is_idle" = true ]; then
local nudge_msg="Keep going. Do not stop — continue your research loop."
if [ "${agents[$i]}" = "codex" ]; then
tmux send-keys -t "$pane" "$nudge_msg" Enter
else
tmux send-keys -t "$pane" "$nudge_msg" C-m
fi
echo "[$(date +%H:%M:%S)] Nudged pane ${i} (${labels[$i]})"
prev_hash[$i]=""
fi
done
done
echo "tmux session ended. Watcher exiting."
}
cmd_stop() {
if [ $# -lt 1 ]; then
usage
fi
local tag="$1"
local tmux_session="autoresearch-${tag}"
# Kill tmux session
if tmux kill-session -t "$tmux_session" 2>/dev/null; then
echo "Killed tmux session '${tmux_session}'"
else
echo "No tmux session '${tmux_session}' found"
fi
# Remove worktrees
local removed=0
for wt in "${WORKTREE_DIR}"/gpu*; do
[ -d "$wt" ] || continue
git -C "$REPO_ROOT" worktree remove "$wt" --force 2>/dev/null && removed=$((removed + 1))
done
echo "Removed ${removed} worktree(s)"
echo "Branches kept for review (git branch -l 'autoresearch/${tag}-*')"
}
cmd_status() {
echo "Active tmux sessions:"
tmux list-sessions 2>/dev/null | grep "^autoresearch-" || echo " (none)"
echo ""
echo "Active worktrees:"
git -C "$REPO_ROOT" worktree list 2>/dev/null | grep "worktrees/" || echo " (none)"
echo ""
echo "Research branches:"
git -C "$REPO_ROOT" branch -l 'autoresearch/*' 2>/dev/null || echo " (none)"
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if [ $# -lt 1 ]; then
usage
fi
COMMAND="$1"
shift
case "$COMMAND" in
launch) cmd_launch "$@" ;;
stop) cmd_stop "$@" ;;
status) cmd_status "$@" ;;
*) usage ;;
esac

628
train.py Normal file
View File

@ -0,0 +1,628 @@
"""
Autoresearch pretraining script. Single-GPU, single-file.
Cherry-picked and simplified from nanochat.
Usage: uv run train.py
"""
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import gc
import math
import time
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from kernels import get_kernel
fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface
from constants import MAX_SEQ_LEN, TIME_BUDGET
from prepare import Tokenizer, make_dataloader, evaluate_bpb
# ---------------------------------------------------------------------------
# GPT Model
# ---------------------------------------------------------------------------
@dataclass
class GPTConfig:
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 768
window_pattern: str = "SSSL"
def norm(x):
return F.rms_norm(x, (x.size(-1),))
def has_ve(layer_idx, n_layer):
"""Returns True if layer should have Value Embedding (alternating, last always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
v = v + gate.unsqueeze(-1) * ve
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, ve, cos_sin, window_size):
x = x + self.attn(norm(x), ve, cos_sin, window_size)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.window_sizes = self._compute_window_sizes(config)
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
# Value embeddings
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({
str(i): nn.Embedding(config.vocab_size, kv_dim)
for i in range(config.n_layer) if has_ve(i, config.n_layer)
})
# Rotary embeddings
self.rotary_seq_len = config.sequence_len * 10
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight)
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
self.resid_lambdas.fill_(1.0)
self.x0_lambdas.fill_(0.1)
# Value embeddings
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast embeddings to bf16
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def _compute_window_sizes(self, config):
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern)
long_window = config.sequence_len
short_window = long_window // 2
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
window_sizes[-1] = (long_window, 0)
return window_sizes
def estimate_flops(self):
"""Estimated FLOPs per token (forward + backward)."""
nparams = sum(p.numel() for p in self.parameters())
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
h = self.config.n_head
q = self.config.n_embd // self.config.n_head
t = self.config.sequence_len
attn_flops = 0
for window_size in self.window_sizes:
window = window_size[0]
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
return 6 * (nparams - nparams_exclude) + attn_flops
def num_scaling_params(self):
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
return {
'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
# Scale LR ∝ 1/√dmodel (tuned at 768 dim)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
param_groups = [
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
]
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
optimizer = MuonAdamW(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, reduction='mean'):
B, T = idx.size()
assert T <= self.cos.size(1)
cos_sin = self.cos[:, :T], self.sin[:, :T]
x = self.transformer.wte(idx)
x = norm(x)
x0 = x
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i])
x = norm(x)
softcap = 15
logits = self.lm_head(x)
logits = logits.float()
logits = softcap * torch.tanh(logits / softcap)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
ignore_index=-1, reduction=reduction)
return loss
return logits
# ---------------------------------------------------------------------------
# Optimizer (MuonAdamW, single GPU only)
# ---------------------------------------------------------------------------
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# NorMuon variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class MuonAdamW(torch.optim.Optimizer):
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
def __init__(self, param_groups):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
def _step_muon(self, group):
params = group['params']
if not params:
return
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(stacked_grads, stacked_params,
state["momentum_buffer"], state["second_momentum_buffer"],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
self._muon_beta2_t, group["ns_steps"], red_dim)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
# ---------------------------------------------------------------------------
# Hyperparameters (edit these directly, no CLI flags needed)
# ---------------------------------------------------------------------------
# Model architecture
ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
HEAD_DIM = 128 # target head dimension for attention
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
# Optimization
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
# Model size
DEPTH = 8 # number of transformer layers
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
# ---------------------------------------------------------------------------
# Setup: tokenizer, model, optimizer, dataloader
# ---------------------------------------------------------------------------
t_start = time.time()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
H100_BF16_PEAK_FLOPS = 989.5e12
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
print(f"Vocab size: {vocab_size:,}")
def build_model_config(depth):
base_dim = depth * ASPECT_RATIO
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
num_heads = model_dim // HEAD_DIM
return GPTConfig(
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=WINDOW_PATTERN,
)
config = build_model_config(DEPTH)
print(f"Model config: {asdict(config)}")
with torch.device("meta"):
model = GPT(config)
model.to_empty(device=device)
model.init_weights()
param_counts = model.num_scaling_params()
print("Parameter counts:")
for key, value in param_counts.items():
print(f" {key:24s}: {value:,}")
num_params = param_counts['total']
num_flops_per_token = model.estimate_flops()
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
optimizer = model.setup_optimizer(
unembedding_lr=UNEMBEDDING_LR,
embedding_lr=EMBEDDING_LR,
scalar_lr=SCALAR_LR,
adam_betas=ADAM_BETAS,
matrix_lr=MATRIX_LR,
weight_decay=WEIGHT_DECAY,
)
model = torch.compile(model, dynamic=False)
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
x, y, epoch = next(train_loader) # prefetch first batch
print(f"Time budget: {TIME_BUDGET}s")
print(f"Gradient accumulation steps: {grad_accum_steps}")
# Schedules (all based on progress = training_time / TIME_BUDGET)
def get_lr_multiplier(progress):
if progress < WARMUP_RATIO:
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
elif progress < 1.0 - WARMDOWN_RATIO:
return 1.0
else:
cooldown = (1.0 - progress) / WARMDOWN_RATIO
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
def get_muon_momentum(step):
frac = min(step / 300, 1)
return (1 - frac) * 0.85 + frac * 0.95
def get_weight_decay(progress):
return WEIGHT_DECAY * (1 - progress)
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
t_start_training = time.time()
smooth_train_loss = 0
total_training_time = 0
step = 0
while True:
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
x, y, epoch = next(train_loader)
# Progress and schedules
progress = min(total_training_time / TIME_BUDGET, 1.0)
lrm = get_lr_multiplier(progress)
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(progress)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item()
# Fast fail: abort if loss is exploding
if train_loss_f > 100:
print("FAIL")
exit(1)
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
if step > 10:
total_training_time += dt
# Logging
ema_beta = 0.9
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
pct_done = 100 * progress
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
remaining = max(0, TIME_BUDGET - total_training_time)
print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
# GC management (Python's GC causes ~500ms stalls)
if step == 0:
gc.collect()
gc.freeze()
gc.disable()
elif (step + 1) % 5000 == 0:
gc.collect()
step += 1
# Time's up — but only stop after warmup steps so we don't count compilation
if step > 10 and total_training_time >= TIME_BUDGET:
break
print() # newline after \r training log
total_tokens = step * TOTAL_BATCH_SIZE
# Final eval
model.eval()
with autocast_ctx:
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
# Final summary
t_end = time.time()
startup_time = t_start_training - t_start
steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print("---")
print(f"val_bpb: {val_bpb:.6f}")
print(f"training_seconds: {total_training_time:.1f}")
print(f"total_seconds: {t_end - t_start:.1f}")
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
print(f"mfu_percent: {steady_state_mfu:.2f}")
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
print(f"num_steps: {step}")
print(f"num_params_M: {num_params / 1e6:.1f}")
print(f"depth: {DEPTH}")

1367
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff