""" One-time data preparation for autoresearch experiments. Downloads data shards and trains a BPE tokenizer. Usage: python prepare.py # full prep (download + tokenizer) python prepare.py --num-shards 8 # download only 8 shards (for testing) Data and tokenizer are stored in ~/.cache/autoresearch/. """ import os import sys import time import math import argparse import pickle from multiprocessing import Pool import requests import pyarrow.parquet as pq import rustbpe import tiktoken import torch from constants import MAX_SEQ_LEN, EVAL_TOKENS # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") DATA_DIR = os.path.join(CACHE_DIR, "data") TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" MAX_SHARD = 6542 # the last datashard is shard_06542.parquet VOCAB_SIZE = 8192 # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3}) SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(16)] BOS_TOKEN = "<|reserved_0|>" # --------------------------------------------------------------------------- # Data download # --------------------------------------------------------------------------- def download_single_shard(index): """Download one parquet shard with retries. Returns True on success.""" filename = f"shard_{index:05d}.parquet" filepath = os.path.join(DATA_DIR, filename) if os.path.exists(filepath): return True url = f"{BASE_URL}/{filename}" max_attempts = 5 for attempt in range(1, max_attempts + 1): try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() temp_path = filepath + ".tmp" with open(temp_path, "wb") as f: for chunk in response.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) os.rename(temp_path, filepath) print(f" Downloaded {filename}") return True except (requests.RequestException, IOError) as e: print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") for path in [filepath + ".tmp", filepath]: if os.path.exists(path): try: os.remove(path) except OSError: pass if attempt < max_attempts: time.sleep(2 ** attempt) return False def download_data(num_shards): """Download data shards in parallel.""" os.makedirs(DATA_DIR, exist_ok=True) ids = list(range(min(num_shards, MAX_SHARD + 1))) # Count what's already downloaded existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) if existing == len(ids): print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}") return needed = len(ids) - existing print(f"Data: downloading {needed} shards ({existing} already exist)...") workers = min(8, needed) with Pool(processes=workers) as pool: results = pool.map(download_single_shard, ids) ok = sum(1 for r in results if r) print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}") # --------------------------------------------------------------------------- # Tokenizer training # --------------------------------------------------------------------------- def list_parquet_files(): """Return sorted list of parquet file paths in the data directory.""" files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) return [os.path.join(DATA_DIR, f) for f in files] def text_iterator(max_chars=2_000_000_000, doc_cap=10_000): """Yield documents from training split (all shards except last).""" parquet_paths = list_parquet_files()[:-1] # last shard is val nchars = 0 for filepath in parquet_paths: pf = pq.ParquetFile(filepath) for rg_idx in range(pf.num_row_groups): rg = pf.read_row_group(rg_idx) for text in rg.column("text").to_pylist(): doc = text[:doc_cap] if len(text) > doc_cap else text nchars += len(doc) yield doc if nchars >= max_chars: return def train_tokenizer(): """Train BPE tokenizer using rustbpe, save as tiktoken pickle.""" tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl") token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path): print(f"Tokenizer: already trained at {TOKENIZER_DIR}") return os.makedirs(TOKENIZER_DIR, exist_ok=True) parquet_files = list_parquet_files() if len(parquet_files) < 2: print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") sys.exit(1) # --- Train with rustbpe --- print("Tokenizer: training BPE tokenizer...") t0 = time.time() tokenizer = rustbpe.Tokenizer() vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS) tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN) # Build tiktoken encoding from trained merges pattern = tokenizer.get_pattern() mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()} tokens_offset = len(mergeable_ranks) special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} enc = tiktoken.Encoding( name="rustbpe", pat_str=pattern, mergeable_ranks=mergeable_ranks, special_tokens=special_tokens, ) # Save tokenizer with open(tokenizer_pkl, "wb") as f: pickle.dump(enc, f) t1 = time.time() print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}") # --- Build token_bytes lookup for BPB evaluation --- print("Tokenizer: building token_bytes lookup...") special_set = set(SPECIAL_TOKENS) token_bytes_list = [] for token_id in range(enc.n_vocab): token_str = enc.decode([token_id]) if token_str in special_set: token_bytes_list.append(0) else: token_bytes_list.append(len(token_str.encode("utf-8"))) token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32) torch.save(token_bytes_tensor, token_bytes_path) print(f"Tokenizer: saved token_bytes to {token_bytes_path}") # Sanity check test = "Hello world! Numbers: 123. Unicode: 你好" encoded = enc.encode_ordinary(test) decoded = enc.decode(encoded) assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}" print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})") # --------------------------------------------------------------------------- # Runtime utilities (imported by train.py) # --------------------------------------------------------------------------- class Tokenizer: """Minimal tokenizer wrapper. Training is handled above.""" def __init__(self, enc): self.enc = enc self.bos_token_id = enc.encode_single_token(BOS_TOKEN) @classmethod def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: enc = pickle.load(f) return cls(enc) def get_vocab_size(self): return self.enc.n_vocab def get_bos_token_id(self): return self.bos_token_id def encode(self, text, prepend=None, num_threads=8): if prepend is not None: prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) if isinstance(text, str): ids = self.enc.encode_ordinary(text) if prepend is not None: ids.insert(0, prepend_id) elif isinstance(text, list): ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) if prepend is not None: for row in ids: row.insert(0, prepend_id) else: raise ValueError(f"Invalid input type: {type(text)}") return ids def decode(self, ids): return self.enc.decode(ids) def get_token_bytes(device="cpu"): path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") with open(path, "rb") as f: return torch.load(f, map_location=device) def _document_batches(split, tokenizer_batch_size=128): """Infinite iterator over document batches from parquet files.""" parquet_paths = list_parquet_files() assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] epoch = 1 while True: for filepath in parquet_paths: pf = pq.ParquetFile(filepath) for rg_idx in range(pf.num_row_groups): rg = pf.read_row_group(rg_idx) batch = rg.column('text').to_pylist() for i in range(0, len(batch), tokenizer_batch_size): yield batch[i:i+tokenizer_batch_size], epoch epoch += 1 def make_dataloader(tokenizer, B, T, split, buffer_size=1000): """ BOS-aligned dataloader with best-fit packing. Every row starts with BOS. Documents packed using best-fit to minimize cropping. When no document fits remaining space, crops shortest doc to fill exactly. 100% utilization (no padding). """ assert split in ["train", "val"] row_capacity = T + 1 batches = _document_batches(split) bos_token = tokenizer.get_bos_token_id() doc_buffer = [] epoch = 1 def refill_buffer(): nonlocal epoch doc_batch, epoch = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token) doc_buffer.extend(token_lists) # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] row_buffer = torch.empty((B, row_capacity), dtype=torch.long) cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True) gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda") cpu_inputs = cpu_buffer[:B * T].view(B, T) cpu_targets = cpu_buffer[B * T:].view(B, T) inputs = gpu_buffer[:B * T].view(B, T) targets = gpu_buffer[B * T:].view(B, T) while True: for row_idx in range(B): pos = 0 while pos < row_capacity: while len(doc_buffer) < buffer_size: refill_buffer() remaining = row_capacity - pos # Find largest doc that fits entirely best_idx = -1 best_len = 0 for i, doc in enumerate(doc_buffer): doc_len = len(doc) if doc_len <= remaining and doc_len > best_len: best_idx = i best_len = doc_len if best_idx >= 0: doc = doc_buffer.pop(best_idx) row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) pos += len(doc) else: # No doc fits — crop shortest to fill remaining shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) doc = doc_buffer.pop(shortest_idx) row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) pos += remaining cpu_inputs.copy_(row_buffer[:, :-1]) cpu_targets.copy_(row_buffer[:, 1:]) gpu_buffer.copy_(cpu_buffer, non_blocking=True) yield inputs, targets, epoch # --------------------------------------------------------------------------- # Evaluation (DO NOT CHANGE — this is the fixed metric) # --------------------------------------------------------------------------- @torch.no_grad() def evaluate_bpb(model, tokenizer, batch_size): """ Bits per byte (BPB): vocab size-independent evaluation metric. Sums per-token cross-entropy (in nats), sums target byte lengths, then converts nats/byte to bits/byte. Special tokens (byte length 0) are excluded from both sums. Uses fixed MAX_SEQ_LEN so results are comparable across configs. """ token_bytes = get_token_bytes(device="cuda") val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) total_nats = 0.0 total_bytes = 0 for _ in range(steps): x, y, _ = next(val_loader) loss_flat = model(x, y, reduction='none').view(-1) y_flat = y.view(-1) nbytes = token_bytes[y_flat] mask = nbytes > 0 total_nats += (loss_flat * mask).sum().item() total_bytes += nbytes.sum().item() return total_nats / (math.log(2) * total_bytes) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") parser.add_argument("--num-shards", type=int, default=8, help="Number of shards to download (-1 = all 1823)") parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") args = parser.parse_args() num_shards = MAX_SHARD + 1 if args.num_shards == -1 else args.num_shards print(f"Cache directory: {CACHE_DIR}") print() # Step 1: Download data download_data(num_shards) print() # Step 2: Train tokenizer train_tokenizer() print() print("Done! Ready to train.")