Skip to content

[WIP] add SPARK PCS #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 270 additions & 5 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,49 @@
use base64::read;
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
iter::successors,
sync::Arc,
};

use ff::Field;
use itertools::{Itertools, enumerate, izip};
use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension},
util::ceil_log2,
mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension},
op_mle,
util::{ceil_log2, create_uninit_vec, max_usable_threads},
virtual_poly::build_eq_x_r_vec,
virtual_poly_v2::ArcMultilinearExtension,
};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rayon::{
iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
},
slice::ParallelSlice,
};
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverStateV2},
};
use transcript::{ForkableTranscript, Transcript};

use crate::{
chip_handler::utils::power_sequence,
circuit_builder::SetTableAddrType,
error::ZKVMError,
expression::Instance,
scheme::{
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
utils::{
infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles,
wit_infer_by_expr,
},
},
structs::{
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
Point, PointAndEval, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey,
ZKVMWitnesses,
},
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
virtual_polys::VirtualPolynomials,
Expand Down Expand Up @@ -1297,3 +1308,257 @@ impl TowerProver {
(next_rt, proofs)
}
}

pub struct SparkProver;

pub struct SparkWitness<'a, E: ExtensionField> {
// sparse representative
pub row_index: ArcMultilinearExtension<'a, E>,
pub col_index: ArcMultilinearExtension<'a, E>,
pub value: ArcMultilinearExtension<'a, E>,
// offline memory check
// write_ts can be read_rs + 1
pub read_ts_row: ArcMultilinearExtension<'a, E>,
pub read_ts_col: ArcMultilinearExtension<'a, E>,
pub audit_ts_row: ArcMultilinearExtension<'a, E>,
pub audit_ts_col: ArcMultilinearExtension<'a, E>,
}

/// derive power sequence [1, base, base^2, ..., base^(len-1)] of base expression
fn power_sequence_scalar<'a, E: ExtensionField>(base: &'a E) -> impl Iterator<Item = E> + 'a {
successors(Some(E::ONE), move |prev| Some(prev.clone() * base.clone()))
}

fn spark_e_vec<E: ExtensionField>(index_vec: &ArcMultilinearExtension<E>, eq: &[E]) -> Vec<E> {
index_vec
.get_base_field_vec()
.iter()
.map(|index| eq[index.to_canonical_u64() as usize])
.collect_vec()
}

/// from spartan p.27 MemoryInTheHead
pub fn mem_in_the_head<E: ExtensionField>(
mem_log2_size: usize,
addr: &ArcMultilinearExtension<E>,
) -> (Vec<E>, Vec<E>, Vec<E>) {
let n = addr.evaluations().len();
let mut read_ts = vec![E::ZERO; n];
let mut write_ts = vec![E::ZERO; n];
let mut audit_ts = vec![E::ZERO; 1 << mem_log2_size];
let addr = addr.get_base_field_vec();
for i in 0..addr.len() {
let addr_u64 = addr[i].to_canonical_u64() as usize;
let r_ts = audit_ts[addr_u64];
let ts = r_ts + E::ONE;
read_ts[i] = r_ts;
write_ts[i] = ts;
audit_ts[addr_u64] = ts;
}
(read_ts, write_ts, audit_ts)
}

fn infer_tower_product_witness_from_poly<E: ExtensionField>(
poly: &DenseMultilinearExtension<E>,
) -> Vec<Vec<ArcMultilinearExtension<E>>> {
infer_tower_product_witness(
poly.num_vars(),
(0..NUM_FANIN)
.map(move |i| {
let ranged_poly: ArcMultilinearExtension<E> =
Arc::new(poly.get_ranged_mle(NUM_FANIN, i));
ranged_poly
})
.collect(),
NUM_FANIN,
)
}

fn mem_read_write_set_witness<'a, E: ExtensionField>(
n_threads: usize,
index: &ArcMultilinearExtension<'a, E>,
read_ts: &ArcMultilinearExtension<'a, E>,
e_vec: &DenseMultilinearExtension<E>,
pow_gamma: &[E],
) -> (DenseMultilinearExtension<E>, DenseMultilinearExtension<E>) {
let per_row_rw_chunk_size = e_vec.evaluations().len().div_ceil(n_threads);
let (read_set, write_set): (Vec<E>, Vec<E>) = e_vec
.get_ext_field_vec()
.par_chunks_exact(per_row_rw_chunk_size)
.zip_eq(
index
.get_base_field_vec()
.par_chunks_exact(per_row_rw_chunk_size),
)
.zip_eq(
read_ts
.get_base_field_vec()
.par_chunks_exact(per_row_rw_chunk_size),
)
.enumerate()
.flat_map(|(i, ((value, index), read_ts))| {
izip!(value, index, read_ts)
.map(|(value, index, read_ts)| {
// a, v, t
let tmp = pow_gamma[1] * value + index;
(tmp + read_ts, tmp + read_ts + E::ONE)
})
.collect_vec()
})
.unzip();
let (read_set_poly, write_set_poly) = (read_set.into_mle(), write_set.into_mle());
(read_set_poly, write_set_poly)
}

