Honor --download-workers instead of hardcoding 8 download workers

This commit is contained in:
Andrej 2026-03-07 14:17:45 -08:00 committed by GitHub
commit 500114a035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -88,7 +88,7 @@ def download_single_shard(index):
return False return False
def download_data(num_shards): def download_data(num_shards, download_workers=8):
"""Download training shards + pinned validation shard.""" """Download training shards + pinned validation shard."""
os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
num_train = min(num_shards, MAX_SHARD) num_train = min(num_shards, MAX_SHARD)
@ -105,7 +105,7 @@ def download_data(num_shards):
needed = len(ids) - existing needed = len(ids) - existing
print(f"Data: downloading {needed} shards ({existing} already exist)...") print(f"Data: downloading {needed} shards ({existing} already exist)...")
workers = min(8, needed) workers = max(1, min(download_workers, needed))
with Pool(processes=workers) as pool: with Pool(processes=workers) as pool:
results = pool.map(download_single_shard, ids) results = pool.map(download_single_shard, ids)
@ -379,7 +379,7 @@ if __name__ == "__main__":
print() print()
# Step 1: Download data # Step 1: Download data
download_data(num_shards) download_data(num_shards, download_workers=args.download_workers)
print() print()
# Step 2: Train tokenizer # Step 2: Train tokenizer