Skip to content

Commit b0cfa24

Browse files
committed
ok i can train and sample a model with a custom tokenizer
1 parent 4c6f0af commit b0cfa24

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
@dataclass
1313
class ModelArgs:
14+
# default hyperparameters for the Llama 7B model
1415
dim: int = 4096
1516
n_layers: int = 32
1617
n_heads: int = 32
1718
n_kv_heads: Optional[int] = None
18-
vocab_size: int = -1 # defined later by tokenizer
19-
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
19+
vocab_size: int = 32000
20+
multiple_of: int = 256 # MLP hidden layer size will be multiple of
2021
norm_eps: float = 1e-5
2122
max_seq_len: int = 2048
2223
dropout: float = 0.0

sample.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from model import ModelArgs, Transformer
1010
from tokenizer import Tokenizer
1111

12+
from tinystories import get_tokenizer_model_path
13+
1214
# -----------------------------------------------------------------------------
1315
out_dir = 'out' # ignored if init_from is not 'resume'
1416
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
@@ -51,7 +53,9 @@
5153
model = torch.compile(model) # requires PyTorch 2.0 (optional)
5254

5355
# load the tokenizer
54-
enc = Tokenizer()
56+
assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize
57+
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
58+
enc = Tokenizer(tokenizer_model=tokenizer_model)
5559

5660
# encode the beginning of the prompt
5761
if start.startswith('FILE:'):

tinystories.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ def train_vocab(vocab_size):
120120

121121
def process_shard(args, vocab_size):
122122
shard_id, shard = args
123-
tokenizer_model = None
124-
if vocab_size > 0:
125-
tokenizer_model = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
123+
tokenizer_model = get_tokenizer_model_path()
126124
enc = Tokenizer(tokenizer_model)
127125
with open(shard, "r") as f:
128126
data = json.load(f)
@@ -171,10 +169,12 @@ def pretokenize(vocab_size):
171169
class PretokDataset(torch.utils.data.IterableDataset):
172170
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
173171

174-
def __init__(self, split, max_seq_len):
172+
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
175173
super().__init__()
176174
self.split = split
177175
self.max_seq_len = max_seq_len
176+
self.vocab_size = vocab_size
177+
self.vocab_source = vocab_source
178178

179179
def __iter__(self):
180180
# get worker info within a DataLoader
@@ -186,8 +186,14 @@ def __iter__(self):
186186
seed = 42 + worker_id + 1337 * rank
187187
rng = random.Random(seed)
188188
print(f"Created a PretokDataset with rng seed {seed}")
189-
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
190-
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
189+
if self.vocab_source == "llama2":
190+
# the .bin files are right along the .json files
191+
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
192+
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
193+
elif self.vocab_source == "custom":
194+
# the .bin files are in tok{N} directory
195+
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
196+
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
191197
# train/test split. let's use only shard 0 for test split, rest train
192198
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
193199
while True:
@@ -209,12 +215,25 @@ def __iter__(self):
209215
y = chunk[1:]
210216
yield x, y
211217

218+
# -----------------------------------------------------------------------------
219+
# public interface functions
220+
221+
def get_tokenizer_model_path(vocab_size):
222+
"""
223+
Returns path to the sentencepiece tokenizer model for a given vocab size
224+
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
225+
None is returned.
226+
"""
227+
if vocab_size == 0:
228+
return None
229+
else:
230+
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
212231

213232
class Task:
214233

215234
@staticmethod
216-
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
217-
ds = PretokDataset(split, max_seq_len)
235+
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
236+
ds = PretokDataset(**dataset_kwargs)
218237
dl = torch.utils.data.DataLoader(
219238
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
220239
)
@@ -223,6 +242,8 @@ def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
223242
y = y.to(device, non_blocking=True)
224243
yield x, y
225244

245+
# -----------------------------------------------------------------------------
246+
# CLI for constructing the dataset
226247

227248
if __name__ == "__main__":
228249
"""

train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
# data
4848
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
4949
max_seq_len = 256
50+
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
51+
vocab_size = 512
5052
dataset = "tinystories" # tinystories|tinyshakespeare
5153
# model
5254
dim = 288
@@ -83,6 +85,10 @@
8385
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
8486
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
8587

88+
# validating checks
89+
assert vocab_source in ["llama2", "custom"]
90+
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
91+
8692
# various inits, derived attributes, I/O setup
8793
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
8894
if ddp:
@@ -128,6 +134,8 @@
128134
task.iter_batches,
129135
batch_size=batch_size,
130136
max_seq_len=max_seq_len,
137+
vocab_size=vocab_size,
138+
vocab_source=vocab_source,
131139
device=device,
132140
num_workers=0,
133141
)
@@ -142,7 +150,7 @@
142150
n_layers=n_layers,
143151
n_heads=n_heads,
144152
n_kv_heads=n_heads,
145-
vocab_size=32000,
153+
vocab_size=vocab_size,
146154
multiple_of=multiple_of,
147155
max_seq_len=max_seq_len,
148156
dropout=dropout,
@@ -206,7 +214,7 @@ def estimate_loss():
206214
out = {}
207215
model.eval()
208216
for split in ["train", "val"]:
209-
batch_iter = iter_batches(split)
217+
batch_iter = iter_batches(split=split)
210218
losses = torch.zeros(eval_iters) # keep on CPU
211219
for k in range(eval_iters):
212220
X, Y = next(batch_iter)
@@ -238,7 +246,7 @@ def get_lr(it):
238246
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
239247

240248
# training loop
241-
train_batch_iter = iter_batches("train")
249+
train_batch_iter = iter_batches(split="train")
242250
X, Y = next(train_batch_iter) # fetch the very first batch
243251
t0 = time.time()
244252
local_iter_num = 0 # number of iterations in the lifetime of this process

0 commit comments

Comments
 (0)