fix(prepare): honor --download-workers

This commit is contained in:
Dipesh Babu 2026-03-07 15:39:17 -05:00
parent 6fdefa7265
commit 777e443790

View File

@ -88,7 +88,7 @@ def download_single_shard(index):
return False
def download_data(num_shards):
def download_data(num_shards, download_workers=8):
"""Download training shards + pinned validation shard."""
os.makedirs(DATA_DIR, exist_ok=True)
num_train = min(num_shards, MAX_SHARD)
@ -105,7 +105,7 @@ def download_data(num_shards):
needed = len(ids) - existing
print(f"Data: downloading {needed} shards ({existing} already exist)...")
workers = min(8, needed)
workers = max(1, min(download_workers, needed))
with Pool(processes=workers) as pool:
results = pool.map(download_single_shard, ids)
@ -379,7 +379,7 @@ if __name__ == "__main__":
print()
# Step 1: Download data
download_data(num_shards)
download_data(num_shards, download_workers=args.download_workers)
print()
# Step 2: Train tokenizer