Skip to content

Commit dec49a9

Browse files
committed
Parallelize few stuff
1 parent fd30bd6 commit dec49a9

File tree

6 files changed

+86
-68
lines changed

6 files changed

+86
-68
lines changed

curves/src/pasta/fields/fp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256};
1+
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters};
22

33
pub type Fp = Fp256<FpParameters>;
44

curves/src/pasta/fields/fq.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256};
1+
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters};
22

33
pub type Fq = Fp256<FqParameters>;
44

5-
65
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
76
pub struct FqParameters;
87

kimchi/src/circuits/expr.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,24 +1021,24 @@ fn unnormalized_lagrange_evals<F: FftField>(
10211021

10221022
impl<'a, F: FftField> EvalResult<'a, F> {
10231023
fn init_<G: Sync + Send + Fn(usize) -> F>(
1024-
res_domain: (Domain, D<F>),
1024+
res_domain: (Domain, &D<F>),
10251025
g: G,
10261026
) -> Evaluations<F, D<F>> {
10271027
let n = res_domain.1.size();
10281028
Evaluations::<F, D<F>>::from_vec_and_domain(
10291029
(0..n).into_par_iter().map(g).collect(),
1030-
res_domain.1,
1030+
res_domain.1.clone(),
10311031
)
10321032
}
10331033

1034-
fn init<G: Sync + Send + Fn(usize) -> F>(res_domain: (Domain, D<F>), g: G) -> Self {
1034+
fn init<G: Sync + Send + Fn(usize) -> F>(res_domain: (Domain, &D<F>), g: G) -> Self {
10351035
Self::Evals {
10361036
domain: res_domain.0,
10371037
evals: Self::init_(res_domain, g),
10381038
}
10391039
}
10401040

1041-
fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1041+
fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
10421042
use EvalResult::*;
10431043
match (self, other) {
10441044
(Constant(x), Constant(y)) => Constant(x + y),
@@ -1074,7 +1074,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
10741074
.collect();
10751075
Evals {
10761076
domain: res_domain.0,
1077-
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
1077+
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1.clone()),
10781078
}
10791079
}
10801080
(
@@ -1151,13 +1151,13 @@ impl<'a, F: FftField> EvalResult<'a, F> {
11511151

11521152
Evals {
11531153
domain: res_domain.0,
1154-
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
1154+
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1.clone()),
11551155
}
11561156
}
11571157
}
11581158
}
11591159

1160-
fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1160+
fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
11611161
use EvalResult::*;
11621162
match (self, other) {
11631163
(Constant(x), Constant(y)) => Constant(x - y),
@@ -1275,7 +1275,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
12751275
}
12761276
}
12771277

