Honor --download-workers instead of hardcoding 8 download workers
This commit is contained in:
commit
500114a035
@ -88,7 +88,7 @@ def download_single_shard(index):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def download_data(num_shards):
|
def download_data(num_shards, download_workers=8):
|
||||||
"""Download training shards + pinned validation shard."""
|
"""Download training shards + pinned validation shard."""
|
||||||
os.makedirs(DATA_DIR, exist_ok=True)
|
os.makedirs(DATA_DIR, exist_ok=True)
|
||||||
num_train = min(num_shards, MAX_SHARD)
|
num_train = min(num_shards, MAX_SHARD)
|
||||||
@ -105,7 +105,7 @@ def download_data(num_shards):
|
|||||||
needed = len(ids) - existing
|
needed = len(ids) - existing
|
||||||
print(f"Data: downloading {needed} shards ({existing} already exist)...")
|
print(f"Data: downloading {needed} shards ({existing} already exist)...")
|
||||||
|
|
||||||
workers = min(8, needed)
|
workers = max(1, min(download_workers, needed))
|
||||||
with Pool(processes=workers) as pool:
|
with Pool(processes=workers) as pool:
|
||||||
results = pool.map(download_single_shard, ids)
|
results = pool.map(download_single_shard, ids)
|
||||||
|
|
||||||
@ -379,7 +379,7 @@ if __name__ == "__main__":
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
# Step 1: Download data
|
# Step 1: Download data
|
||||||
download_data(num_shards)
|
download_data(num_shards, download_workers=args.download_workers)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Step 2: Train tokenizer
|
# Step 2: Train tokenizer
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user