minor tweaks, pin val shard
This commit is contained in:
parent
47ec1ade0a
commit
032d203695
27
prepare.py
27
prepare.py
@ -34,12 +34,14 @@ DATA_DIR = os.path.join(CACHE_DIR, "data")
|
||||
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
||||
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
||||
VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542)
|
||||
VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
|
||||
VOCAB_SIZE = 8192
|
||||
|
||||
# BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(16)]
|
||||
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
|
||||
BOS_TOKEN = "<|reserved_0|>"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -81,9 +83,12 @@ def download_single_shard(index):
|
||||
|
||||
|
||||
def download_data(num_shards):
|
||||
"""Download data shards in parallel."""
|
||||
"""Download training shards + pinned validation shard."""
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
ids = list(range(min(num_shards, MAX_SHARD + 1)))
|
||||
num_train = min(num_shards, MAX_SHARD)
|
||||
ids = list(range(num_train))
|
||||
if VAL_SHARD not in ids:
|
||||
ids.append(VAL_SHARD)
|
||||
|
||||
# Count what's already downloaded
|
||||
existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
|
||||
@ -111,9 +116,9 @@ def list_parquet_files():
|
||||
return [os.path.join(DATA_DIR, f) for f in files]
|
||||
|
||||
|
||||
def text_iterator(max_chars=2_000_000_000, doc_cap=10_000):
|
||||
"""Yield documents from training split (all shards except last)."""
|
||||
parquet_paths = list_parquet_files()[:-1] # last shard is val
|
||||
def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
|
||||
"""Yield documents from training split (all shards except pinned val shard)."""
|
||||
parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
|
||||
nchars = 0
|
||||
for filepath in parquet_paths:
|
||||
pf = pq.ParquetFile(filepath)
|
||||
@ -244,7 +249,11 @@ def _document_batches(split, tokenizer_batch_size=128):
|
||||
"""Infinite iterator over document batches from parquet files."""
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
val_path = os.path.join(DATA_DIR, VAL_FILENAME)
|
||||
if split == "train":
|
||||
parquet_paths = [p for p in parquet_paths if p != val_path]
|
||||
else:
|
||||
parquet_paths = [val_path]
|
||||
epoch = 1
|
||||
while True:
|
||||
for filepath in parquet_paths:
|
||||
@ -354,11 +363,11 @@ def evaluate_bpb(model, tokenizer, batch_size):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
|
||||
parser.add_argument("--num-shards", type=int, default=8, help="Number of shards to download (-1 = all 1823)")
|
||||
parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
|
||||
parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
num_shards = MAX_SHARD + 1 if args.num_shards == -1 else args.num_shards
|
||||
num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
|
||||
|
||||
print(f"Cache directory: {CACHE_DIR}")
|
||||
print()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user