diff --git a/prepare.py b/prepare.py index 62d63ce..62607b9 100644 --- a/prepare.py +++ b/prepare.py @@ -88,7 +88,7 @@ def download_single_shard(index): return False -def download_data(num_shards): +def download_data(num_shards, download_workers=8): """Download training shards + pinned validation shard.""" os.makedirs(DATA_DIR, exist_ok=True) num_train = min(num_shards, MAX_SHARD) @@ -105,7 +105,7 @@ def download_data(num_shards): needed = len(ids) - existing print(f"Data: downloading {needed} shards ({existing} already exist)...") - workers = min(8, needed) + workers = max(1, min(download_workers, needed)) with Pool(processes=workers) as pool: results = pool.map(download_single_shard, ids) @@ -379,7 +379,7 @@ if __name__ == "__main__": print() # Step 1: Download data - download_data(num_shards) + download_data(num_shards, download_workers=args.download_workers) print() # Step 2: Train tokenizer