diff --git a/Cargo.lock b/Cargo.lock index 8e2f49326..f91dde8a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1159,6 +1159,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hex-conservative" version = "0.2.1" @@ -1582,6 +1588,7 @@ dependencies = [ "ctr", "ff_ext", "generic-array", + "hex", "itertools 0.13.0", "multilinear_extensions", "num-bigint", diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 3dbf76d76..b80134f8d 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -17,6 +17,7 @@ clap.workspace = true ctr = "0.9" ff_ext = { path = "../ff_ext" } generic-array = { version = "0.14", features = ["serde"] } +hex = "0.4" itertools.workspace = true multilinear_extensions = { path = "../multilinear_extensions" } num-bigint = "0.4" @@ -45,6 +46,10 @@ parallel = ["dep:rayon"] print-trace = ["whir/print-trace"] sanity-check = [] +[[bin]] +name = "generate_test_vector" +path = "bin/generate_test_vector.rs" + [[bench]] harness = false name = "basefold" diff --git a/mpcs/bin/generate_test_vector.rs b/mpcs/bin/generate_test_vector.rs new file mode 100644 index 000000000..5883c8059 --- /dev/null +++ b/mpcs/bin/generate_test_vector.rs @@ -0,0 +1,94 @@ +use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; +use mpcs::{ + Basefold, BasefoldRSParams, PolynomialCommitmentScheme, Whir, WhirDefaultSpec, + test_util::{get_point_from_challenge, setup_pcs}, +}; +use multilinear_extensions::virtual_poly::ArcMultilinearExtension; +use rand::{distributions::Standard, prelude::Distribution, thread_rng}; +use transcript::{BasicTranscript, Transcript}; +use witness::RowMajorMatrix; + +type PcsWhirGoldilocks = Whir; +type PcsWhirBabyBear = Whir; +type PcsBasefoldGoldilocks = Basefold; +type PcsBasefoldBabyBear = Basefold; + +use clap::Parser; + +#[derive(Parser)] +struct Args { + #[arg(short = 'f', long, default_value = "goldilocks")] + field: String, + #[arg(short = 'p', long, default_value = "basefold")] + pcs: String, + #[arg(short = 'n', long, default_value = "5")] + num_var: u32, +} + +fn main() { + // pass the parameters to determine which field to use, using the clap::Parser + let args = Args::parse(); + let num_var = args.num_var; + let (vp, comm, eval, proof) = match (args.field.as_str(), args.pcs.as_str()) { + ("goldilocks", "whir") => { + generate_test_vector::(num_var as usize) + } + ("goldilocks", "basefold") => { + generate_test_vector::(num_var as usize) + } + ("babybear", "whir") => { + generate_test_vector::(num_var as usize) + } + ("babybear", "basefold") => { + generate_test_vector::(num_var as usize) + } + _ => panic!("Invalid combination of field and PCS"), + }; + println!("num_vars: {}", num_var); + println!("vp: {}", vp); + println!("comm: {}", comm); + println!("eval: {}", eval); + println!("proof: {}", proof); +} + +pub fn generate_test_vector( + num_vars: usize, +) -> (String, String, String, String) +where + Pcs: PolynomialCommitmentScheme, + Standard: Distribution, +{ + let (pp, vp) = setup_pcs::(num_vars); + let mut test_rng = thread_rng(); + + // Commit and open + let (comm, eval, proof) = { + let mut transcript = BasicTranscript::new(b"BaseFold"); + let rmm = RowMajorMatrix::::rand(&mut test_rng, 1 << num_vars, 1); + let poly: ArcMultilinearExtension = rmm.to_mles().remove(0).into(); + let comm = Pcs::commit_and_write(&pp, rmm, &mut transcript).unwrap(); + + let point = get_point_from_challenge(num_vars, &mut transcript); + let eval = poly.evaluate(point.as_slice()); + transcript.append_field_element_ext(&eval); + + ( + Pcs::get_pure_commitment(&comm), + eval, + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(), + ) + }; + // Serialize vp, comm, eval, proof using bincode + let vp_bin = bincode::serialize(&vp).unwrap(); + let comm_bin = bincode::serialize(&comm).unwrap(); + let eval_bin = bincode::serialize(&eval).unwrap(); + let proof_bin = bincode::serialize(&proof).unwrap(); + + // Encode them as hex strings + let vp_hex = hex::encode(vp_bin); + let comm_hex = hex::encode(comm_bin); + let eval_hex = hex::encode(eval_bin); + let proof_hex = hex::encode(proof_bin); + + (vp_hex, comm_hex, eval_hex, proof_hex) +} diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 410f0773e..838d74f68 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -357,6 +357,7 @@ pub mod test_util { ) -> Vec { transcript.sample_and_append_vec(b"Point", num_vars) } + pub fn get_points_from_challenge( num_vars: impl Fn(usize) -> usize, num_points: usize,