From 032d2036954a6c41a268005e6a50fc5b18f4a0a7 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 7 Mar 2026 17:59:52 +0000 Subject: [PATCH] minor tweaks, pin val shard --- prepare.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/prepare.py b/prepare.py index baa6515..dd95c16 100644 --- a/prepare.py +++ b/prepare.py @@ -34,12 +34,14 @@ DATA_DIR = os.path.join(CACHE_DIR, "data") TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" MAX_SHARD = 6542 # the last datashard is shard_06542.parquet +VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542) +VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet" VOCAB_SIZE = 8192 # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3}) SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" -SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(16)] +SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)] BOS_TOKEN = "<|reserved_0|>" # --------------------------------------------------------------------------- @@ -81,9 +83,12 @@ def download_single_shard(index): def download_data(num_shards): - """Download data shards in parallel.""" + """Download training shards + pinned validation shard.""" os.makedirs(DATA_DIR, exist_ok=True) - ids = list(range(min(num_shards, MAX_SHARD + 1))) + num_train = min(num_shards, MAX_SHARD) + ids = list(range(num_train)) + if VAL_SHARD not in ids: + ids.append(VAL_SHARD) # Count what's already downloaded existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) @@ -111,9 +116,9 @@ def list_parquet_files(): return [os.path.join(DATA_DIR, f) for f in files] -def text_iterator(max_chars=2_000_000_000, doc_cap=10_000): - """Yield documents from training split (all shards except last).""" - parquet_paths = list_parquet_files()[:-1] # last shard is val +def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): + """Yield documents from training split (all shards except pinned val shard).""" + parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] nchars = 0 for filepath in parquet_paths: pf = pq.ParquetFile(filepath) @@ -244,7 +249,11 @@ def _document_batches(split, tokenizer_batch_size=128): """Infinite iterator over document batches from parquet files.""" parquet_paths = list_parquet_files() assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." - parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + val_path = os.path.join(DATA_DIR, VAL_FILENAME) + if split == "train": + parquet_paths = [p for p in parquet_paths if p != val_path] + else: + parquet_paths = [val_path] epoch = 1 while True: for filepath in parquet_paths: @@ -354,11 +363,11 @@ def evaluate_bpb(model, tokenizer, batch_size): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") - parser.add_argument("--num-shards", type=int, default=8, help="Number of shards to download (-1 = all 1823)") + parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") args = parser.parse_args() - num_shards = MAX_SHARD + 1 if args.num_shards == -1 else args.num_shards + num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards print(f"Cache directory: {CACHE_DIR}") print()