@@ -508,13 +508,15 @@ void error_usage() {
508
508
fprintf (stderr , " -s <int> random seed, default time(NULL)\n" );
509
509
fprintf (stderr , " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n" );
510
510
fprintf (stderr , " -i <string> input prompt\n" );
511
+ fprintf (stderr , " -z <string> optional path to custom tokenizer\n" );
511
512
exit (EXIT_FAILURE );
512
513
}
513
514
514
515
int main (int argc , char * argv []) {
515
516
516
517
// default inits
517
518
char * checkpoint = NULL ; // e.g. out/model.bin
519
+ char * tokenizer = "tokenizer.bin" ;
518
520
float temperature = 1.0f ; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
519
521
float topp = 1.0f ; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
520
522
rng_seed = 0 ; // seed rng with time by default
@@ -534,6 +536,7 @@ int main(int argc, char *argv[]) {
534
536
else if (argv [i ][1 ] == 's' ) { rng_seed = atoi (argv [i + 1 ]); }
535
537
else if (argv [i ][1 ] == 'n' ) { steps = atoi (argv [i + 1 ]); }
536
538
else if (argv [i ][1 ] == 'i' ) { prompt = argv [i + 1 ]; }
539
+ else if (argv [i ][1 ] == 'z' ) { tokenizer = argv [i + 1 ]; }
537
540
else { error_usage (); }
538
541
}
539
542
if (rng_seed == 0 ) { rng_seed = (unsigned int )time (NULL );}
@@ -567,13 +570,13 @@ int main(int argc, char *argv[]) {
567
570
// right now we cannot run for more than config.seq_len steps
568
571
if (steps <= 0 || steps > config .seq_len ) { steps = config .seq_len ; }
569
572
570
- // read in the tokenizer.bin file
573
+ // read in the tokenizer .bin file
571
574
char * * vocab = (char * * )malloc (config .vocab_size * sizeof (char * ));
572
575
float * vocab_scores = (float * )malloc (config .vocab_size * sizeof (float ));
573
576
unsigned int max_token_length ;
574
577
{
575
- FILE * file = fopen (" tokenizer.bin" , "rb" );
576
- if (!file ) { fprintf (stderr , "couldn't load tokenizer.bin \n" ); return 1 ; }
578
+ FILE * file = fopen (tokenizer , "rb" );
579
+ if (!file ) { fprintf (stderr , "couldn't load %s \n" , tokenizer ); return 1 ; }
577
580
if (fread (& max_token_length , sizeof (int ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); return 1 ; }
578
581
int len ;
579
582
for (int i = 0 ; i < config .vocab_size ; i ++ ) {
0 commit comments