fn mem_init_audit_witness<E: ExtensionField>(
n_threads: usize,
mem: &[E],
audit_ts: &ArcMultilinearExtension<E>,
pow_gamma: &[E],
) -> (DenseMultilinearExtension<E>, DenseMultilinearExtension<E>) {
let per_row_init_chunk_size = mem.len().div_ceil(n_threads);
let (mem_init_row, mem_audit_row): (Vec<E>, Vec<E>) = mem
.par_chunks_exact(per_row_init_chunk_size)
.zip_eq(
audit_ts
.get_base_field_vec()
.par_chunks_exact(per_row_init_chunk_size),
)
.enumerate()
.flat_map(|(i, (rows, audit_ts))| {
let addr_offset = i * per_row_init_chunk_size;
izip!(rows, audit_ts)
.enumerate()
.map(|(i, (row, audit_ts))| {
// a, v, t
let tmp: E = pow_gamma[1] * row + E::BaseField::from((addr_offset + i) as u64);
(tmp, tmp + pow_gamma[2] * audit_ts)
})
.collect_vec()
})
.unzip();
let (mem_init_row_poly, mem_audit_row_poly) =
(mem_init_row.into_mle(), mem_audit_row.into_mle());

(mem_init_row_poly, mem_audit_row_poly)
}
pub struct SparkProof;

/// Tower Prover
impl SparkProver {
#[tracing::instrument(skip_all, name = "spark_prover_create_proof", level = "trace")]
pub fn create_proof<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>(
pp: PCS::ProverParam,
point_n_evals: &[PointAndEval<E>],
witnesses: &[SparkWitness<E>],
transcript: &mut Transcript<E>,
) -> (Point<E>, SparkProof) {
// product argument to prove offline memory check argument
let n_threads = max_usable_threads();
// let alpha_pows = power_sequence_scalar(alpha).take(3).collect_vec();
for (point_n_eval, witnesses) in izip!(point_n_evals, witnesses) {
// sanity check
let (r, eval) = (&point_n_eval.point, &point_n_eval.eval);
assert!(r.len().is_power_of_two());
let (r_row, r_col) = (&r[..r.len() / 2], &r[r.len() / 2..]);

let (mem_row, mem_col) = (build_eq_x_r_vec(r_row), build_eq_x_r_vec(r_col));
// prepare e_row, e_col
let e_row = spark_e_vec(&witnesses.row_index, &mem_row).into_mle();
let e_col = spark_e_vec(&witnesses.col_index, &mem_col).into_mle();
let e_vec = vec![e_row, e_col];

assert_eq!(
e_vec[0].evaluations().len(),
witnesses.row_index.evaluations().len()
);
assert_eq!(
e_vec[1].evaluations().len(),
witnesses.col_index.evaluations().len()
);

let commit_with_data = PCS::batch_commit_and_write(&pp, &e_vec, transcript).unwrap();
let gamma = transcript.get_and_append_challenge(b"spark_omc_hash");
let pow_gamma = power_sequence_scalar(&gamma.elements).take(3).collect_vec();

let ((mem_init_row_poly, mem_audit_row_poly), (mem_init_col_poly, mem_audit_col_poly)) = (
mem_init_audit_witness(n_threads, &mem_row, &witnesses.audit_ts_row, &pow_gamma),
mem_init_audit_witness(n_threads, &mem_col, &witnesses.audit_ts_row, &pow_gamma),
);

let (
mem_init_last_layer_row,
mem_audit_last_layer_row,
mem_init_last_layer_col,
mem_audit_last_layer_col,
) = (
infer_tower_product_witness_from_poly(&mem_init_row_poly),
infer_tower_product_witness_from_poly(&mem_audit_row_poly),
infer_tower_product_witness_from_poly(&mem_init_col_poly),
infer_tower_product_witness_from_poly(&mem_audit_col_poly),
);

let ((read_set_row, write_set_row), (read_set_col, write_set_col)) = (
mem_read_write_set_witness(
n_threads,
&witnesses.row_index,
&witnesses.read_ts_row,
&e_vec[0],
&pow_gamma,
),
mem_read_write_set_witness(
n_threads,
&witnesses.row_index,
&witnesses.read_ts_row,
&e_vec[0],
&pow_gamma,
),
);

let (
mem_read_last_layer_row,
mem_write_last_layer_row,
mem_read_last_layer_col,
mem_write_last_layer_col,
) = (
infer_tower_product_witness_from_poly(&read_set_row),
infer_tower_product_witness_from_poly(&write_set_row),
infer_tower_product_witness_from_poly(&read_set_col),
infer_tower_product_witness_from_poly(&write_set_col),
);

let (rt_tower, tower_proof) = TowerProver::create_proof(
vec![
TowerProverSpec {
witness: mem_init_last_layer_row,
},
TowerProverSpec {
witness: mem_audit_last_layer_row,
},
TowerProverSpec {
witness: mem_read_last_layer_row,
},
TowerProverSpec {
witness: mem_write_last_layer_row,
},
TowerProverSpec {
witness: mem_init_last_layer_col,
},
TowerProverSpec {
witness: mem_audit_last_layer_col,
},
TowerProverSpec {
witness: mem_read_last_layer_col,
},
TowerProverSpec {
witness: mem_write_last_layer_col,
},
],
vec![],
NUM_FANIN,
transcript,
);
}
unimplemented!()
}
}
Loading
Loading