Honor --download-workers instead of hardcoding 8 download workers
This commit is contained in:
commit
500114a035
@ -88,7 +88,7 @@ def download_single_shard(index):
|
||||
return False
|
||||
|
||||
|
||||
def download_data(num_shards):
|
||||
def download_data(num_shards, download_workers=8):
|
||||
"""Download training shards + pinned validation shard."""
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
num_train = min(num_shards, MAX_SHARD)
|
||||
@ -105,7 +105,7 @@ def download_data(num_shards):
|
||||
needed = len(ids) - existing
|
||||
print(f"Data: downloading {needed} shards ({existing} already exist)...")
|
||||
|
||||
workers = min(8, needed)
|
||||
workers = max(1, min(download_workers, needed))
|
||||
with Pool(processes=workers) as pool:
|
||||
results = pool.map(download_single_shard, ids)
|
||||
|
||||
@ -379,7 +379,7 @@ if __name__ == "__main__":
|
||||
print()
|
||||
|
||||
# Step 1: Download data
|
||||
download_data(num_shards)
|
||||
download_data(num_shards, download_workers=args.download_workers)
|
||||
print()
|
||||
|
||||
# Step 2: Train tokenizer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user