From 507d75dae6c04934711e60fa7d71e27caa8d37f5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 30 Apr 2025 21:40:24 +0800 Subject: [PATCH] experiment 4-1 logup/product argument --- ceno_zkvm/src/scheme/constants.rs | 4 +- ceno_zkvm/src/scheme/prover.rs | 62 ++++++++++++++------- ceno_zkvm/src/scheme/utils.rs | 93 +++++++++++++++++++++++-------- 3 files changed, 113 insertions(+), 46 deletions(-) diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 4ad5ed0ba..5479631a3 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -2,7 +2,7 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; pub(crate) const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup pub(crate) const SEL_DEGREE: usize = 2; -pub const NUM_FANIN: usize = 2; -pub const NUM_FANIN_LOGUP: usize = 2; +pub const NUM_FANIN: usize = 4; +pub const NUM_FANIN_LOGUP: usize = 4; pub const MAX_NUM_VARIABLES: usize = 24; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 32eb6415a..74f82afc9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -397,7 +397,7 @@ impl> ZKVMProver { // TODO optimize last layer to avoid alloc new vector to save memory let lk_records_last_layer = interleaving_mles_to_mles(lk_records_wit, num_instances, NUM_FANIN, chip_record_alpha); - assert_eq!(lk_records_last_layer.len(), 2); + assert_eq!(lk_records_last_layer.len(), NUM_FANIN); exit_span!(span); let span = entered_span!("tower_witness_lk_layers"); @@ -405,18 +405,26 @@ impl> ZKVMProver { exit_span!(span); exit_span!(wit_inference_span); - if cfg!(test) { + if cfg!(debug_assertions) { // sanity check - assert_eq!(lk_wit_layers.len(), log2_num_instances + log2_lk_count); - assert_eq!(r_wit_layers.len(), log2_num_instances + log2_r_count); - assert_eq!(w_wit_layers.len(), log2_num_instances + log2_w_count); + assert_eq!( + lk_wit_layers.len(), + (log2_num_instances + log2_lk_count) / 2 + ); + assert_eq!(r_wit_layers.len(), (log2_num_instances + log2_r_count) / 2); + assert_eq!(w_wit_layers.len(), (log2_num_instances + log2_w_count) / 2); assert!(lk_wit_layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); + let expected_size = 1 << (ceil_log2(NUM_FANIN) * i); + let (p1, p2, p3, p4, q1, q2, q3, q4) = + (&w[0], &w[1], &w[2], &w[3], &w[4], &w[5], &w[6], &w[7]); p1.evaluations().len() == expected_size && p2.evaluations().len() == expected_size + && p3.evaluations().len() == expected_size + && p4.evaluations().len() == expected_size && q1.evaluations().len() == expected_size && q2.evaluations().len() == expected_size + && q3.evaluations().len() == expected_size + && q4.evaluations().len() == expected_size })); assert!(r_wit_layers.iter().enumerate().all(|(i, r_wit_layer)| { let expected_size = 1 << (ceil_log2(NUM_FANIN) * i); @@ -1187,7 +1195,7 @@ impl TowerProver { ) -> (Point, TowerProofs) { // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 // TODO mayber give a better naming? - assert_eq!(num_fanin, 2); + assert_eq!(num_fanin, 4); let mut proofs = TowerProofs::new(prod_specs.len(), logup_specs.len()); let log_num_fanin = ceil_log2(num_fanin); @@ -1244,7 +1252,7 @@ impl TowerProver { if round < s.witness.len() { let layer_polys = &s.witness[round]; // sanity check - assert_eq!(layer_polys.len(), 4); // p1, q1, p2, q2 + assert_eq!(layer_polys.len(), 8); // p1, q1, p2, q2, p3, q3, p4, q4 assert!( layer_polys .iter() @@ -1253,19 +1261,30 @@ impl TowerProver { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - let (q2, q1, p2, p1) = ( + let (q4, q3, q2, q1, p4, p3, p2, p1) = ( + &layer_polys[7], + &layer_polys[6], + &layer_polys[5], + &layer_polys[4], &layer_polys[3], &layer_polys[2], &layer_polys[1], &layer_polys[0], ); - // \sum_s eq(rt, s) * alpha_numerator^{i} * (p1 * q2 + p2 * q1) - virtual_polys.add_mle_list(vec![&eq, &p1, &q2], *alpha_numerator); - virtual_polys.add_mle_list(vec![&eq, &p2, &q1], *alpha_numerator); - - // \sum_s eq(rt, s) * alpha_denominator^{i} * (q1 * q2) - virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator); + // \sum_s eq(rt, s) * alpha_numerator^{i} * ( + // p1 * q2 * q3 * q4 + // + p2 * q1 * q3 * q4 + // + p3 * q1 * q2 * q4 + // + p4 * q1 * q2 * q3 + // ) + virtual_polys.add_mle_list(vec![&eq, &q2, &q3, &q4, &p1], *alpha_numerator); + virtual_polys.add_mle_list(vec![&eq, &q1, &q3, &q4, &p2], *alpha_numerator); + virtual_polys.add_mle_list(vec![&eq, &q1, &q2, &q4, &p3], *alpha_numerator); + virtual_polys.add_mle_list(vec![&eq, &q1, &q2, &q3, &p4], *alpha_numerator); + + // \sum_s eq(rt, s) * alpha_denominator^{i} * (q1 * q2 * q3 * q4) + virtual_polys.add_mle_list(vec![&eq, &q1, &q2, &q3, &q4], *alpha_denominator); } } @@ -1308,12 +1327,15 @@ impl TowerProver { for (i, s) in enumerate(&logup_specs) { if round < s.witness.len() { // collect evals belong to current spec - // p1, q2, p2, q1 - let p1 = *evals_iter.next().expect("insufficient evals length"); let q2 = *evals_iter.next().expect("insufficient evals length"); - let p2 = *evals_iter.next().expect("insufficient evals length"); + let q3 = *evals_iter.next().expect("insufficient evals length"); + let q4 = *evals_iter.next().expect("insufficient evals length"); + let p1 = *evals_iter.next().expect("insufficient evals length"); let q1 = *evals_iter.next().expect("insufficient evals length"); - proofs.push_logup_evals_and_point(i, vec![p1, p2, q1, q2], rt_prime.clone()); + let p2 = *evals_iter.next().expect("insufficient evals length"); + let p3 = *evals_iter.next().expect("insufficient evals length"); + let p4 = *evals_iter.next().expect("insufficient evals length"); + proofs.push_logup_evals_and_point(i, vec![p1, p2, p3, p4, q1, q2, q3, q4], rt_prime.clone()); } } assert_eq!(evals_iter.next(), None); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 388edc6fd..6ac20021a 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -86,16 +86,23 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( } macro_rules! tower_mle_4 { - ($p1:ident, $p2:ident, $q1:ident, $q2:ident, $start_index:ident, $cur_len:ident) => {{ + ($p1:ident, $p2:ident, $p3:ident, $p4:ident, $q1:ident, $q2:ident, $q3:ident, $q4:ident, $start_index:ident, $cur_len:ident) => {{ let range = $start_index..($start_index + $cur_len); $q1[range.clone()] .par_iter() .zip(&$q2[range.clone()]) + .zip(&$q3[range.clone()]) + .zip(&$q4[range.clone()]) .zip(&$p1[range.clone()]) - .zip(&$p2[range]) - .map(|(((q1, q2), p1), p2)| { - let p = *q1 * *p2 + *q2 * *p1; - let q = *q1 * *q2; + .zip(&$p2[range.clone()]) + .zip(&$p3[range.clone()]) + .zip(&$p4[range]) + .map(|(((((((q1, q2), q3), q4), p1), p2), p3), p4)| { + let p = *q2 * *q3 * *q4 * *p1 + + *q1 * *q3 * *q4 * *p2 + + *q1 * *q2 * *q4 * *p3 + + *q1 * *q2 * *q3 * *p4; + let q = *q1 * *q2 * *q3 * *q4; (p, q) }) .unzip() @@ -103,63 +110,89 @@ macro_rules! tower_mle_4 { } /// infer logup witness from last layer -/// return is the ([p1,p2], [q1,q2]) for each layer +/// return is the ([p1,p2,p3,p4], [q1,q2,q3,q4]) for each layer pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( p_mles: Option>>, q_mles: Vec>, ) -> Vec>> { - if cfg!(test) { - assert_eq!(q_mles.len(), 2); + if cfg!(debug_assertions) { + assert_eq!(q_mles.len(), 4); assert!(q_mles.iter().map(|q| q.evaluations().len()).all_equal()); } - let num_vars = ceil_log2(q_mles[0].evaluations().len()); - let mut wit_layers = (0..num_vars).fold(vec![(p_mles, q_mles)], |mut acc, _| { + let num_layers = q_mles[0].num_vars() / 2; + let mut wit_layers = (0..num_layers).fold(vec![(p_mles, q_mles)], |mut acc, _| { let (p, q): &( Option>>, Vec>, ) = acc.last().unwrap(); - let (q1, q2) = (&q[0], &q[1]); - let cur_len = q1.evaluations().len() / 2; + let (q1, q2, q3, q4) = (&q[0], &q[1], &q[2], &q[3]); + let cur_len = q1.evaluations().len() / 4; let (next_p, next_q): ( Vec>, Vec>, - ) = (0..2) + ) = (0..4) .map(|index| { let start_index = cur_len * index; let (p_evals, q_evals): (Vec, Vec) = if let Some(p) = p { - let (p1, p2) = (&p[0], &p[1]); + let (p1, p2, p3, p4) = (&p[0], &p[1], &p[2], &p[3]); match ( p1.evaluations(), p2.evaluations(), + p3.evaluations(), + p4.evaluations(), q1.evaluations(), q2.evaluations(), + q3.evaluations(), + q4.evaluations(), ) { ( FieldType::Ext(p1), FieldType::Ext(p2), + FieldType::Ext(p3), + FieldType::Ext(p4), FieldType::Ext(q1), FieldType::Ext(q2), - ) => tower_mle_4!(p1, p2, q1, q2, start_index, cur_len), + FieldType::Ext(q3), + FieldType::Ext(q4), + ) => tower_mle_4!(p1, p2, p3, p4, q1, q2, q3, q4, start_index, cur_len), ( FieldType::Base(p1), FieldType::Base(p2), + FieldType::Base(p3), + FieldType::Base(p4), FieldType::Ext(q1), FieldType::Ext(q2), - ) => tower_mle_4!(p1, p2, q1, q2, start_index, cur_len), + FieldType::Ext(q3), + FieldType::Ext(q4), + ) => tower_mle_4!(p1, p2, p3, p4, q1, q2, q3, q4, start_index, cur_len), _ => unreachable!(), } } else { - match (q1.evaluations(), q2.evaluations()) { - (FieldType::Ext(q1), FieldType::Ext(q2)) => { + match ( + q1.evaluations(), + q2.evaluations(), + q3.evaluations(), + q4.evaluations(), + ) { + ( + FieldType::Ext(q1), + FieldType::Ext(q2), + FieldType::Ext(q3), + FieldType::Ext(q4), + ) => { let range = start_index..(start_index + cur_len); q1[range.clone()] .par_iter() - .zip(&q2[range]) - .map(|(q1, q2)| { - // 1 / q1 + 1 / q2 = (q1+q2) / q1*q2 - // p is numerator and q is denominator - let p = *q1 + *q2; - let q = *q1 * *q2; + .zip(&q2[range.clone()]) + .zip(&q3[range.clone()]) + .zip(&q4[range]) + .map(|(((q1, q2), q3), q4)| { + // 1 / q1 + 1 / q2 + 1 / q3 + 1 / q4 + let p = *q2 * *q3 * *q4 + + *q1 * *q3 * *q4 + + *q1 * *q2 * *q4 + + *q1 * *q2 * *q3; + let q = *q1 * *q2 * *q3 * *q4; (p, q) }) .unzip() @@ -195,6 +228,18 @@ pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( .collect::>() .into_mle() .into(), + (0..len) + .into_par_iter() + .map(|_| E::ONE) + .collect::>() + .into_mle() + .into(), + (0..len) + .into_par_iter() + .map(|_| E::ONE) + .collect::>() + .into_mle() + .into(), ] .into_iter() .chain(q)