From c4e483ccb2f5bce3b64dc4925202900a920443a2 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 18 Dec 2024 11:00:43 +0800 Subject: [PATCH 01/17] test extreme case --- ceno_zkvm/src/scheme/prover.rs | 38 +++++++-- ceno_zkvm/src/scheme/utils.rs | 138 +++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c8cae8bc..54d65f75c 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,11 +9,13 @@ use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, - util::ceil_log2, + util::{ceil_log2, max_usable_threads}, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -25,10 +27,10 @@ use crate::{ error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, + wit_infer_by_expr, wit_infer_by_expr_in_place, }, }, structs::{ @@ -238,14 +240,34 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference", profiling_3 = true); // main constraint: read/write record witness inference let record_span = entered_span!("record"); + // let records_wit: Vec> = cs + // .r_expressions + // .par_iter() + // .chain(cs.w_expressions.par_iter()) + // .chain(cs.lk_expressions.par_iter()) + // .map(|expr| { + // assert_eq!(expr.degree(), 1); + // wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) + // }) + // .collect(); + let n_threads = max_usable_threads(); let records_wit: Vec> = cs .r_expressions - .par_iter() - .chain(cs.w_expressions.par_iter()) - .chain(cs.lk_expressions.par_iter()) + .iter() + .chain(cs.w_expressions.iter()) + .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) + let len = witnesses[0].evaluations().len(); + let data = (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + .into_mle() + .into(); + // data.into_mle().into() + wit_infer_by_expr_in_place(&[], &witnesses, pi, challenges, expr, n_threads, data) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8ec6453a..09188c535 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -347,6 +347,144 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( ) } +pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( + fixed: &[ArcMultilinearExtension<'a, E>], + witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[ArcMultilinearExtension<'a, E>], + challenges: &[E; N], + expr: &Expression, + n_threads: usize, + mutable_res: ArcMultilinearExtension<'a, E>, +) -> ArcMultilinearExtension<'a, E> { + expr.evaluate_with_instance::>( + &|f| fixed[f.0].clone(), + &|witness_id| witnesses[witness_id as usize].clone(), + &|i| instance[i.0].clone(), + &|scalar| { + let scalar: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ + scalar, + ])); + scalar + }, + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be acquired once for each power + let challenge = challenges[challenge_id as usize]; + let challenge: ArcMultilinearExtension = Arc::new( + DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ + challenge.pow([pow as u64]) * scalar + offset, + ]), + ); + challenge + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] + b[0]], + )), + (1, _) => { + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..b.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[0] + b[i]; + }) + }); + mutable_res.clone() + } + (_, 1) => { + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[i] + b[0]; + }) + }); + mutable_res.clone() + } + (_, _) => { + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[i] + b[i]; + }) + }); + mutable_res.clone() + } + } + }) + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] * b[0]], + )), + (1, _) => { + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[0] * b[i]; + }) + }); + mutable_res.clone() + } + (_, 1) => { + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[i] * b[0]; + }) + }); + mutable_res.clone() + } + (_, _) => { + assert_eq!(a.len(), b.len()); + // we do the pointwise evaluation multiplication here without involving FFT + // the evaluations outside of range will be checked via sumcheck + identity polynomial + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a[i] * b[i]; + }) + }); + mutable_res.clone() + } + } + }) + }, + &|x, a, b| { + op_mle_xa_b!(|x, a, b| { + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..x.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a * x[i] + b; + }) + }); + mutable_res.clone() + }) + }, + ) +} + pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], From a0a34073f801eb7c4633818bf58ae634715398fa Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 00:45:52 +0800 Subject: [PATCH 02/17] simple object pool --- ceno_zkvm/src/expression.rs | 69 ++++++++++++++ ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/scheme/utils.rs | 145 ++++++++++++++++++++---------- ceno_zkvm/src/uint/util.rs | 36 ++++++++ multilinear_extensions/src/mle.rs | 76 ++++++++++++++++ 5 files changed, 280 insertions(+), 47 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index c2b523014..05bb6a429 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -18,6 +18,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::{ChallengeId, RAMType, WitnessId}, + uint::util::SimpleVecPool, }; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -141,6 +142,74 @@ impl Expression { } } + #[allow(clippy::too_many_arguments)] + pub fn evaluate_with_instance_pool( + &self, + fixed_in: &impl Fn(&Fixed) -> T, + wit_in: &impl Fn(WitnessId) -> T, // witin id + instance: &impl Fn(Instance) -> T, + constant: &impl Fn(E::BaseField) -> T, + challenge: &impl Fn(ChallengeId, usize, E, E) -> T, + sum: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, + product: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, + scaled: &impl Fn( + T, + T, + T, + &mut SimpleVecPool>, + &mut SimpleVecPool>, + ) -> T, + pool_e: &mut SimpleVecPool>, + pool_b: &mut SimpleVecPool>, + ) -> T { + match self { + Expression::Fixed(f) => fixed_in(f), + Expression::WitIn(witness_id) => wit_in(*witness_id), + Expression::Instance(i) => instance(*i), + Expression::Constant(scalar) => constant(*scalar), + Expression::Sum(a, b) => { + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + sum(a, b, pool_e, pool_b) + } + Expression::Product(a, b) => { + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + product(a, b, pool_e, pool_b) + } + Expression::ScaledSum(x, a, b) => { + let x = x.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + scaled(x, a, b, pool_e, pool_b) + } + Expression::Challenge(challenge_id, pow, scalar, offset) => { + challenge(*challenge_id, *pow, *scalar, *offset) + } + } + } + pub fn is_monomial_form(&self) -> bool { Self::is_monomial_form_inner(MonomialState::SumTerm, self) } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 945404ff3..9aae503f5 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -3,6 +3,7 @@ #![feature(stmt_expr_attributes)] #![feature(variant_count)] #![feature(strict_overflow_ops)] +#![feature(sync_unsafe_cell)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 09188c535..75bf9d33e 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,10 +1,10 @@ -use std::sync::Arc; +use std::{cell::SyncUnsafeCell, ops::Add, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair, + commutative_op_mle_pair, commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b, op_mle3_range, util::ceil_log2, @@ -19,7 +19,8 @@ use rayon::{ }; use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, utils::next_pow2_instance_padding, + expression::Expression, scheme::constants::MIN_PAR_SIZE, uint::util::SimpleVecPool, + utils::next_pow2_instance_padding, }; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector @@ -347,6 +348,29 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( ) } +fn mutable_a_plus_c(n_threads: usize, a: &A, b: &[B], res: &mut [A]) +where + B: Sync + Send + Copy, + A: Sync + Send + Copy + Add + Default, +{ + unsafe { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| { + let ptr = (*res.get()).as_mut_ptr(); + (0..b.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = *a + b[i]; + }) + }); + } +} + +use ff::Field; + +const POOL_CAP: usize = 2; + pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], @@ -356,7 +380,22 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( n_threads: usize, mutable_res: ArcMultilinearExtension<'a, E>, ) -> ArcMultilinearExtension<'a, E> { - expr.evaluate_with_instance::>( + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); + expr.evaluate_with_instance_pool::>( &|f| fixed[f.0].clone(), &|witness_id| witnesses[witness_id as usize].clone(), &|i| instance[i.0].clone(), @@ -377,50 +416,60 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( ); challenge }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), - (1, _) => { - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..b.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[0] + b[i]; - }) - }); - mutable_res.clone() - } - (_, 1) => { - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[i] + b[0]; - }) - }); - mutable_res.clone() - } - (_, _) => { - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[i] + b[i]; - }) - }); - mutable_res.clone() + &|a, b, pool_e, pool_b| { + commutative_op_mle_pair_pool!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] + b[0]], + )), + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..b.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[0] + b[i]; + }) + }); + res.into_inner().into_mle().into() + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[i] + b[0]; + }) + }); + res.into_inner().into_mle().into() + } + (_, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[i] + b[i]; + }) + }); + res.into_inner().into_mle().into() + } } - } - }) + }, + pool_e, + pool_b + ) }, - &|a, b| { + &|a, b, pool_e, pool_b| { commutative_op_mle_pair!(|a, b| { match (a.len(), b.len()) { (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( @@ -466,7 +515,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( } }) }, - &|x, a, b| { + &|x, a, b, pool_e, pool_b| { op_mle_xa_b!(|x, a, b| { assert_eq!(a.len(), 1); assert_eq!(b.len(), 1); @@ -482,6 +531,8 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( mutable_res.clone() }) }, + &mut pool_e, + &mut pool_b, ) } diff --git a/ceno_zkvm/src/uint/util.rs b/ceno_zkvm/src/uint/util.rs index 3ee46752a..a59ce5723 100644 --- a/ceno_zkvm/src/uint/util.rs +++ b/ceno_zkvm/src/uint/util.rs @@ -1,3 +1,5 @@ +use std::collections::VecDeque; + // calculate the maximum number of combinations for stars and bars formula const fn max_combinations(degree: usize, num_cells: usize) -> usize { // compute factorial of n using usize @@ -66,3 +68,37 @@ mod tests { assert_eq!(131070, max_carry_word_for_multiplication(2, 32, 16)); } } + +pub struct SimpleVecPool { + pool: VecDeque, +} + +impl SimpleVecPool { + // Create a new pool with a factory closure + pub fn new T>(cap: usize, init: F) -> Self { + let mut pool = SimpleVecPool { + pool: VecDeque::new(), + }; + (0..cap).for_each(|_| { + pool.add(init()); + }); + pool + } + + // Add a new item to the pool + pub fn add(&mut self, item: T) { + self.pool.push_back(item); + } + + // Borrow an item from the pool, or create a new one if empty + pub fn borrow(&mut self) -> T { + self.pool + .pop_front() + .expect("pool is empty, consider increase cap size") + } + + // Return an item to the pool + pub fn return_to_pool(&mut self, item: T) { + self.pool.push_back(item); + } +} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index b4e8df983..91017a59e 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1228,3 +1228,79 @@ macro_rules! commutative_op_mle_pair { commutative_op_mle_pair!(|$a, $b| $op, |out| out) }; } + +/// macro support op(a, b) and tackles type matching internally. +/// Please noted that op must satisfy commutative rule w.r.t op(b, a) operand swap. +#[macro_export] +macro_rules! commutative_op_mle_pair_pool { + (|$first:ident, $second:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { + match (&$first.evaluations(), &$second.evaluations()) { + ($crate::mle::FieldType::Base(base1), $crate::mle::FieldType::Base(base2)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &base1[start..][..offset] + } else { + &base1[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base2[start..][..offset] + } else { + &base2[..] + }; + let $res = $pool_b.borrow(); + let $bb_out = $op; + $op_bb_out + } + ($crate::mle::FieldType::Ext(ext), $crate::mle::FieldType::Base(base)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let $res = $pool_e.borrow(); + $op + } + ($crate::mle::FieldType::Base(base), $crate::mle::FieldType::Ext(ext)) => { + let base = if let Some((start, offset)) = $first.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let ext = if let Some((start, offset)) = $second.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + // swap first and second to make ext field come first before base field. + // so the same coding template can apply. + // that's why first and second operand must be commutative + let $first = ext; + let $second = base; + let $res = $pool_e.borrow(); + $op + } + ($crate::mle::FieldType::Ext(ext), $crate::mle::FieldType::Ext(base)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let $res = $pool_e.borrow(); + $op + } + _ => unreachable!(), + } + }; + (|$a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { + commutative_op_mle_pair_pool!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + }; +} From 9b0a44acd73f5b4ce8885a0235c3af68e90dc189 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 09:29:32 +0800 Subject: [PATCH 03/17] apply same techniqus to product --- ceno_zkvm/src/scheme/utils.rs | 96 +++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 75bf9d33e..db6cf252c 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -470,50 +470,60 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( ) }, &|a, b, pool_e, pool_b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), - (1, _) => { - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[0] * b[i]; - }) - }); - mutable_res.clone() - } - (_, 1) => { - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[i] * b[0]; - }) - }); - mutable_res.clone() - } - (_, _) => { - assert_eq!(a.len(), b.len()); - // we do the pointwise evaluation multiplication here without involving FFT - // the evaluations outside of range will be checked via sumcheck + identity polynomial - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a[i] * b[i]; - }) - }); - mutable_res.clone() + commutative_op_mle_pair_pool!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] * b[0]], + )), + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[0] * b[i]; + }) + }); + res.into_inner().into_mle().into() + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[i] * b[0]; + }) + }); + res.into_inner().into_mle().into() + } + (_, _) => { + assert_eq!(a.len(), b.len()); + // we do the pointwise evaluation multiplication here without involving FFT + // the evaluations outside of range will be checked via sumcheck + identity polynomial + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a[i] * b[i]; + }) + }); + res.into_inner().into_mle().into() + } } - } - }) + }, + pool_e, + pool_b + ) }, &|x, a, b, pool_e, pool_b| { op_mle_xa_b!(|x, a, b| { From 8c56b0a0c3e661da44c610b67326df0d7830857d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 10:05:40 +0800 Subject: [PATCH 04/17] add op_mle_xa_b_pool and pool recycle function --- ceno_zkvm/src/scheme/utils.rs | 118 +++++++++++++++++++----------- ceno_zkvm/src/uint/util.rs | 38 +++++----- multilinear_extensions/src/mle.rs | 39 ++++++++++ 3 files changed, 135 insertions(+), 60 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index db6cf252c..c68507a3f 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,4 +1,4 @@ -use std::{cell::SyncUnsafeCell, ops::Add, sync::Arc}; +use std::{borrow::Cow, cell::SyncUnsafeCell, ops::Add, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; @@ -6,7 +6,7 @@ use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle_xa_b, op_mle3_range, + op_mle_xa_b, op_mle_xa_b_pool, op_mle3_range, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, }; @@ -369,7 +369,25 @@ where use ff::Field; -const POOL_CAP: usize = 2; +const POOL_CAP: usize = 12; + +// fn recycle_arcpoly( +// poly: Cow>, +// pool_e: &mut SimpleVecPool>, +// pool_b: &mut SimpleVecPool>, +// ) { +// match poly { +// Cow::Borrowed(_) => (), +// Cow::Owned(_) => { +// let poly = Arc::try_unwrap(poly.into_owned()).unwrap().downcast::<_>(); +// match poly.evaluations_to_owned() { +// FieldType::Base(vec) => pool_b.return_to_pool(vec), +// FieldType::Ext(vec) => pool_e.return_to_pool(vec), +// _ => unreachable!(), +// }; +// } +// }; +// } pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], @@ -395,16 +413,16 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( .map(|_| E::BaseField::ZERO) .collect::>() }); - expr.evaluate_with_instance_pool::>( - &|f| fixed[f.0].clone(), - &|witness_id| witnesses[witness_id as usize].clone(), - &|i| instance[i.0].clone(), + let poly = expr.evaluate_with_instance_pool::>>( + &|f| Cow::Borrowed(&fixed[f.0]), + &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), + &|i| Cow::Borrowed(&instance[i.0]), &|scalar| { let scalar: ArcMultilinearExtension = Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ scalar, ])); - scalar + Cow::Owned(scalar) }, &|challenge_id, pow, scalar, offset| { // TODO cache challenge power to be acquired once for each power @@ -414,16 +432,21 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( challenge.pow([pow as u64]) * scalar + offset, ]), ); - challenge + Cow::Owned(challenge) }, - &|a, b, pool_e, pool_b| { - commutative_op_mle_pair_pool!( + &|cow_a, cow_b, pool_e, pool_b| { + let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); + let poly = commutative_op_mle_pair_pool!( |a, b, res| { match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart(0, vec![ + a[0] + b[0], + ]), + ); + Cow::Owned(poly) + } (1, _) => { let res = SyncUnsafeCell::new(res); (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { @@ -435,7 +458,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[0] + b[i]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } (_, 1) => { let res = SyncUnsafeCell::new(res); @@ -448,7 +471,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[i] + b[0]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } (_, _) => { let res = SyncUnsafeCell::new(res); @@ -461,22 +484,27 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[i] + b[i]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } } }, pool_e, pool_b - ) + ); + poly }, &|a, b, pool_e, pool_b| { commutative_op_mle_pair_pool!( |a, b, res| { match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart(0, vec![ + a[0] * b[0], + ]), + ); + Cow::Owned(poly) + } (1, _) => { let res = SyncUnsafeCell::new(res); (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { @@ -488,7 +516,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[0] * b[i]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } (_, 1) => { let res = SyncUnsafeCell::new(res); @@ -501,7 +529,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[i] * b[0]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } (_, _) => { assert_eq!(a.len(), b.len()); @@ -517,7 +545,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( *ptr.add(i) = a[i] * b[i]; }) }); - res.into_inner().into_mle().into() + Cow::Owned(res.into_inner().into_mle().into()) } } }, @@ -526,24 +554,32 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( ) }, &|x, a, b, pool_e, pool_b| { - op_mle_xa_b!(|x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..x.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a * x[i] + b; - }) - }); - mutable_res.clone() - }) + op_mle_xa_b_pool!( + |x, a, b| { + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..x.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a * x[i] + b; + }) + }); + Cow::Owned(mutable_res.clone()) + }, + pool_e, + pool_b + ) }, &mut pool_e, &mut pool_b, - ) + ); + match poly { + Cow::Borrowed(poly) => poly.clone(), + Cow::Owned(_) => poly.into_owned(), + } } pub(crate) fn eval_by_expr( diff --git a/ceno_zkvm/src/uint/util.rs b/ceno_zkvm/src/uint/util.rs index a59ce5723..8cb4cc226 100644 --- a/ceno_zkvm/src/uint/util.rs +++ b/ceno_zkvm/src/uint/util.rs @@ -50,25 +50,6 @@ pub(crate) const fn max_carry_word_for_multiplication(n: usize, m: usize, c: usi max_carry_value_gt as u64 } -#[cfg(test)] -mod tests { - use crate::uint::util::{max_carry_word_for_multiplication, max_combinations}; - - #[test] - fn test_max_combinations_degree() { - // degree=1 is pure add, therefore only one term - assert_eq!(1, max_combinations(1, 4)); - // for degree=2 mul, we have u[0]*v[3], u[1]*v[2], u[2]*v[1], u[3]*v[0] - // thus 4 terms - assert_eq!(4, max_combinations(2, 4)); - } - - #[test] - fn test_max_word_of_limb_degree() { - assert_eq!(131070, max_carry_word_for_multiplication(2, 32, 16)); - } -} - pub struct SimpleVecPool { pool: VecDeque, } @@ -102,3 +83,22 @@ impl SimpleVecPool { self.pool.push_back(item); } } + +#[cfg(test)] +mod tests { + use crate::uint::util::{max_carry_word_for_multiplication, max_combinations}; + + #[test] + fn test_max_combinations_degree() { + // degree=1 is pure add, therefore only one term + assert_eq!(1, max_combinations(1, 4)); + // for degree=2 mul, we have u[0]*v[3], u[1]*v[2], u[2]*v[1], u[3]*v[0] + // thus 4 terms + assert_eq!(4, max_combinations(2, 4)); + } + + #[test] + fn test_max_word_of_limb_degree() { + assert_eq!(131070, max_carry_word_for_multiplication(2, 32, 16)); + } +} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 91017a59e..82601dab4 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1091,6 +1091,45 @@ macro_rules! op_mle_xa_b { }; } +/// deal with x * a + b +#[macro_export] +macro_rules! op_mle_xa_b_pool { + (|$x:ident, $a:ident, $b:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { + match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Base(a_vec), + $crate::mle::FieldType::Base(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Ext(a_vec), + $crate::mle::FieldType::Base(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Ext(a_vec), + $crate::mle::FieldType::Ext(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + (x, a, b) => unreachable!( + "unmatched pattern {:?} {:?} {:?}", + x.variant_name(), + a.variant_name(), + b.variant_name() + ), + } + }; + (|$x:ident, $a:ident, $b:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { + op_mle_xa_b_pool!(|$x, $a, $b| $op, $pool_e, $pool_b, |out| out) + }; +} + /// deal with f1 * f2 * f3 /// applying cumulative rule for f1, f2, f3 to canonical form: Ext field comes first following by Base Field #[macro_export] From ba963bd904372b8ebfb3ad411c9aaf0587faf87d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 11:04:26 +0800 Subject: [PATCH 05/17] downcast arc to avoid massive change --- ceno_zkvm/src/scheme/utils.rs | 409 +++++++++++++++++++--------------- ceno_zkvm/src/uint/util.rs | 1 + 2 files changed, 230 insertions(+), 180 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c68507a3f..9ebc34bae 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,8 +1,8 @@ -use std::{borrow::Cow, cell::SyncUnsafeCell, ops::Add, sync::Arc}; +use std::{any::TypeId, borrow::Cow, cell::SyncUnsafeCell, ops::Add, ptr, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; -use itertools::Itertools; +use itertools::{Either, Itertools}; use multilinear_extensions::{ commutative_op_mle_pair, commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, @@ -371,22 +371,66 @@ use ff::Field; const POOL_CAP: usize = 12; -// fn recycle_arcpoly( -// poly: Cow>, -// pool_e: &mut SimpleVecPool>, -// pool_b: &mut SimpleVecPool>, -// ) { -// match poly { -// Cow::Borrowed(_) => (), -// Cow::Owned(_) => { -// let poly = Arc::try_unwrap(poly.into_owned()).unwrap().downcast::<_>(); -// match poly.evaluations_to_owned() { -// FieldType::Base(vec) => pool_b.return_to_pool(vec), -// FieldType::Ext(vec) => pool_e.return_to_pool(vec), -// _ => unreachable!(), -// }; +fn try_recycle_arcpoly( + poly: Cow>, + pool_e: &mut SimpleVecPool>, + pool_b: &mut SimpleVecPool>, + pool_expected_size_vec: usize, +) { + fn downcast_arc( + arc: ArcMultilinearExtension<'_, E>, + ) -> DenseMultilinearExtension { + unsafe { + // get the raw pointer from the Arc + let raw = Arc::into_raw(arc); + // cast the raw pointer to the desired concrete type + let typed_ptr = raw as *const DenseMultilinearExtension; + // manually drop the Arc without dropping the value + Arc::decrement_strong_count(raw); + // reconstruct the Arc with the concrete type + // Move the value out + ptr::read(typed_ptr) + } + } + let len = poly.evaluations().len(); + if len == pool_expected_size_vec { + match poly { + Cow::Borrowed(_) => (), + Cow::Owned(_) => { + let poly = downcast_arc(poly.into_owned()); + + match poly.evaluations { + FieldType::Base(vec) => pool_b.return_to_pool(vec), + FieldType::Ext(vec) => pool_e.return_to_pool(vec), + _ => unreachable!(), + }; + } + }; + } +} + +// fn try_unwrap_and_downcast( +// arc: ArcMultilinearExtension<'_, E>, +// ) -> DenseMultilinearExtension { +// // Attempt to unwrap the Arc +// match Arc::try_unwrap(arc) { +// Ok(obj) => { +// // Check if the type matches +// if obj.type_id() == TypeId::of::() { +// // Safe to downcast +// let raw_ptr = &obj as *const dyn MyTrait as *const T; +// unsafe { +// // Take ownership of the concrete type +// let concrete: T = raw_ptr.read(); +// Ok(concrete) +// } +// } else { +// // Type mismatch +// Err(Arc::new(obj)) +// } // } -// }; +// Err(shared_arc) => Err(shared_arc), +// } // } pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( @@ -413,169 +457,174 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( .map(|_| E::BaseField::ZERO) .collect::>() }); - let poly = expr.evaluate_with_instance_pool::>>( - &|f| Cow::Borrowed(&fixed[f.0]), - &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), - &|i| Cow::Borrowed(&instance[i.0]), - &|scalar| { - let scalar: ArcMultilinearExtension = - Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ - scalar, - ])); - Cow::Owned(scalar) - }, - &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be acquired once for each power - let challenge = challenges[challenge_id as usize]; - let challenge: ArcMultilinearExtension = Arc::new( - DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ - challenge.pow([pow as u64]) * scalar + offset, - ]), - ); - Cow::Owned(challenge) - }, - &|cow_a, cow_b, pool_e, pool_b| { - let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); - let poly = commutative_op_mle_pair_pool!( - |a, b, res| { - match (a.len(), b.len()) { - (1, 1) => { - let poly: ArcMultilinearExtension<_> = Arc::new( - DenseMultilinearExtension::from_evaluation_vec_smart(0, vec![ - a[0] + b[0], - ]), - ); - Cow::Owned(poly) - } - (1, _) => { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..b.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[0] + b[i]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - (_, 1) => { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[i] + b[0]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - (_, _) => { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[i] + b[i]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - } - }, - pool_e, - pool_b - ); - poly - }, - &|a, b, pool_e, pool_b| { - commutative_op_mle_pair_pool!( - |a, b, res| { - match (a.len(), b.len()) { - (1, 1) => { - let poly: ArcMultilinearExtension<_> = Arc::new( - DenseMultilinearExtension::from_evaluation_vec_smart(0, vec![ - a[0] * b[0], - ]), - ); - Cow::Owned(poly) - } - (1, _) => { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[0] * b[i]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - (_, 1) => { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[i] * b[0]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - (_, _) => { - assert_eq!(a.len(), b.len()); - // we do the pointwise evaluation multiplication here without involving FFT - // the evaluations outside of range will be checked via sumcheck + identity polynomial - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { - let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = a[i] * b[i]; - }) - }); - Cow::Owned(res.into_inner().into_mle().into()) - } - } - }, - pool_e, - pool_b - ) - }, - &|x, a, b, pool_e, pool_b| { - op_mle_xa_b_pool!( - |x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - (0..n_threads).into_par_iter().for_each(|thread_id| { - (0..x.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - let _ = a * x[i] + b; - }) - }); - Cow::Owned(mutable_res.clone()) - }, - pool_e, - pool_b - ) - }, - &mut pool_e, - &mut pool_b, - ); + let poly = + expr.evaluate_with_instance_pool::>>( + &|f| Cow::Borrowed(&fixed[f.0]), + &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), + &|i| Cow::Borrowed(&instance[i.0]), + &|scalar| { + let scalar: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ + scalar, + ])); + Cow::Owned(scalar) + }, + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be acquired once for each power + let challenge = challenges[challenge_id as usize]; + let challenge: ArcMultilinearExtension = Arc::new( + DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ + challenge.pow([pow as u64]) * scalar + offset, + ]), + ); + Cow::Owned(challenge) + }, + &|cow_a, cow_b, pool_e, pool_b| { + let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); + let poly = + commutative_op_mle_pair_pool!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] + b[0]], + ), + ); + Cow::Owned(poly) + } + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..b.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[0] + b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] + b[0]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] + b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + } + }, + pool_e, + pool_b + ); + try_recycle_arcpoly(cow_a, pool_e, pool_b, len); + try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + poly + }, + &|cow_a, cow_b, pool_e, pool_b| { + let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); + let poly = + commutative_op_mle_pair_pool!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] * b[0]], + ), + ); + Cow::Owned(poly) + } + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[0] * b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] * b[0]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, _) => { + assert_eq!(a.len(), b.len()); + // we do the pointwise evaluation multiplication here without involving FFT + // the evaluations outside of range will be checked via sumcheck + identity polynomial + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] * b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + } + }, + pool_e, + pool_b + ); + try_recycle_arcpoly(cow_a, pool_e, pool_b, len); + try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + poly + }, + &|x, a, b, pool_e, pool_b| { + op_mle_xa_b_pool!( + |x, a, b| { + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); + (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..x.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + let _ = a * x[i] + b; + }) + }); + Cow::Owned(mutable_res.clone()) + }, + pool_e, + pool_b + ) + }, + &mut pool_e, + &mut pool_b, + ); match poly { Cow::Borrowed(poly) => poly.clone(), Cow::Owned(_) => poly.into_owned(), diff --git a/ceno_zkvm/src/uint/util.rs b/ceno_zkvm/src/uint/util.rs index 8cb4cc226..983382d6d 100644 --- a/ceno_zkvm/src/uint/util.rs +++ b/ceno_zkvm/src/uint/util.rs @@ -80,6 +80,7 @@ impl SimpleVecPool { // Return an item to the pool pub fn return_to_pool(&mut self, item: T) { + println!("got return!"); self.pool.push_back(item); } } From d20c90566d3059577ef3747e053c7a5745fffac1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 11:11:43 +0800 Subject: [PATCH 06/17] debug --- ceno_zkvm/src/scheme/utils.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 9ebc34bae..c2beb36a3 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -382,6 +382,7 @@ fn try_recycle_arcpoly( ) -> DenseMultilinearExtension { unsafe { // get the raw pointer from the Arc + assert_eq!(Arc::strong_count(&arc), 1); let raw = Arc::into_raw(arc); // cast the raw pointer to the desired concrete type let typed_ptr = raw as *const DenseMultilinearExtension; @@ -625,6 +626,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( &mut pool_e, &mut pool_b, ); + println!("??"); match poly { Cow::Borrowed(poly) => poly.clone(), Cow::Owned(_) => poly.into_owned(), From 099c9d54882f571eb835ce0ab489aa351d450948 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 11:35:02 +0800 Subject: [PATCH 07/17] scaled_sum also support vector pool --- ceno_zkvm/src/scheme/utils.rs | 15 ++++--- multilinear_extensions/src/mle.rs | 73 ++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c2beb36a3..bbdf329ef 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -6,7 +6,7 @@ use itertools::{Either, Itertools}; use multilinear_extensions::{ commutative_op_mle_pair, commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle_xa_b, op_mle_xa_b_pool, op_mle3_range, + op_mle_xa_b, op_mle_xa_b_pool, op_mle3_range, op_mle3_range_pool, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, }; @@ -603,21 +603,24 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( try_recycle_arcpoly(cow_b, pool_e, pool_b, len); poly }, - &|x, a, b, pool_e, pool_b| { + &|cow_x, cow_a, cow_b, pool_e, pool_b| { + let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref()); op_mle_xa_b_pool!( - |x, a, b| { + |x, a, b, res| { + let res = SyncUnsafeCell::new(res); assert_eq!(a.len(), 1); assert_eq!(b.len(), 1); let (a, b) = (a[0], b[0]); - (0..n_threads).into_par_iter().for_each(|thread_id| { + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); (0..x.len()) .skip(thread_id) .step_by(n_threads) .for_each(|i| { - let _ = a * x[i] + b; + *ptr.add(i) = a * x[i] + b; }) }); - Cow::Owned(mutable_res.clone()) + Cow::Owned(res.into_inner().into_mle().into()) }, pool_e, pool_b diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 82601dab4..4bf81c819 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1091,31 +1091,92 @@ macro_rules! op_mle_xa_b { }; } +#[macro_export] +macro_rules! op_mle3_range_pool { + ($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ + let $x = if let Some((start, offset)) = $x.evaluations_range() { + &$x_vec[start..][..offset] + } else { + &$x_vec[..] + }; + let $a = if let Some((start, offset)) = $a.evaluations_range() { + &$a_vec[start..][..offset] + } else { + &$a_vec[..] + }; + let $b = if let Some((start, offset)) = $b.evaluations_range() { + &$b_vec[start..][..offset] + } else { + &$b_vec[..] + }; + let $res = $res_vec; + assert_eq!($res.len(), $x.len()); + let $bb_out = $op; + $op_bb_out + }}; +} + /// deal with x * a + b #[macro_export] macro_rules! op_mle_xa_b_pool { - (|$x:ident, $a:ident, $b:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { + (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Base(a_vec), $crate::mle::FieldType::Base(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_b.borrow(); + op_mle3_range_pool!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Ext(a_vec), $crate::mle::FieldType::Base(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_e.borrow(); + op_mle3_range_pool!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Ext(a_vec), $crate::mle::FieldType::Ext(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_e.borrow(); + op_mle3_range_pool!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } (x, a, b) => unreachable!( "unmatched pattern {:?} {:?} {:?}", @@ -1125,8 +1186,8 @@ macro_rules! op_mle_xa_b_pool { ), } }; - (|$x:ident, $a:ident, $b:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { - op_mle_xa_b_pool!(|$x, $a, $b| $op, $pool_e, $pool_b, |out| out) + (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { + op_mle_xa_b_pool!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) }; } From a24fc4e8de3e0eaa0600195c5d0481f17da51c82 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 11:52:12 +0800 Subject: [PATCH 08/17] fix op_mle_xa_b_pool --- ceno_zkvm/src/scheme/prover.rs | 34 ++++---------------- ceno_zkvm/src/scheme/utils.rs | 58 ++++++---------------------------- 2 files changed, 16 insertions(+), 76 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 54d65f75c..55419dbd9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -13,9 +13,7 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, -}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -27,10 +25,10 @@ use crate::{ error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, wit_infer_by_expr_in_place, + wit_infer_by_expr, wit_infer_by_expr_in_pool, }, }, structs::{ @@ -240,16 +238,6 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference", profiling_3 = true); // main constraint: read/write record witness inference let record_span = entered_span!("record"); - // let records_wit: Vec> = cs - // .r_expressions - // .par_iter() - // .chain(cs.w_expressions.par_iter()) - // .chain(cs.lk_expressions.par_iter()) - // .map(|expr| { - // assert_eq!(expr.degree(), 1); - // wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) - // }) - // .collect(); let n_threads = max_usable_threads(); let records_wit: Vec> = cs .r_expressions @@ -258,16 +246,7 @@ impl> ZKVMProver { .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - let len = witnesses[0].evaluations().len(); - let data = (0..len) - .into_par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|_| E::ZERO) - .collect::>() - .into_mle() - .into(); - // data.into_mle().into() - wit_infer_by_expr_in_place(&[], &witnesses, pi, challenges, expr, n_threads, data) + wit_infer_by_expr_in_pool(&[], &witnesses, pi, challenges, expr, n_threads) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -547,7 +526,7 @@ impl> ZKVMProver { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { let expected_zero_poly = - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); + wit_infer_by_expr_in_pool(&[], &witnesses, pi, challenges, expr, n_threads); let top_100_errors = expected_zero_poly .get_base_field_vec() .iter() @@ -723,6 +702,7 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference let record_span = entered_span!("record"); + let n_threads = max_usable_threads(); let mut records_wit: Vec> = cs .r_table_expressions .par_iter() @@ -736,7 +716,7 @@ impl> ZKVMProver { .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr) + wit_infer_by_expr_in_pool(&fixed, &witnesses, pi, challenges, expr, n_threads) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index bbdf329ef..d6bbe5ec7 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,8 +1,8 @@ -use std::{any::TypeId, borrow::Cow, cell::SyncUnsafeCell, ops::Add, ptr, sync::Arc}; +use std::{borrow::Cow, cell::SyncUnsafeCell, ptr, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; -use itertools::{Either, Itertools}; +use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, @@ -348,25 +348,6 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( ) } -fn mutable_a_plus_c(n_threads: usize, a: &A, b: &[B], res: &mut [A]) -where - B: Sync + Send + Copy, - A: Sync + Send + Copy + Add + Default, -{ - unsafe { - let res = SyncUnsafeCell::new(res); - (0..n_threads).into_par_iter().for_each(|thread_id| { - let ptr = (*res.get()).as_mut_ptr(); - (0..b.len()) - .skip(thread_id) - .step_by(n_threads) - .for_each(|i| { - *ptr.add(i) = *a + b[i]; - }) - }); - } -} - use ff::Field; const POOL_CAP: usize = 12; @@ -410,38 +391,13 @@ fn try_recycle_arcpoly( } } -// fn try_unwrap_and_downcast( -// arc: ArcMultilinearExtension<'_, E>, -// ) -> DenseMultilinearExtension { -// // Attempt to unwrap the Arc -// match Arc::try_unwrap(arc) { -// Ok(obj) => { -// // Check if the type matches -// if obj.type_id() == TypeId::of::() { -// // Safe to downcast -// let raw_ptr = &obj as *const dyn MyTrait as *const T; -// unsafe { -// // Take ownership of the concrete type -// let concrete: T = raw_ptr.read(); -// Ok(concrete) -// } -// } else { -// // Type mismatch -// Err(Arc::new(obj)) -// } -// } -// Err(shared_arc) => Err(shared_arc), -// } -// } - -pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( +pub(crate) fn wit_infer_by_expr_in_pool<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], instance: &[ArcMultilinearExtension<'a, E>], challenges: &[E; N], expr: &Expression, n_threads: usize, - mutable_res: ArcMultilinearExtension<'a, E>, ) -> ArcMultilinearExtension<'a, E> { let len = witnesses[0].evaluations().len(); let mut pool_e: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { @@ -605,7 +561,7 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( }, &|cow_x, cow_a, cow_b, pool_e, pool_b| { let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref()); - op_mle_xa_b_pool!( + let poly = op_mle_xa_b_pool!( |x, a, b, res| { let res = SyncUnsafeCell::new(res); assert_eq!(a.len(), 1); @@ -624,7 +580,11 @@ pub(crate) fn wit_infer_by_expr_in_place<'a, E: ExtensionField, const N: usize>( }, pool_e, pool_b - ) + ); + try_recycle_arcpoly(cow_a, pool_e, pool_b, len); + try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + try_recycle_arcpoly(cow_x, pool_e, pool_b, len); + poly }, &mut pool_e, &mut pool_b, From 381f8ac71758326d07cc906a81f7d3393de69851 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 12:02:51 +0800 Subject: [PATCH 09/17] replace old wit_infer_by_expr --- ceno_zkvm/src/scheme/mock_prover.rs | 77 +++++++++++----- ceno_zkvm/src/scheme/prover.rs | 8 +- ceno_zkvm/src/scheme/utils.rs | 134 +++------------------------- multilinear_extensions/src/mle.rs | 39 -------- 4 files changed, 70 insertions(+), 188 deletions(-) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 037f8a130..e11ebdee5 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -23,7 +23,9 @@ use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; use itertools::{Itertools, enumerate, izip}; -use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; +use multilinear_extensions::{ + mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension, +}; use rand::thread_rng; use std::{ collections::{HashMap, HashSet}, @@ -426,6 +428,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { challenge: Option<[E; 2]>, lkm: Option, ) -> Result<(), Vec>> { + let n_threads = max_usable_threads(); let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), @@ -473,10 +476,12 @@ impl<'a, E: ExtensionField + Hash> MockProver { let (left, right) = expr.unpack_sum().unwrap(); let right = right.neg(); - let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); + let left_evaluated = + wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads); let left_evaluated = left_evaluated.get_base_field_vec(); - let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); + let right_evaluated = + wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads); let right_evaluated = right_evaluated.get_base_field_vec(); // left_evaluated.len() ?= right_evaluated.len() due to padding instance @@ -496,7 +501,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } else { // contains require_zero - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); + let expr_evaluated = + wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); let expr_evaluated = expr_evaluated.get_base_field_vec(); for (inst_id, element) in enumerate(expr_evaluated) { @@ -519,7 +525,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec @@ -550,7 +556,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .map(|expr| { // TODO generalized to all inst_id let inst_id = 0; - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr) + wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads) .get_base_field_vec()[inst_id] .to_canonical_u64() }) @@ -742,6 +748,7 @@ Hints: witnesses: &ZKVMWitnesses, pi: &PublicValues, ) { + let n_threads = max_usable_threads(); let instance = pi .to_vec::() .concat() @@ -815,10 +822,16 @@ Hints: .zip(cs.lk_expressions_namespace_map.clone().into_iter()) .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_input = - (wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr) - .get_ext_field_vec())[..num_rows] - .to_vec(); + let lk_input = (wit_infer_by_expr( + &fixed, + &witness, + &pi_mles, + &challenges, + expr, + n_threads, + ) + .get_ext_field_vec())[..num_rows] + .to_vec(); rom_inputs.entry(rom_type).or_default().push(( lk_input, circuit_name.clone(), @@ -838,10 +851,16 @@ Hints: .iter() .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_table = - wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values) - .get_ext_field_vec() - .to_vec(); + let lk_table = wit_infer_by_expr( + &fixed, + &witness, + &pi_mles, + &challenges, + &expr.values, + n_threads, + ) + .get_ext_field_vec() + .to_vec(); let multiplicity = wit_infer_by_expr( &fixed, @@ -849,6 +868,7 @@ Hints: &pi_mles, &challenges, &expr.multiplicity, + n_threads, ) .get_base_field_vec() .to_vec(); @@ -968,10 +988,16 @@ Hints: .zip_eq(cs.w_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let write_rlc_records = - (wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr) - .get_ext_field_vec())[..*num_rows] - .to_vec(); + let write_rlc_records = (wit_infer_by_expr( + fixed, + witness, + &pi_mles, + &challenges, + w_rlc_expr, + n_threads, + ) + .get_ext_field_vec())[..*num_rows] + .to_vec(); if $ram_type == RAMType::GlobalState { // w_exprs = [GlobalState, pc, timestamp] @@ -986,6 +1012,7 @@ Hints: &pi_mles, &challenges, expr, + n_threads, ); v.get_base_field_vec()[..*num_rows].to_vec() }) @@ -1030,10 +1057,16 @@ Hints: .zip_eq(cs.r_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let read_records = - wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr) - .get_ext_field_vec()[..*num_rows] - .to_vec(); + let read_records = wit_infer_by_expr( + fixed, + witness, + &pi_mles, + &challenges, + r_expr, + n_threads, + ) + .get_ext_field_vec()[..*num_rows] + .to_vec(); let mut records = vec![]; for (row, record) in enumerate(read_records) { // TODO: return error diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 55419dbd9..20c106c3d 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -28,7 +28,7 @@ use crate::{ constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, wit_infer_by_expr_in_pool, + wit_infer_by_expr, }, }, structs::{ @@ -246,7 +246,7 @@ impl> ZKVMProver { .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr_in_pool(&[], &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -526,7 +526,7 @@ impl> ZKVMProver { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { let expected_zero_poly = - wit_infer_by_expr_in_pool(&[], &witnesses, pi, challenges, expr, n_threads); + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads); let top_100_errors = expected_zero_poly .get_base_field_vec() .iter() @@ -716,7 +716,7 @@ impl> ZKVMProver { .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr_in_pool(&fixed, &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index d6bbe5ec7..c431c4912 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -4,12 +4,17 @@ use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair, commutative_op_mle_pair_pool, + commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle_xa_b, op_mle_xa_b_pool, op_mle3_range, op_mle3_range_pool, + op_mle_xa_b_pool, op_mle3_range_pool, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, }; + +use ff::Field; + +const POOL_CAP: usize = 12; + use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, @@ -234,124 +239,6 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( - fixed: &[ArcMultilinearExtension<'a, E>], - witnesses: &[ArcMultilinearExtension<'a, E>], - instance: &[ArcMultilinearExtension<'a, E>], - challenges: &[E; N], - expr: &Expression, -) -> ArcMultilinearExtension<'a, E> { - expr.evaluate_with_instance::>( - &|f| fixed[f.0].clone(), - &|witness_id| witnesses[witness_id as usize].clone(), - &|i| instance[i.0].clone(), - &|scalar| { - let scalar: ArcMultilinearExtension = - Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ - scalar, - ])); - scalar - }, - &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be acquired once for each power - let challenge = challenges[challenge_id as usize]; - let challenge: ArcMultilinearExtension = Arc::new( - DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ - challenge.pow([pow as u64]) * scalar + offset, - ]), - ); - challenge - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] + *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a + b[0]) - .collect(), - )), - (_, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a + b) - .collect(), - )), - } - }) - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] * *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a * b[0]) - .collect(), - )), - (_, _) => { - assert_eq!(a.len(), b.len()); - // we do the pointwise evaluation multiplication here without involving FFT - // the evaluations outside of range will be checked via sumcheck + identity polynomial - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a * b) - .collect(), - )) - } - } - }) - }, - &|x, a, b| { - op_mle_xa_b!(|x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(x.len()), - x.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|x| a * x + b) - .collect(), - )) - }) - }, - ) -} - -use ff::Field; - -const POOL_CAP: usize = 12; - fn try_recycle_arcpoly( poly: Cow>, pool_e: &mut SimpleVecPool>, @@ -391,7 +278,7 @@ fn try_recycle_arcpoly( } } -pub(crate) fn wit_infer_by_expr_in_pool<'a, E: ExtensionField, const N: usize>( +pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], instance: &[ArcMultilinearExtension<'a, E>], @@ -665,11 +552,10 @@ mod tests { expression::{Expression, ToExpr}, scheme::utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, + wit_infer_by_expr, }, }; - use super::wit_infer_by_expr; - #[test] fn test_infer_tower_witness() { type E = GoldilocksExt2; @@ -931,6 +817,7 @@ mod tests { &[], &[], &expr, + 1, ); res.get_base_field_vec(); } @@ -961,6 +848,7 @@ mod tests { &[], &[E::ONE], &expr, + 1, ); res.get_ext_field_vec(); } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 4bf81c819..96db7125c 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1052,45 +1052,6 @@ macro_rules! op_mle3_range { }}; } -/// deal with x * a + b -#[macro_export] -macro_rules! op_mle_xa_b { - (|$x:ident, $a:ident, $b:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { - match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { - ( - $crate::mle::FieldType::Base(x_vec), - $crate::mle::FieldType::Base(a_vec), - $crate::mle::FieldType::Base(b_vec), - ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) - } - ( - $crate::mle::FieldType::Base(x_vec), - $crate::mle::FieldType::Ext(a_vec), - $crate::mle::FieldType::Base(b_vec), - ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) - } - ( - $crate::mle::FieldType::Base(x_vec), - $crate::mle::FieldType::Ext(a_vec), - $crate::mle::FieldType::Ext(b_vec), - ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) - } - (x, a, b) => unreachable!( - "unmatched pattern {:?} {:?} {:?}", - x.variant_name(), - a.variant_name(), - b.variant_name() - ), - } - }; - (|$x:ident, $a:ident, $b:ident| $op:expr) => { - op_mle_xa_b!(|$x, $a, $b| $op, |out| out) - }; -} - #[macro_export] macro_rules! op_mle3_range_pool { ($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ From 3c362d9f8d26677644cb4d0a866f6f20f3bc3bff Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 12:19:11 +0800 Subject: [PATCH 10/17] slighly refactor --- ceno_zkvm/src/expression.rs | 2 +- ceno_zkvm/src/scheme/utils.rs | 5 +++-- ceno_zkvm/src/uint/util.rs | 37 ----------------------------------- ceno_zkvm/src/utils.rs | 37 ++++++++++++++++++++++++++++++++++- 4 files changed, 40 insertions(+), 41 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 05bb6a429..29b800303 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -18,7 +18,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::{ChallengeId, RAMType, WitnessId}, - uint::util::SimpleVecPool, + utils::SimpleVecPool, }; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c431c4912..40eb623e0 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -24,8 +24,9 @@ use rayon::{ }; use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, uint::util::SimpleVecPool, - utils::next_pow2_instance_padding, + expression::Expression, + scheme::constants::MIN_PAR_SIZE, + utils::{SimpleVecPool, next_pow2_instance_padding}, }; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector diff --git a/ceno_zkvm/src/uint/util.rs b/ceno_zkvm/src/uint/util.rs index 983382d6d..3ee46752a 100644 --- a/ceno_zkvm/src/uint/util.rs +++ b/ceno_zkvm/src/uint/util.rs @@ -1,5 +1,3 @@ -use std::collections::VecDeque; - // calculate the maximum number of combinations for stars and bars formula const fn max_combinations(degree: usize, num_cells: usize) -> usize { // compute factorial of n using usize @@ -50,41 +48,6 @@ pub(crate) const fn max_carry_word_for_multiplication(n: usize, m: usize, c: usi max_carry_value_gt as u64 } -pub struct SimpleVecPool { - pool: VecDeque, -} - -impl SimpleVecPool { - // Create a new pool with a factory closure - pub fn new T>(cap: usize, init: F) -> Self { - let mut pool = SimpleVecPool { - pool: VecDeque::new(), - }; - (0..cap).for_each(|_| { - pool.add(init()); - }); - pool - } - - // Add a new item to the pool - pub fn add(&mut self, item: T) { - self.pool.push_back(item); - } - - // Borrow an item from the pool, or create a new one if empty - pub fn borrow(&mut self) -> T { - self.pool - .pop_front() - .expect("pool is empty, consider increase cap size") - } - - // Return an item to the pool - pub fn return_to_pool(&mut self, item: T) { - println!("got return!"); - self.pool.push_back(item); - } -} - #[cfg(test)] mod tests { use crate::uint::util::{max_carry_word_for_multiplication, max_combinations}; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 8b7d8cbde..53d745194 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, fmt::Display, hash::Hash, panic::{self, PanicHookInfo}, @@ -229,3 +229,38 @@ where result } + +pub struct SimpleVecPool { + pool: VecDeque, +} + +impl SimpleVecPool { + // Create a new pool with a factory closure + pub fn new T>(cap: usize, init: F) -> Self { + let mut pool = SimpleVecPool { + pool: VecDeque::new(), + }; + (0..cap).for_each(|_| { + pool.add(init()); + }); + pool + } + + // Add a new item to the pool + pub fn add(&mut self, item: T) { + self.pool.push_back(item); + } + + // Borrow an item from the pool, or create a new one if empty + pub fn borrow(&mut self) -> T { + self.pool + .pop_front() + .expect("pool is empty, consider increase cap size") + } + + // Return an item to the pool + pub fn return_to_pool(&mut self, item: T) { + println!("got return!"); + self.pool.push_back(item); + } +} From a40f6bf231e66e92b2307d6ab6b43adfd591abed Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 14:50:23 +0800 Subject: [PATCH 11/17] make downcast work --- ceno_zkvm/src/scheme/utils.rs | 9 ++++++--- multilinear_extensions/src/mle.rs | 16 ++++++++++++++++ multilinear_extensions/src/virtual_poly_v2.rs | 3 +++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 40eb623e0..e1c6c0ff0 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -5,10 +5,10 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair_pool, - mle::{DenseMultilinearExtension, FieldType, IntoMLE}, + mle::{DenseMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, op_mle_xa_b_pool, op_mle3_range_pool, util::ceil_log2, - virtual_poly_v2::ArcMultilinearExtension, + virtual_poly_v2::{ArcMultilinearExtension, DynMultilinearExtension}, }; use ff::Field; @@ -267,7 +267,10 @@ fn try_recycle_arcpoly( match poly { Cow::Borrowed(_) => (), Cow::Owned(_) => { - let poly = downcast_arc(poly.into_owned()); + let poly = poly.into_owned(); + let poly: Box>> = + poly.dyn_try_unwrap().unwrap(); + let poly = Box::downcast::>(poly).unwrap(); match poly.evaluations { FieldType::Base(vec) => pool_b.return_to_pool(vec), diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 96db7125c..c1a66e152 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -62,6 +62,10 @@ pub trait MultilinearExtension: Send + Sync { _ => panic!("evaluation not in base field"), } } + + fn dyn_try_unwrap( + self: Arc, + ) -> Option>>; } impl Debug for dyn MultilinearExtension> { @@ -821,6 +825,12 @@ impl MultilinearExtension for DenseMultilinearExtension FieldType::Unreachable => unreachable!(), } } + + fn dyn_try_unwrap( + self: Arc, + ) -> Option>> { + Arc::try_unwrap(self).ok().map(|it| Box::new(it) as _) + } } pub struct RangedMultilinearExtension<'a, E: ExtensionField> { @@ -992,6 +1002,12 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi fn dup(&self, _num_instances: usize, _num_dups: usize) -> DenseMultilinearExtension { unimplemented!() } + + fn dyn_try_unwrap( + self: Arc, + ) -> Option>> { + unimplemented!() + } } #[macro_export] diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index 5d64d88bc..89ea6de1b 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -11,6 +11,9 @@ use serde::{Deserialize, Serialize}; pub type ArcMultilinearExtension<'a, E> = Arc> + 'a>; + +pub type DynMultilinearExtension<'a, E> = + dyn MultilinearExtension> + 'a; #[rustfmt::skip] /// A virtual polynomial is a sum of products of multilinear polynomials; /// where the multilinear polynomials are stored via their multilinear From fcb9d7422efeead3e41dd3f75f28af85e7cd8833 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 15:04:27 +0800 Subject: [PATCH 12/17] clean up arc downcasting logic --- ceno_zkvm/src/scheme/utils.rs | 55 +++++++++---------- multilinear_extensions/src/mle.rs | 6 +- multilinear_extensions/src/virtual_poly_v2.rs | 2 - 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e1c6c0ff0..12372b8df 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, cell::SyncUnsafeCell, ptr, sync::Arc}; +use std::{borrow::Cow, cell::SyncUnsafeCell, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; @@ -8,7 +8,7 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, op_mle_xa_b_pool, op_mle3_range_pool, util::ceil_log2, - virtual_poly_v2::{ArcMultilinearExtension, DynMultilinearExtension}, + virtual_poly_v2::ArcMultilinearExtension, }; use ff::Field; @@ -246,37 +246,37 @@ fn try_recycle_arcpoly( pool_b: &mut SimpleVecPool>, pool_expected_size_vec: usize, ) { - fn downcast_arc( - arc: ArcMultilinearExtension<'_, E>, - ) -> DenseMultilinearExtension { - unsafe { - // get the raw pointer from the Arc - assert_eq!(Arc::strong_count(&arc), 1); - let raw = Arc::into_raw(arc); - // cast the raw pointer to the desired concrete type - let typed_ptr = raw as *const DenseMultilinearExtension; - // manually drop the Arc without dropping the value - Arc::decrement_strong_count(raw); - // reconstruct the Arc with the concrete type - // Move the value out - ptr::read(typed_ptr) - } - } + // fn downcast_arc( + // arc: ArcMultilinearExtension<'_, E>, + // ) -> DenseMultilinearExtension { + // unsafe { + // // get the raw pointer from the Arc + // assert_eq!(Arc::strong_count(&arc), 1); + // let raw = Arc::into_raw(arc); + // // cast the raw pointer to the desired concrete type + // let typed_ptr = raw as *const DenseMultilinearExtension; + // // manually drop the Arc without dropping the value + // Arc::decrement_strong_count(raw); + // // reconstruct the Arc with the concrete type + // // Move the value out + // ptr::read(typed_ptr) + // } + // } let len = poly.evaluations().len(); if len == pool_expected_size_vec { match poly { Cow::Borrowed(_) => (), Cow::Owned(_) => { let poly = poly.into_owned(); - let poly: Box>> = - poly.dyn_try_unwrap().unwrap(); - let poly = Box::downcast::>(poly).unwrap(); - - match poly.evaluations { - FieldType::Base(vec) => pool_b.return_to_pool(vec), - FieldType::Ext(vec) => pool_e.return_to_pool(vec), - _ => unreachable!(), - }; + let poly = poly.dyn_try_unwrap().unwrap(); + + // let poly = Box::downcast::>(poly).unwrap(); + + // match poly.evaluations { + // FieldType::Base(vec) => pool_b.return_to_pool(vec), + // FieldType::Ext(vec) => pool_e.return_to_pool(vec), + // _ => unreachable!(), + // }; } }; } @@ -480,7 +480,6 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( &mut pool_e, &mut pool_b, ); - println!("??"); match poly { Cow::Borrowed(poly) => poly.clone(), Cow::Owned(_) => poly.into_owned(), diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index c1a66e152..576f4856d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -65,7 +65,7 @@ pub trait MultilinearExtension: Send + Sync { fn dyn_try_unwrap( self: Arc, - ) -> Option>>; + ) -> Option + Send + Sync>>; } impl Debug for dyn MultilinearExtension> { @@ -828,7 +828,7 @@ impl MultilinearExtension for DenseMultilinearExtension fn dyn_try_unwrap( self: Arc, - ) -> Option>> { + ) -> Option + Send + Sync>> { Arc::try_unwrap(self).ok().map(|it| Box::new(it) as _) } } @@ -1005,7 +1005,7 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi fn dyn_try_unwrap( self: Arc, - ) -> Option>> { + ) -> Option + Send + Sync>> { unimplemented!() } } diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index 89ea6de1b..184a726c9 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -12,8 +12,6 @@ use serde::{Deserialize, Serialize}; pub type ArcMultilinearExtension<'a, E> = Arc> + 'a>; -pub type DynMultilinearExtension<'a, E> = - dyn MultilinearExtension> + 'a; #[rustfmt::skip] /// A virtual polynomial is a sum of products of multilinear polynomials; /// where the multilinear polynomials are stored via their multilinear From a15a32d144568e6d4cf0a06144ea84ac7a9a14a5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 15:24:45 +0800 Subject: [PATCH 13/17] fix bug --- ceno_zkvm/src/scheme/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 12372b8df..624cf89bf 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -406,7 +406,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( let res = SyncUnsafeCell::new(res); (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { let ptr = (*res.get()).as_mut_ptr(); - (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + (0..b.len()).skip(thread_id).step_by(n_threads).for_each( |i| { *ptr.add(i) = a[0] * b[i]; }, From c7263ab18d807b04d1713a442b48c537e7ecc452 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 16:47:07 +0800 Subject: [PATCH 14/17] implement right way of arc unwrap --- ceno_zkvm/src/scheme/utils.rs | 18 +++++++----------- ceno_zkvm/src/utils.rs | 1 - multilinear_extensions/src/mle.rs | 16 ++++++---------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 624cf89bf..4d9f81d52 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -5,7 +5,7 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair_pool, - mle::{DenseMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, + mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b_pool, op_mle3_range_pool, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, @@ -13,7 +13,7 @@ use multilinear_extensions::{ use ff::Field; -const POOL_CAP: usize = 12; +const POOL_CAP: usize = 3; use rayon::{ iter::{ @@ -268,15 +268,11 @@ fn try_recycle_arcpoly( Cow::Borrowed(_) => (), Cow::Owned(_) => { let poly = poly.into_owned(); - let poly = poly.dyn_try_unwrap().unwrap(); - - // let poly = Box::downcast::>(poly).unwrap(); - - // match poly.evaluations { - // FieldType::Base(vec) => pool_b.return_to_pool(vec), - // FieldType::Ext(vec) => pool_e.return_to_pool(vec), - // _ => unreachable!(), - // }; + match poly.arc_try_unwrap().unwrap() { + FieldType::Base(vec) => pool_b.return_to_pool(vec), + FieldType::Ext(vec) => pool_e.return_to_pool(vec), + _ => unreachable!(), + }; } }; } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 53d745194..8d89431e3 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -260,7 +260,6 @@ impl SimpleVecPool { // Return an item to the pool pub fn return_to_pool(&mut self, item: T) { - println!("got return!"); self.pool.push_back(item); } } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 576f4856d..0c8d51393 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -63,9 +63,7 @@ pub trait MultilinearExtension: Send + Sync { } } - fn dyn_try_unwrap( - self: Arc, - ) -> Option + Send + Sync>>; + fn arc_try_unwrap(self: Arc) -> Option>; } impl Debug for dyn MultilinearExtension> { @@ -826,10 +824,10 @@ impl MultilinearExtension for DenseMultilinearExtension } } - fn dyn_try_unwrap( - self: Arc, - ) -> Option + Send + Sync>> { - Arc::try_unwrap(self).ok().map(|it| Box::new(it) as _) + fn arc_try_unwrap(self: Arc) -> Option> { + Arc::try_unwrap(self) + .ok() + .map(|it| it.evaluations_to_owned()) } } @@ -1003,9 +1001,7 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi unimplemented!() } - fn dyn_try_unwrap( - self: Arc, - ) -> Option + Send + Sync>> { + fn arc_try_unwrap(self: Arc) -> Option> { unimplemented!() } } From dd98c57739a3534222ec827e5690f97f2623bcf6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 17:21:20 +0800 Subject: [PATCH 15/17] shared vector pool across whole witness inference --- ceno_zkvm/src/expression.rs | 24 ++++++--- ceno_zkvm/src/scheme/mock_prover.rs | 77 ++++++++-------------------- ceno_zkvm/src/scheme/prover.rs | 78 +++++++++++++++++++++++------ ceno_zkvm/src/scheme/utils.rs | 48 ++++++++++++------ ceno_zkvm/src/utils.rs | 18 +++---- 5 files changed, 144 insertions(+), 101 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 29b800303..507507397 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -143,24 +143,34 @@ impl Expression { } #[allow(clippy::too_many_arguments)] - pub fn evaluate_with_instance_pool( + pub fn evaluate_with_instance_pool Vec, PF2: Fn() -> Vec>( &self, fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id instance: &impl Fn(Instance) -> T, constant: &impl Fn(E::BaseField) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, - sum: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, - product: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, + sum: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, + product: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, scaled: &impl Fn( T, T, T, - &mut SimpleVecPool>, - &mut SimpleVecPool>, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, ) -> T, - pool_e: &mut SimpleVecPool>, - pool_b: &mut SimpleVecPool>, + pool_e: &mut SimpleVecPool, PF1>, + pool_b: &mut SimpleVecPool, PF2>, ) -> T { match self { Expression::Fixed(f) => fixed_in(f), diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e11ebdee5..037f8a130 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -23,9 +23,7 @@ use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; use itertools::{Itertools, enumerate, izip}; -use multilinear_extensions::{ - mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension, -}; +use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ collections::{HashMap, HashSet}, @@ -428,7 +426,6 @@ impl<'a, E: ExtensionField + Hash> MockProver { challenge: Option<[E; 2]>, lkm: Option, ) -> Result<(), Vec>> { - let n_threads = max_usable_threads(); let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), @@ -476,12 +473,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { let (left, right) = expr.unpack_sum().unwrap(); let right = right.neg(); - let left_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads); + let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); let left_evaluated = left_evaluated.get_base_field_vec(); - let right_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads); + let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); let right_evaluated = right_evaluated.get_base_field_vec(); // left_evaluated.len() ?= right_evaluated.len() due to padding instance @@ -501,8 +496,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } else { // contains require_zero - let expr_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_base_field_vec(); for (inst_id, element) in enumerate(expr_evaluated) { @@ -525,7 +519,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec @@ -556,7 +550,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .map(|expr| { // TODO generalized to all inst_id let inst_id = 0; - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads) + wit_infer_by_expr(&[], wits_in, pi, &challenge, expr) .get_base_field_vec()[inst_id] .to_canonical_u64() }) @@ -748,7 +742,6 @@ Hints: witnesses: &ZKVMWitnesses, pi: &PublicValues, ) { - let n_threads = max_usable_threads(); let instance = pi .to_vec::() .concat() @@ -822,16 +815,10 @@ Hints: .zip(cs.lk_expressions_namespace_map.clone().into_iter()) .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_input = (wit_infer_by_expr( - &fixed, - &witness, - &pi_mles, - &challenges, - expr, - n_threads, - ) - .get_ext_field_vec())[..num_rows] - .to_vec(); + let lk_input = + (wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr) + .get_ext_field_vec())[..num_rows] + .to_vec(); rom_inputs.entry(rom_type).or_default().push(( lk_input, circuit_name.clone(), @@ -851,16 +838,10 @@ Hints: .iter() .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_table = wit_infer_by_expr( - &fixed, - &witness, - &pi_mles, - &challenges, - &expr.values, - n_threads, - ) - .get_ext_field_vec() - .to_vec(); + let lk_table = + wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values) + .get_ext_field_vec() + .to_vec(); let multiplicity = wit_infer_by_expr( &fixed, @@ -868,7 +849,6 @@ Hints: &pi_mles, &challenges, &expr.multiplicity, - n_threads, ) .get_base_field_vec() .to_vec(); @@ -988,16 +968,10 @@ Hints: .zip_eq(cs.w_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let write_rlc_records = (wit_infer_by_expr( - fixed, - witness, - &pi_mles, - &challenges, - w_rlc_expr, - n_threads, - ) - .get_ext_field_vec())[..*num_rows] - .to_vec(); + let write_rlc_records = + (wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr) + .get_ext_field_vec())[..*num_rows] + .to_vec(); if $ram_type == RAMType::GlobalState { // w_exprs = [GlobalState, pc, timestamp] @@ -1012,7 +986,6 @@ Hints: &pi_mles, &challenges, expr, - n_threads, ); v.get_base_field_vec()[..*num_rows].to_vec() }) @@ -1057,16 +1030,10 @@ Hints: .zip_eq(cs.r_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let read_records = wit_infer_by_expr( - fixed, - witness, - &pi_mles, - &challenges, - r_expr, - n_threads, - ) - .get_ext_field_vec()[..*num_rows] - .to_vec(); + let read_records = + wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr) + .get_ext_field_vec()[..*num_rows] + .to_vec(); let mut records = vec![]; for (row, record) in enumerate(read_records) { // TODO: return error diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 20c106c3d..f728cd5ad 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -13,7 +13,9 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -25,16 +27,18 @@ use crate::{ error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, + wit_infer_by_expr, wit_infer_by_expr_pool, }, }, structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, + utils::{ + SimpleVecPool, get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads, + }, virtual_polys::VirtualPolynomials, }; @@ -238,6 +242,21 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference", profiling_3 = true); // main constraint: read/write record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); let n_threads = max_usable_threads(); let records_wit: Vec> = cs .r_expressions @@ -246,7 +265,16 @@ impl> ZKVMProver { .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr_pool( + &[], + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -526,7 +554,7 @@ impl> ZKVMProver { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { let expected_zero_poly = - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads); + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); let top_100_errors = expected_zero_poly .get_base_field_vec() .iter() @@ -702,21 +730,41 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); let n_threads = max_usable_threads(); let mut records_wit: Vec> = cs .r_table_expressions - .par_iter() + .iter() .map(|r| &r.expr) - .chain(cs.w_table_expressions.par_iter().map(|w| &w.expr)) - .chain( - cs.lk_table_expressions - .par_iter() - .map(|lk| &lk.multiplicity), - ) - .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) + .chain(cs.w_table_expressions.iter().map(|w| &w.expr)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr_pool( + &fixed, + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 4d9f81d52..e0fb49840 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -7,14 +7,12 @@ use multilinear_extensions::{ commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b_pool, op_mle3_range_pool, - util::ceil_log2, + util::{ceil_log2, max_usable_threads}, virtual_poly_v2::ArcMultilinearExtension, }; use ff::Field; -const POOL_CAP: usize = 3; - use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, @@ -240,10 +238,10 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -fn try_recycle_arcpoly( +fn try_recycle_arcpoly Vec, PF2: Fn() -> Vec>( poly: Cow>, - pool_e: &mut SimpleVecPool>, - pool_b: &mut SimpleVecPool>, + pool_e: &mut SimpleVecPool, PF1>, + pool_b: &mut SimpleVecPool, PF2>, pool_expected_size_vec: usize, ) { // fn downcast_arc( @@ -284,25 +282,49 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( instance: &[ArcMultilinearExtension<'a, E>], challenges: &[E; N], expr: &Expression, - n_threads: usize, ) -> ArcMultilinearExtension<'a, E> { + let n_threads = max_usable_threads(); let len = witnesses[0].evaluations().len(); - let mut pool_e: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { (0..len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) .map(|_| E::ZERO) .collect::>() }); - let mut pool_b: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { (0..len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) .map(|_| E::BaseField::ZERO) .collect::>() }); + wit_infer_by_expr_pool( + fixed, + witnesses, + instance, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( + fixed: &[ArcMultilinearExtension<'a, E>], + witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[ArcMultilinearExtension<'a, E>], + challenges: &[E; N], + expr: &Expression, + n_threads: usize, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, +) -> ArcMultilinearExtension<'a, E> { + let len = witnesses[0].evaluations().len(); let poly = - expr.evaluate_with_instance_pool::>>( + expr.evaluate_with_instance_pool::>, _, _>( &|f| Cow::Borrowed(&fixed[f.0]), &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), &|i| Cow::Borrowed(&instance[i.0]), @@ -473,8 +495,8 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( try_recycle_arcpoly(cow_x, pool_e, pool_b, len); poly }, - &mut pool_e, - &mut pool_b, + pool_e, + pool_b, ); match poly { Cow::Borrowed(poly) => poly.clone(), @@ -816,7 +838,6 @@ mod tests { &[], &[], &expr, - 1, ); res.get_base_field_vec(); } @@ -847,7 +868,6 @@ mod tests { &[], &[E::ONE], &expr, - 1, ); res.get_ext_field_vec(); } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 8d89431e3..898e8364c 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -230,19 +230,19 @@ where result } -pub struct SimpleVecPool { +pub struct SimpleVecPool T> { pool: VecDeque, + factory_fn: F, } -impl SimpleVecPool { +impl T> SimpleVecPool { // Create a new pool with a factory closure - pub fn new T>(cap: usize, init: F) -> Self { - let mut pool = SimpleVecPool { + pub fn new(init: F) -> Self { + let pool = SimpleVecPool { pool: VecDeque::new(), + factory_fn: init, }; - (0..cap).for_each(|_| { - pool.add(init()); - }); + pool } @@ -253,9 +253,7 @@ impl SimpleVecPool { // Borrow an item from the pool, or create a new one if empty pub fn borrow(&mut self) -> T { - self.pool - .pop_front() - .expect("pool is empty, consider increase cap size") + self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)()) } // Return an item to the pool From d3a2f4bb685bf9d2ebc8567a685a77c3fa53edb1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 19:53:17 +0800 Subject: [PATCH 16/17] code cosmetics --- ceno_zkvm/src/scheme/utils.rs | 46 ++++++++++--------------------- ceno_zkvm/src/utils.rs | 19 +++++-------- multilinear_extensions/src/mle.rs | 27 ++++++------------ 3 files changed, 31 insertions(+), 61 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e0fb49840..af3ed8f8d 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -4,9 +4,9 @@ use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair_pool, + commutative_op_mle_pair, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle_xa_b_pool, op_mle3_range_pool, + op_mle_xa_b, op_mle3_range, util::{ceil_log2, max_usable_threads}, virtual_poly_v2::ArcMultilinearExtension, }; @@ -238,28 +238,12 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -fn try_recycle_arcpoly Vec, PF2: Fn() -> Vec>( +fn optional_arcpoly_unwrap_pushback( poly: Cow>, - pool_e: &mut SimpleVecPool, PF1>, - pool_b: &mut SimpleVecPool, PF2>, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, pool_expected_size_vec: usize, ) { - // fn downcast_arc( - // arc: ArcMultilinearExtension<'_, E>, - // ) -> DenseMultilinearExtension { - // unsafe { - // // get the raw pointer from the Arc - // assert_eq!(Arc::strong_count(&arc), 1); - // let raw = Arc::into_raw(arc); - // // cast the raw pointer to the desired concrete type - // let typed_ptr = raw as *const DenseMultilinearExtension; - // // manually drop the Arc without dropping the value - // Arc::decrement_strong_count(raw); - // // reconstruct the Arc with the concrete type - // // Move the value out - // ptr::read(typed_ptr) - // } - // } let len = poly.evaluations().len(); if len == pool_expected_size_vec { match poly { @@ -348,7 +332,7 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( &|cow_a, cow_b, pool_e, pool_b| { let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); let poly = - commutative_op_mle_pair_pool!( + commutative_op_mle_pair!( |a, b, res| { match (a.len(), b.len()) { (1, 1) => { @@ -401,14 +385,14 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); poly }, &|cow_a, cow_b, pool_e, pool_b| { let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); let poly = - commutative_op_mle_pair_pool!( + commutative_op_mle_pair!( |a, b, res| { match (a.len(), b.len()) { (1, 1) => { @@ -464,13 +448,13 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); poly }, &|cow_x, cow_a, cow_b, pool_e, pool_b| { let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref()); - let poly = op_mle_xa_b_pool!( + let poly = op_mle_xa_b!( |x, a, b, res| { let res = SyncUnsafeCell::new(res); assert_eq!(a.len(), 1); @@ -490,9 +474,9 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); - try_recycle_arcpoly(cow_x, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_x, pool_e, pool_b, len); poly }, pool_e, diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 898e8364c..79fa3c490 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -230,33 +230,28 @@ where result } +/// a simple vector pool +/// not support multi-thread access pub struct SimpleVecPool T> { pool: VecDeque, factory_fn: F, } impl T> SimpleVecPool { - // Create a new pool with a factory closure + // new pool with a factory closure pub fn new(init: F) -> Self { - let pool = SimpleVecPool { + SimpleVecPool { pool: VecDeque::new(), factory_fn: init, - }; - - pool - } - - // Add a new item to the pool - pub fn add(&mut self, item: T) { - self.pool.push_back(item); + } } - // Borrow an item from the pool, or create a new one if empty + // borrow an item from the pool, or create a new one if empty pub fn borrow(&mut self) -> T { self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)()) } - // Return an item to the pool + // push an item to the pool pub fn return_to_pool(&mut self, item: T) { self.pool.push_back(item); } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 0c8d51393..85947721d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1062,10 +1062,7 @@ macro_rules! op_mle3_range { let $bb_out = $op; $op_bb_out }}; -} -#[macro_export] -macro_rules! op_mle3_range_pool { ($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ let $x = if let Some((start, offset)) = $x.evaluations_range() { &$x_vec[start..][..offset] @@ -1091,7 +1088,7 @@ macro_rules! op_mle3_range_pool { /// deal with x * a + b #[macro_export] -macro_rules! op_mle_xa_b_pool { +macro_rules! op_mle_xa_b { (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { ( @@ -1100,7 +1097,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Base(b_vec), ) => { let res_vec = $pool_b.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1119,7 +1116,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Base(b_vec), ) => { let res_vec = $pool_e.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1138,7 +1135,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Ext(b_vec), ) => { let res_vec = $pool_e.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1160,7 +1157,7 @@ macro_rules! op_mle_xa_b_pool { } }; (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { - op_mle_xa_b_pool!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) + op_mle_xa_b!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) }; } @@ -1297,15 +1294,6 @@ macro_rules! commutative_op_mle_pair { _ => unreachable!(), } }; - (|$a:ident, $b:ident| $op:expr) => { - commutative_op_mle_pair!(|$a, $b| $op, |out| out) - }; -} - -/// macro support op(a, b) and tackles type matching internally. -/// Please noted that op must satisfy commutative rule w.r.t op(b, a) operand swap. -#[macro_export] -macro_rules! commutative_op_mle_pair_pool { (|$first:ident, $second:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$first.evaluations(), &$second.evaluations()) { ($crate::mle::FieldType::Base(base1), $crate::mle::FieldType::Base(base2)) => { @@ -1374,6 +1362,9 @@ macro_rules! commutative_op_mle_pair_pool { } }; (|$a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { - commutative_op_mle_pair_pool!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + commutative_op_mle_pair!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + }; + (|$a:ident, $b:ident| $op:expr) => { + commutative_op_mle_pair!(|$a, $b| $op, |out| out) }; } From d5632d587b970f01cb2b7d3d966a9b278b5124aa Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 19 Dec 2024 20:46:26 +0800 Subject: [PATCH 17/17] code cosmetics --- multilinear_extensions/src/virtual_poly_v2.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index 184a726c9..5d64d88bc 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -11,7 +11,6 @@ use serde::{Deserialize, Serialize}; pub type ArcMultilinearExtension<'a, E> = Arc> + 'a>; - #[rustfmt::skip] /// A virtual polynomial is a sum of products of multilinear polynomials; /// where the multilinear polynomials are stored via their multilinear