@@ -120,9 +120,7 @@ def train_vocab(vocab_size):
120
120
121
121
def process_shard (args , vocab_size ):
122
122
shard_id , shard = args
123
- tokenizer_model = None
124
- if vocab_size > 0 :
125
- tokenizer_model = os .path .join (DATA_CACHE_DIR , f"tok{ vocab_size } .model" )
123
+ tokenizer_model = get_tokenizer_model_path ()
126
124
enc = Tokenizer (tokenizer_model )
127
125
with open (shard , "r" ) as f :
128
126
data = json .load (f )
@@ -171,10 +169,12 @@ def pretokenize(vocab_size):
171
169
class PretokDataset (torch .utils .data .IterableDataset ):
172
170
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
173
171
174
- def __init__ (self , split , max_seq_len ):
172
+ def __init__ (self , split , max_seq_len , vocab_size , vocab_source ):
175
173
super ().__init__ ()
176
174
self .split = split
177
175
self .max_seq_len = max_seq_len
176
+ self .vocab_size = vocab_size
177
+ self .vocab_source = vocab_source
178
178
179
179
def __iter__ (self ):
180
180
# get worker info within a DataLoader
@@ -186,8 +186,14 @@ def __iter__(self):
186
186
seed = 42 + worker_id + 1337 * rank
187
187
rng = random .Random (seed )
188
188
print (f"Created a PretokDataset with rng seed { seed } " )
189
- data_dir = os .path .join (DATA_CACHE_DIR , "TinyStories_all_data" )
190
- shard_filenames = sorted (glob .glob (os .path .join (data_dir , "*.bin" )))
189
+ if self .vocab_source == "llama2" :
190
+ # the .bin files are right along the .json files
191
+ bin_dir = os .path .join (DATA_CACHE_DIR , "TinyStories_all_data" )
192
+ shard_filenames = sorted (glob .glob (os .path .join (bin_dir , "*.bin" )))
193
+ elif self .vocab_source == "custom" :
194
+ # the .bin files are in tok{N} directory
195
+ bin_dir = os .path .join (DATA_CACHE_DIR , f"tok{ self .vocab_size } " )
196
+ shard_filenames = sorted (glob .glob (os .path .join (bin_dir , "*.bin" )))
191
197
# train/test split. let's use only shard 0 for test split, rest train
192
198
shard_filenames = shard_filenames [1 :] if self .split == "train" else shard_filenames [:1 ]
193
199
while True :
@@ -209,12 +215,25 @@ def __iter__(self):
209
215
y = chunk [1 :]
210
216
yield x , y
211
217
218
+ # -----------------------------------------------------------------------------
219
+ # public interface functions
220
+
221
+ def get_tokenizer_model_path (vocab_size ):
222
+ """
223
+ Returns path to the sentencepiece tokenizer model for a given vocab size
224
+ vocab_size = 0 designates the default Llama 2 tokenizer, in that case
225
+ None is returned.
226
+ """
227
+ if vocab_size == 0 :
228
+ return None
229
+ else :
230
+ return os .path .join (DATA_CACHE_DIR , f"tok{ vocab_size } .model" )
212
231
213
232
class Task :
214
233
215
234
@staticmethod
216
- def iter_batches (split , batch_size , max_seq_len , device , num_workers = 0 ):
217
- ds = PretokDataset (split , max_seq_len )
235
+ def iter_batches (batch_size , device , num_workers = 0 , ** dataset_kwargs ):
236
+ ds = PretokDataset (** dataset_kwargs )
218
237
dl = torch .utils .data .DataLoader (
219
238
ds , batch_size = batch_size , pin_memory = True , num_workers = num_workers
220
239
)
@@ -223,6 +242,8 @@ def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
223
242
y = y .to (device , non_blocking = True )
224
243
yield x , y
225
244
245
+ # -----------------------------------------------------------------------------
246
+ # CLI for constructing the dataset
226
247
227
248
if __name__ == "__main__" :
228
249
"""
0 commit comments