1278-
fn pow<'b>(self, d: u64, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
1278+
fn pow<'b>(self, d: u64, res_domain: (Domain, &D<F>)) -> EvalResult<'b, F> {
12791279
let mut acc = EvalResult::Constant(F::one());
12801280
for i in (0..u64::BITS).rev() {
12811281
acc = acc.square(res_domain);
@@ -1288,7 +1288,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
12881288
acc
12891289
}
12901290

1291-
fn square<'b>(self, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
1291+
fn square<'b>(self, res_domain: (Domain, &D<F>)) -> EvalResult<'b, F> {
12921292
use EvalResult::*;
12931293
match self {
12941294
Constant(x) => Constant(x.square()),
@@ -1312,7 +1312,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
13121312
}
13131313
}
13141314

1315-
fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1315+
fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
13161316
use EvalResult::*;
13171317
match (self, other) {
13181318
(Constant(x), Constant(y)) => Constant(x * y),
@@ -1424,6 +1424,15 @@ fn get_domain<F: FftField>(d: Domain, env: &Environment<F>) -> D<F> {
14241424
}
14251425
}
14261426

1427+
fn get_domain_ref<'a, F: FftField>(d: Domain, env: &'a Environment<F>) -> &'a D<F> {
1428+
match d {
1429+
Domain::D1 => &env.domain.d1,
1430+
Domain::D2 => &env.domain.d2,
1431+
Domain::D4 => &env.domain.d4,
1432+
Domain::D8 => &env.domain.d8,
1433+
}
1434+
}
1435+
14271436
impl<F: Field> Expr<ConstantExpr<F>> {
14281437
/// Convenience function for constructing expressions from literal
14291438
/// field elements.
@@ -1713,13 +1722,13 @@ impl<F: FftField> Expr<F> {
17131722
assert_eq!(domain, d);
17141723
evals
17151724
}
1716-
EvalResult::Constant(x) => EvalResult::init_((d, get_domain(d, env)), |_| x),
1725+
EvalResult::Constant(x) => EvalResult::init_((d, get_domain_ref(d, env)), |_| x),
17171726
EvalResult::SubEvals {
17181727
evals,
17191728
domain: d_sub,
17201729
shift: s,
17211730
} => {
1722-
let res_domain = get_domain(d, env);
1731+
let res_domain = get_domain_ref(d, env);
17231732
let scale = (d_sub as usize) / (d as usize);
17241733
assert!(scale != 0);
17251734
EvalResult::init_((d, res_domain), |i| {
@@ -1738,7 +1747,7 @@ impl<F: FftField> Expr<F> {
17381747
where
17391748
'a: 'b,
17401749
{
1741-
let dom = (d, get_domain(d, env));
1750+
let dom = (d, get_domain_ref(d, env));
17421751

17431752
let res: EvalResult<'a, F> = match self {
17441753
Expr::Square(x) => match x.evaluations_helper(cache, d, env) {
@@ -1800,10 +1809,11 @@ impl<F: FftField> Expr<F> {
18001809
Expr::Pow(x, p) => {
18011810
let x = x.evaluations_helper(cache, d, env);
18021811
match x {
1803-
Either::Left(x) => x.pow(*p, (d, get_domain(d, env))),
1804-
Either::Right(id) => {
1805-
id.get_from(cache).unwrap().pow(*p, (d, get_domain(d, env)))
1806-
}
1812+
Either::Left(x) => x.pow(*p, (d, get_domain_ref(d, env))),
1813+
Either::Right(id) => id
1814+
.get_from(cache)
1815+
.unwrap()
1816+
.pow(*p, (d, get_domain_ref(d, env))),
18071817
}
18081818
}
18091819
Expr::VanishesOnZeroKnowledgeAndPreviousRows => EvalResult::SubEvals {
@@ -1837,7 +1847,7 @@ impl<F: FftField> Expr<F> {
18371847
}
18381848
}
18391849
Expr::BinOp(op, e1, e2) => {
1840-
let dom = (d, get_domain(d, env));
1850+
let dom = (d, get_domain_ref(d, env));
18411851
let f = |x: EvalResult<F>, y: EvalResult<F>| match op {
18421852
Op2::Mul => x.mul(y, dom),
18431853
Op2::Add => x.add(y, dom),

poly-commitment/src/combine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ fn affine_window_combine_one_endo_base<P: SWModelParameters>(
295295
) -> Vec<SWJAffine<P>> {
296296
fn assign<A: Copy>(dst: &mut [A], src: &[A]) {
297297
let n = dst.len();
298-
dst[..n].clone_from_slice(&src[..n]);
298+
dst[..n].copy_from_slice(&src[..n]);
299299
}
300300

301301
fn get_bit(limbs_lsb: &[u64], i: u64) -> u64 {

poly-commitment/src/evaluation_proof.rs

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -225,25 +225,31 @@ impl<G: CommitmentCurve> SRS<G> {
225225
let rand_l = <G::ScalarField as UniformRand>::rand(rng);
226226
let rand_r = <G::ScalarField as UniformRand>::rand(rng);
227227

228-
let l = call_msm(
229-
&[&g[0..n], &[self.h, u]].concat(),
230-
&[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]]
231-
.concat()
232-
.iter()
233-
.map(|x| x.into_repr())
234-
.collect::<Vec<_>>(),
235-
)
236-
.into_affine();
237-
238-
let r = call_msm(
239-
&[&g[n..], &[self.h, u]].concat(),
240-
&[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]]
241-
.concat()
242-
.iter()
243-
.map(|x| x.into_repr())
244-
.collect::<Vec<_>>(),
245-
)
246-
.into_affine();
228+
let call_l = || {
229+
call_msm(
230+
&[&g[0..n], &[self.h, u]].concat(),
231+
&[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]]
232+
.concat()
233+
.iter()
234+
.map(|x| x.into_repr())
235+
.collect::<Vec<_>>(),
236+
)
237+
.into_affine()
238+
};
239+
240+
let call_r = || {
241+
call_msm(
242+
&[&g[n..], &[self.h, u]].concat(),
243+
&[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]]
244+
.concat()
245+
.iter()
246+
.map(|x| x.into_repr())
247+
.collect::<Vec<_>>(),
248+
)
249+
.into_affine()
250+
};
251+
252+
let (l, r) = rayon::join(call_l, call_r);
247253

248254
lr.push((l, r));
249255
blinders.push((rand_l, rand_r));
@@ -258,29 +264,33 @@ impl<G: CommitmentCurve> SRS<G> {
258264
chals.push(u);
259265
chal_invs.push(u_inv);
260266

261-
a = a_hi
262-
.par_iter()
263-
.zip(a_lo)
264-
.map(|(&hi, &lo)| {
265-
// lo + u_inv * hi
266-
let mut res = hi;
267-
res *= u_inv;
268-
res += &lo;
269-
res
270-
})
271-
.collect();
272-
273-
b = b_lo
274-
.par_iter()
275-
.zip(b_hi)
276-
.map(|(&lo, &hi)| {
277-
// lo + u * hi
278-
let mut res = hi;
279-
res *= u;
280-
res += &lo;
281-
res
282-
})
283-
.collect();
267+
let call_a = || {
268+
a_hi.par_iter()
269+
.zip(a_lo)
270+
.map(|(&hi, &lo)| {
271+
// lo + u_inv * hi
272+
let mut res = hi;
273+
res *= u_inv;
274+
res += &lo;
275+
res
276+
})
277+
.collect()
278+
};
279+
280+
let call_b = || {
281+
b_lo.par_iter()
282+
.zip(b_hi)
283+
.map(|(&lo, &hi)| {
284+
// lo + u * hi
285+
let mut res = hi;
286+
res *= u;
287+
res += &lo;
288+
res
289+
})
290+
.collect()
291+
};
292+
293+
(a, b) = rayon::join(call_a, call_b);
284294

285295
g = G::combine_one_endo(endo_r, endo_q, &g_lo, &g_hi, u_pre);
286296
}

poly-commitment/src/msm.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,9 @@ pub fn call_msm_impl<G: CommitmentCurve>(
123123
// Safety: We're reinterpreting generic types to their concret types
124124
// proof-systems contains too much useless generic types
125125
// It's safe because we just asserted they are the same types
126-
let result = my_msm::<G::Params>(
127-
unsafe { std::mem::transmute(points) },
128-
unsafe { std::mem::transmute(scalars) },
129-
);
126+
let result = my_msm::<G::Params>(unsafe { std::mem::transmute(points) }, unsafe {
127+
std::mem::transmute(scalars)
128+
});
130129
unsafe { *(&result as *const _ as *const G::Projective) }
131130
}
132131

0 commit comments

Comments
 (0)