From b3b9203d621b84fb1a522c42d1a1d99c784973e7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 9 Dec 2024 15:48:48 +0800 Subject: [PATCH] add SPARK pcs: prover logic --- ceno_zkvm/src/scheme/prover.rs | 275 ++++++++++++++++++++++++++++++++- ceno_zkvm/src/scheme/tests.rs | 64 +++++++- 2 files changed, 332 insertions(+), 7 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index b0ff7b656..b4eb74004 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,6 +1,9 @@ +use base64::read; use ff_ext::ExtensionField; +use goldilocks::SmallField; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, + iter::successors, sync::Arc, }; @@ -8,12 +11,18 @@ use ff::Field; use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, - util::ceil_log2, + mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, + op_mle, + util::{ceil_log2, create_uninit_vec, max_usable_threads}, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::{ + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, + }, + slice::ParallelSlice, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -21,18 +30,20 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use crate::{ + chip_handler::utils::power_sequence, circuit_builder::SetTableAddrType, error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, wit_infer_by_expr, }, }, structs::{ - Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + Point, PointAndEval, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, + ZKVMWitnesses, }, utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, @@ -1297,3 +1308,257 @@ impl TowerProver { (next_rt, proofs) } } + +pub struct SparkProver; + +pub struct SparkWitness<'a, E: ExtensionField> { + // sparse representative + pub row_index: ArcMultilinearExtension<'a, E>, + pub col_index: ArcMultilinearExtension<'a, E>, + pub value: ArcMultilinearExtension<'a, E>, + // offline memory check + // write_ts can be read_rs + 1 + pub read_ts_row: ArcMultilinearExtension<'a, E>, + pub read_ts_col: ArcMultilinearExtension<'a, E>, + pub audit_ts_row: ArcMultilinearExtension<'a, E>, + pub audit_ts_col: ArcMultilinearExtension<'a, E>, +} + +/// derive power sequence [1, base, base^2, ..., base^(len-1)] of base expression +fn power_sequence_scalar<'a, E: ExtensionField>(base: &'a E) -> impl Iterator + 'a { + successors(Some(E::ONE), move |prev| Some(prev.clone() * base.clone())) +} + +fn spark_e_vec(index_vec: &ArcMultilinearExtension, eq: &[E]) -> Vec { + index_vec + .get_base_field_vec() + .iter() + .map(|index| eq[index.to_canonical_u64() as usize]) + .collect_vec() +} + +/// from spartan p.27 MemoryInTheHead +pub fn mem_in_the_head( + mem_log2_size: usize, + addr: &ArcMultilinearExtension, +) -> (Vec, Vec, Vec) { + let n = addr.evaluations().len(); + let mut read_ts = vec![E::ZERO; n]; + let mut write_ts = vec![E::ZERO; n]; + let mut audit_ts = vec![E::ZERO; 1 << mem_log2_size]; + let addr = addr.get_base_field_vec(); + for i in 0..addr.len() { + let addr_u64 = addr[i].to_canonical_u64() as usize; + let r_ts = audit_ts[addr_u64]; + let ts = r_ts + E::ONE; + read_ts[i] = r_ts; + write_ts[i] = ts; + audit_ts[addr_u64] = ts; + } + (read_ts, write_ts, audit_ts) +} + +fn infer_tower_product_witness_from_poly( + poly: &DenseMultilinearExtension, +) -> Vec>> { + infer_tower_product_witness( + poly.num_vars(), + (0..NUM_FANIN) + .map(move |i| { + let ranged_poly: ArcMultilinearExtension = + Arc::new(poly.get_ranged_mle(NUM_FANIN, i)); + ranged_poly + }) + .collect(), + NUM_FANIN, + ) +} + +fn mem_read_write_set_witness<'a, E: ExtensionField>( + n_threads: usize, + index: &ArcMultilinearExtension<'a, E>, + read_ts: &ArcMultilinearExtension<'a, E>, + e_vec: &DenseMultilinearExtension, + pow_gamma: &[E], +) -> (DenseMultilinearExtension, DenseMultilinearExtension) { + let per_row_rw_chunk_size = e_vec.evaluations().len().div_ceil(n_threads); + let (read_set, write_set): (Vec, Vec) = e_vec + .get_ext_field_vec() + .par_chunks_exact(per_row_rw_chunk_size) + .zip_eq( + index + .get_base_field_vec() + .par_chunks_exact(per_row_rw_chunk_size), + ) + .zip_eq( + read_ts + .get_base_field_vec() + .par_chunks_exact(per_row_rw_chunk_size), + ) + .enumerate() + .flat_map(|(i, ((value, index), read_ts))| { + izip!(value, index, read_ts) + .map(|(value, index, read_ts)| { + // a, v, t + let tmp = pow_gamma[1] * value + index; + (tmp + read_ts, tmp + read_ts + E::ONE) + }) + .collect_vec() + }) + .unzip(); + let (read_set_poly, write_set_poly) = (read_set.into_mle(), write_set.into_mle()); + (read_set_poly, write_set_poly) +} + +fn mem_init_audit_witness( + n_threads: usize, + mem: &[E], + audit_ts: &ArcMultilinearExtension, + pow_gamma: &[E], +) -> (DenseMultilinearExtension, DenseMultilinearExtension) { + let per_row_init_chunk_size = mem.len().div_ceil(n_threads); + let (mem_init_row, mem_audit_row): (Vec, Vec) = mem + .par_chunks_exact(per_row_init_chunk_size) + .zip_eq( + audit_ts + .get_base_field_vec() + .par_chunks_exact(per_row_init_chunk_size), + ) + .enumerate() + .flat_map(|(i, (rows, audit_ts))| { + let addr_offset = i * per_row_init_chunk_size; + izip!(rows, audit_ts) + .enumerate() + .map(|(i, (row, audit_ts))| { + // a, v, t + let tmp: E = pow_gamma[1] * row + E::BaseField::from((addr_offset + i) as u64); + (tmp, tmp + pow_gamma[2] * audit_ts) + }) + .collect_vec() + }) + .unzip(); + let (mem_init_row_poly, mem_audit_row_poly) = + (mem_init_row.into_mle(), mem_audit_row.into_mle()); + + (mem_init_row_poly, mem_audit_row_poly) +} +pub struct SparkProof; + +/// Tower Prover +impl SparkProver { + #[tracing::instrument(skip_all, name = "spark_prover_create_proof", level = "trace")] + pub fn create_proof<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme>( + pp: PCS::ProverParam, + point_n_evals: &[PointAndEval], + witnesses: &[SparkWitness], + transcript: &mut Transcript, + ) -> (Point, SparkProof) { + // product argument to prove offline memory check argument + let n_threads = max_usable_threads(); + // let alpha_pows = power_sequence_scalar(alpha).take(3).collect_vec(); + for (point_n_eval, witnesses) in izip!(point_n_evals, witnesses) { + // sanity check + let (r, eval) = (&point_n_eval.point, &point_n_eval.eval); + assert!(r.len().is_power_of_two()); + let (r_row, r_col) = (&r[..r.len() / 2], &r[r.len() / 2..]); + + let (mem_row, mem_col) = (build_eq_x_r_vec(r_row), build_eq_x_r_vec(r_col)); + // prepare e_row, e_col + let e_row = spark_e_vec(&witnesses.row_index, &mem_row).into_mle(); + let e_col = spark_e_vec(&witnesses.col_index, &mem_col).into_mle(); + let e_vec = vec![e_row, e_col]; + + assert_eq!( + e_vec[0].evaluations().len(), + witnesses.row_index.evaluations().len() + ); + assert_eq!( + e_vec[1].evaluations().len(), + witnesses.col_index.evaluations().len() + ); + + let commit_with_data = PCS::batch_commit_and_write(&pp, &e_vec, transcript).unwrap(); + let gamma = transcript.get_and_append_challenge(b"spark_omc_hash"); + let pow_gamma = power_sequence_scalar(&gamma.elements).take(3).collect_vec(); + + let ((mem_init_row_poly, mem_audit_row_poly), (mem_init_col_poly, mem_audit_col_poly)) = ( + mem_init_audit_witness(n_threads, &mem_row, &witnesses.audit_ts_row, &pow_gamma), + mem_init_audit_witness(n_threads, &mem_col, &witnesses.audit_ts_row, &pow_gamma), + ); + + let ( + mem_init_last_layer_row, + mem_audit_last_layer_row, + mem_init_last_layer_col, + mem_audit_last_layer_col, + ) = ( + infer_tower_product_witness_from_poly(&mem_init_row_poly), + infer_tower_product_witness_from_poly(&mem_audit_row_poly), + infer_tower_product_witness_from_poly(&mem_init_col_poly), + infer_tower_product_witness_from_poly(&mem_audit_col_poly), + ); + + let ((read_set_row, write_set_row), (read_set_col, write_set_col)) = ( + mem_read_write_set_witness( + n_threads, + &witnesses.row_index, + &witnesses.read_ts_row, + &e_vec[0], + &pow_gamma, + ), + mem_read_write_set_witness( + n_threads, + &witnesses.row_index, + &witnesses.read_ts_row, + &e_vec[0], + &pow_gamma, + ), + ); + + let ( + mem_read_last_layer_row, + mem_write_last_layer_row, + mem_read_last_layer_col, + mem_write_last_layer_col, + ) = ( + infer_tower_product_witness_from_poly(&read_set_row), + infer_tower_product_witness_from_poly(&write_set_row), + infer_tower_product_witness_from_poly(&read_set_col), + infer_tower_product_witness_from_poly(&write_set_col), + ); + + let (rt_tower, tower_proof) = TowerProver::create_proof( + vec![ + TowerProverSpec { + witness: mem_init_last_layer_row, + }, + TowerProverSpec { + witness: mem_audit_last_layer_row, + }, + TowerProverSpec { + witness: mem_read_last_layer_row, + }, + TowerProverSpec { + witness: mem_write_last_layer_row, + }, + TowerProverSpec { + witness: mem_init_last_layer_col, + }, + TowerProverSpec { + witness: mem_audit_last_layer_col, + }, + TowerProverSpec { + witness: mem_read_last_layer_col, + }, + TowerProverSpec { + witness: mem_write_last_layer_col, + }, + ], + vec![], + NUM_FANIN, + transcript, + ); + } + unimplemented!() + } +} diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index fe3ce8f07..9df5ada9c 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -8,7 +8,7 @@ use ceno_emul::{ }; use ff::Field; use ff_ext::ExtensionField; -use goldilocks::GoldilocksExt2; +use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme}; use multilinear_extensions::{ @@ -25,6 +25,7 @@ use crate::{ Instruction, riscv::{arith::AddInstruction, ecall::HaltInstruction}, }, + scheme::prover::{SparkWitness, mem_in_the_head}, set_val, structs::{ PointAndEval, RAMType::Register, TowerProver, TowerProverSpec, ZKVMConstraintSystem, @@ -37,7 +38,7 @@ use crate::{ use super::{ PublicValues, constants::{MAX_NUM_VARIABLES, NUM_FANIN}, - prover::ZKVMProver, + prover::{SparkProver, ZKVMProver}, utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; @@ -377,3 +378,62 @@ fn test_tower_proof_various_prod_size() { _test_tower_proof_prod_size_2(1 << leaf_layer_size); } } + +#[test] +fn test_spark_prove_n_verify() { + type E = GoldilocksExt2; + type Pcs = BasefoldDefault; + let mut rng = test_rng(); + let num_var = 10; + let read_write_log2_len = 32; + let repeat_log2 = 4; + + // this addr represent full leng of read, but its value range just within [0, 1 << read_write_log2_len/2] + let addr_row: ArcMultilinearExtension = (0..1 << repeat_log2) + .flat_map(|_| { + (0..1 << (read_write_log2_len / 2)) + .map(Goldilocks::from) + .collect_vec() + }) + .collect_vec() + .into_mle() + .into(); + + let addr_col: ArcMultilinearExtension = (0..1 << repeat_log2) + .flat_map(|_| { + (1 << (read_write_log2_len / 2)..0) + .map(Goldilocks::from) + .collect_vec() + }) + .collect_vec() + .into_mle() + .into(); + + let value = (0..addr_col.evaluations().len()) + .map(|_| Goldilocks::random(&mut rng)) + .collect_vec(); + + let mut transcript = Transcript::new(b"test_tower_proof"); + let ((read_ts_row, _, audit_ts_row), (read_ts_col, _, audit_ts_col)) = ( + mem_in_the_head(num_var, &addr_row), + mem_in_the_head(num_var, &addr_col), + ); + let witnesses: Vec> = vec![SparkWitness { + row_index: addr_row, + col_index: addr_col, + value: value.into_mle().into(), + read_ts_row: read_ts_row.into_mle().into(), + read_ts_col: read_ts_col.into_mle().into(), + audit_ts_row: audit_ts_row.into_mle().into(), + audit_ts_col: audit_ts_col.into_mle().into(), + }]; + + // pcs setup + let param = Pcs::setup(1 << 13).unwrap(); + let (pp, _vp) = Pcs::trim(param, 1 << 13).unwrap(); + let point_n_evals = vec![PointAndEval { + point: (0..num_var).map(|_| E::random(&mut rng)).collect_vec(), + eval: E::random(&mut rng), + }]; + SparkProver::create_proof::(pp, &point_n_evals, &witnesses, &mut transcript); +}