""" One-time data preparation for autoresearch experiments. Downloads data shards and trains a BPE tokenizer. Usage: python prepare.py # full prep (download + tokenizer) python prepare.py --num-shards 8 # download only 8 shards (for testing) Data and tokenizer are stored in ~/.cache/autoresearch/. """ import os import sys import time import math import argparse import pickle from multiprocessing import Pool import requests import pyarrow.parquet as pq import rustbpe import tiktoken import torch # --------------------------------------------------------------------------- # Constants (fixed, do not modify) # --------------------------------------------------------------------------- MAX_SEQ_LEN = 2048 # context length TIME_BUDGET = 300 # training time budget in seconds (5 minutes) EVAL_TOKENS = 40 * 524288 # number of tokens for val eval # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") DATA_DIR = os.path.join(CACHE_DIR, "data") TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" MAX_SHARD = 6542 # the last datashard is shard_06542.parquet VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542) VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet" VOCAB_SIZE = 8192 # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3}) SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)] BOS_TOKEN = "<|reserved_0|>" # --------------------------------------------------------------------------- # Data download # --------------------------------------------------------------------------- def download_single_shard(index): """Download one parquet shard with retries. Returns True on success.""" filename = f"shard_{index:05d}.parquet" filepath = os.path.join(DATA_DIR, filename) if os.path.exists(filepath): return True url = f"{BASE_URL}/{filename}" max_attempts = 5 for attempt in range(1, max_attempts + 1): try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() temp_path = filepath + ".tmp" with open(temp_path, "wb") as f: for chunk in response.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) os.rename(temp_path, filepath) print(f" Downloaded {filename}") return True except (requests.RequestException, IOError) as e: print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") for path in [filepath + ".tmp", filepath]: if os.path.exists(path): try: os.remove(path) except OSError: pass if attempt < max_attempts: time.sleep(2 ** attempt) return False def download_data(num_shards, download_workers=8): """Download training shards + pinned validation shard.""" os.makedirs(DATA_DIR, exist_ok=True) num_train = min(num_shards, MAX_SHARD) ids = list(range(num_train)) if VAL_SHARD not in ids: ids.append(VAL_SHARD) # Count what's already downloaded existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) if existing == len(ids): print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}") return needed = len(ids) - existing print(f"Data: downloading {needed} shards ({existing} already exist)...") workers = max(1, min(download_workers, needed)) with Pool(processes=workers) as pool: results = pool.map(download_single_shard, ids) ok = sum(1 for r in results if r) print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}") # --------------------------------------------------------------------------- # Tokenizer training # --------------------------------------------------------------------------- def list_parquet_files(): """Return sorted list of parquet file paths in the data directory.""" files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) return [os.path.join(DATA_DIR, f) for f in files] def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): """Yield documents from training split (all shards except pinned val shard).""" parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] nchars = 0 for filepath in parquet_paths: pf = pq.ParquetFile(filepath) for rg_idx in range(pf.num_row_groups): rg = pf.read_row_group(rg_idx) for text in rg.column("text").to_pylist(): doc = text[:doc_cap] if len(text) > doc_cap else text nchars += len(doc) yield doc if nchars >= max_chars: return def train_tokenizer(): """Train BPE tokenizer using rustbpe, save as tiktoken pickle.""" tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl") token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path): print(f"Tokenizer: already trained at {TOKENIZER_DIR}") return os.makedirs(TOKENIZER_DIR, exist_ok=True) parquet_files = list_parquet_files() if len(parquet_files) < 2: print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") sys.exit(1) # --- Train with rustbpe --- print("Tokenizer: training BPE tokenizer...") t0 = time.time() tokenizer = rustbpe.Tokenizer() vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS) tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN) # Build tiktoken encoding from trained merges pattern = tokenizer.get_pattern() mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()} tokens_offset = len(mergeable_ranks) special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} enc = tiktoken.Encoding( name="rustbpe", pat_str=pattern, mergeable_ranks=mergeable_ranks, special_tokens=special_tokens, ) # Save tokenizer with open(tokenizer_pkl, "wb") as f: pickle.dump(enc, f) t1 = time.time() print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}") # --- Build token_bytes lookup for BPB evaluation --- print("Tokenizer: building token_bytes lookup...") special_set = set(SPECIAL_TOKENS) token_bytes_list = [] for token_id in range(enc.n_vocab): token_str = enc.decode([token_id]) if token_str in special_set: token_bytes_list.append(0) else: token_bytes_list.append(len(token_str.encode("utf-8"))) token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32) torch.save(token_bytes_tensor, token_bytes_path) print(f"Tokenizer: saved token_bytes to {token_bytes_path}") # Sanity check test = "Hello world! Numbers: 123. Unicode: 你好" encoded = enc.encode_ordinary(test) decoded = enc.decode(encoded) assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}" print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})") # --------------------------------------------------------------------------- # Runtime utilities (imported by train.py) # --------------------------------------------------------------------------- class Tokenizer: """Minimal tokenizer wrapper. Training is handled above.""" def __init__(self, enc): self.enc = enc self.bos_token_id = enc.encode_single_token(BOS_TOKEN) @classmethod def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: enc = pickle.load(f) return cls(enc) def get_vocab_size(self): return self.enc.n_vocab def get_bos_token_id(self): return self.bos_token_id def encode(self, text, prepend=None, num_threads=8): if prepend is not None: prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) if isinstance(text, str): ids = self.enc.encode_ordinary(text) if prepend is not None: ids.insert(0, prepend_id) elif isinstance(text, list): ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) if prepend is not None: for row in ids: row.insert(0, prepend_id) else: raise ValueError(f"Invalid input type: {type(text)}") return ids def decode(self, ids): return self.enc.decode(ids) def get_token_bytes(device="cpu"): path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") with open(path, "rb") as f: return torch.load(f, map_location=device) def _document_batches(split, tokenizer_batch_size=128): """Infinite iterator over document batches from parquet files.""" parquet_paths = list_parquet_files() assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." val_path = os.path.join(DATA_DIR, VAL_FILENAME) if split == "train": parquet_paths = [p for p in parquet_paths if p != val_path] assert len(parquet_paths) > 0, "No training shards found." else: parquet_paths = [val_path] epoch = 1 while True: for filepath in parquet_paths: pf = pq.ParquetFile(filepath) for rg_idx in range(pf.num_row_groups): rg = pf.read_row_group(rg_idx) batch = rg.column('text').to_pylist() for i in range(0, len(batch), tokenizer_batch_size): yield batch[i:i+tokenizer_batch_size], epoch epoch += 1 def make_dataloader(tokenizer, B, T, split, buffer_size=1000): """ BOS-aligned dataloader with best-fit packing. Every row starts with BOS. Documents packed using best-fit to minimize cropping. When no document fits remaining space, crops shortest doc to fill exactly. 100% utilization (no padding). """ assert split in ["train", "val"] row_capacity = T + 1 batches = _document_batches(split) bos_token = tokenizer.get_bos_token_id() doc_buffer = [] epoch = 1 def refill_buffer(): nonlocal epoch doc_batch, epoch = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token) doc_buffer.extend(token_lists) # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] row_buffer = torch.empty((B, row_capacity), dtype=torch.long) cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True) gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda") cpu_inputs = cpu_buffer[:B * T].view(B, T) cpu_targets = cpu_buffer[B * T:].view(B, T) inputs = gpu_buffer[:B * T].view(B, T) targets = gpu_buffer[B * T:].view(B, T) while True: for row_idx in range(B): pos = 0 while pos < row_capacity: while len(doc_buffer) < buffer_size: refill_buffer() remaining = row_capacity - pos # Find largest doc that fits entirely best_idx = -1 best_len = 0 for i, doc in enumerate(doc_buffer): doc_len = len(doc) if doc_len <= remaining and doc_len > best_len: best_idx = i best_len = doc_len if best_idx >= 0: doc = doc_buffer.pop(best_idx) row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) pos += len(doc) else: # No doc fits — crop shortest to fill remaining shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) doc = doc_buffer.pop(shortest_idx) row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) pos += remaining cpu_inputs.copy_(row_buffer[:, :-1]) cpu_targets.copy_(row_buffer[:, 1:]) gpu_buffer.copy_(cpu_buffer, non_blocking=True) yield inputs, targets, epoch # --------------------------------------------------------------------------- # Evaluation (DO NOT CHANGE — this is the fixed metric) # --------------------------------------------------------------------------- @torch.no_grad() def evaluate_bpb(model, tokenizer, batch_size): """ Bits per byte (BPB): vocab size-independent evaluation metric. Sums per-token cross-entropy (in nats), sums target byte lengths, then converts nats/byte to bits/byte. Special tokens (byte length 0) are excluded from both sums. Uses fixed MAX_SEQ_LEN so results are comparable across configs. """ token_bytes = get_token_bytes(device="cuda") val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) total_nats = 0.0 total_bytes = 0 for _ in range(steps): x, y, _ = next(val_loader) loss_flat = model(x, y, reduction='none').view(-1) y_flat = y.view(-1) nbytes = token_bytes[y_flat] mask = nbytes > 0 total_nats += (loss_flat * mask).sum().item() total_bytes += nbytes.sum().item() return total_nats / (math.log(2) * total_bytes) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") args = parser.parse_args() num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards print(f"Cache directory: {CACHE_DIR}") print() # Step 1: Download data download_data(num_shards, download_workers=args.download_workers) print() # Step 2: Train tokenizer train_tokenizer() print() print("Done! Ready to train.")