Skip to content

Commit f5fc0c2

Browse files
committed
final piece: run.c support for new tokenizer, super ez
1 parent ea4cedc commit f5fc0c2

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

run.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,15 @@ void error_usage() {
508508
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
509509
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
510510
fprintf(stderr, " -i <string> input prompt\n");
511+
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
511512
exit(EXIT_FAILURE);
512513
}
513514

514515
int main(int argc, char *argv[]) {
515516

516517
// default inits
517518
char *checkpoint = NULL; // e.g. out/model.bin
519+
char *tokenizer = "tokenizer.bin";
518520
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
519521
float topp = 1.0f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
520522
rng_seed = 0; // seed rng with time by default
@@ -534,6 +536,7 @@ int main(int argc, char *argv[]) {
534536
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
535537
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
536538
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
539+
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
537540
else { error_usage(); }
538541
}
539542
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
@@ -567,13 +570,13 @@ int main(int argc, char *argv[]) {
567570
// right now we cannot run for more than config.seq_len steps
568571
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
569572

570-
// read in the tokenizer.bin file
573+
// read in the tokenizer .bin file
571574
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
572575
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
573576
unsigned int max_token_length;
574577
{
575-
FILE *file = fopen("tokenizer.bin", "rb");
576-
if (!file) { fprintf(stderr, "couldn't load tokenizer.bin\n"); return 1; }
578+
FILE *file = fopen(tokenizer, "rb");
579+
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
577580
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
578581
int len;
579582
for (int i = 0; i < config.vocab_size; i++) {

0 commit comments

Comments
 (0)