Skip to content

Commit ca7e18e

Browse files
committed
feat: use cache BR for bit_reverse
feat: parallel bit_reverse
1 parent a9441a1 commit ca7e18e

File tree

2 files changed

+78
-19
lines changed

2 files changed

+78
-19
lines changed

starky/src/fft_p.rs

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,67 @@ use crate::constant::{get_max_workers, MAX_OPS_PER_THREAD, MIN_OPS_PER_THREAD, S
33
use crate::fft_worker::{fft_block, interpolate_prepare_block};
44
use crate::helper::log2_any;
55
use crate::traits::FieldExtension;
6+
use crate::utils::parallells::parallelize;
67
use core::cmp::min;
8+
use lazy_static::lazy_static;
79
use rayon::prelude::*;
10+
use std::collections::HashMap;
11+
use std::sync::Mutex;
812

13+
lazy_static! {
14+
static ref BR_CACHE: Mutex<HashMap<usize, Vec<usize>>> = Mutex::new(HashMap::new());
15+
}
916
pub fn BR(x: usize, domain_pow: usize) -> usize {
1017
assert!(domain_pow <= 32);
11-
let mut x = x;
12-
x = (x >> 16) | (x << 16);
13-
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
14-
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
15-
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
16-
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
18+
let cal = |x: usize, domain_pow: usize| -> usize {
19+
let mut x = x;
20+
x = (x >> 16) | (x << 16);
21+
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
22+
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
23+
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
24+
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
25+
};
26+
27+
// get cache by domain_pow
28+
let mut map = BR_CACHE.lock().unwrap();
29+
let mut cache = if map.contains_key(&domain_pow) {
30+
map.remove(&domain_pow).unwrap() // get and remove the old values.
31+
} else {
32+
vec![]
33+
};
34+
// check if need append more to cache
35+
let cache_len = cache.len();
36+
let n = 1 << domain_pow;
37+
if cache_len <= n || cache_len < x {
38+
let end = if n >= x { n } else { x };
39+
// todo parallel
40+
for i in cache_len..=end {
41+
let a = cal(i, domain_pow);
42+
cache.push(a);
43+
}
44+
}
45+
let res = cache[x];
46+
// update map with cache
47+
map.insert(domain_pow, cache);
48+
res
49+
}
50+
fn BRs(start: usize, end: usize, domain_pow: usize) -> Vec<usize> {
51+
assert!(end > start);
52+
// 1. obtain a useless one to precompute the cache.
53+
// to make sure the cache existed and its len >= end.
54+
BR(end, domain_pow);
55+
56+
// 2. get cache by domain_pow
57+
let map = BR_CACHE.lock().unwrap();
58+
let cache = if map.contains_key(&domain_pow) {
59+
map.get(&domain_pow).unwrap()
60+
} else {
61+
// double check
62+
BR(end, domain_pow);
63+
map.get(&domain_pow).unwrap()
64+
};
65+
66+
(start..end).map(|i| cache[i]).collect()
1767
}
1868

1969
pub fn transpose<F: FieldExtension>(
@@ -44,11 +94,14 @@ pub fn bit_reverse<F: FieldExtension>(
4494
nbits: usize,
4595
) {
4696
let n = 1 << nbits;
47-
for i in 0..n {
48-
let ri = BR(i, nbits);
49-
for k in 0..n_pols {
50-
buffdst[i * n_pols + k] = buffsrc[ri * n_pols + k];
51-
}
97+
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.
98+
99+
let len = n * n_pols;
100+
assert_eq!(len, buffdst.len());
101+
for j in 0..len {
102+
let i = j / n_pols;
103+
let k = j % n_pols;
104+
buffdst[j] = buffsrc[ris[i] * n_pols + k];
52105
}
53106
}
54107

@@ -59,9 +112,10 @@ pub fn interpolate_bit_reverse<F: FieldExtension>(
59112
nbits: usize,
60113
) {
61114
let n = 1 << nbits;
115+
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.
116+
62117
for i in 0..n {
63-
let ri = BR(i, nbits);
64-
let rii = (n - ri) % n;
118+
let rii = (n - ris[i]) % n;
65119
for k in 0..n_pols {
66120
buffdst[i * n_pols + k] = buffsrc[rii * n_pols + k];
67121
}
@@ -76,12 +130,15 @@ pub fn inv_bit_reverse<F: FieldExtension>(
76130
) {
77131
let n = 1 << nbits;
78132
let n_inv = F::inv(&F::from(n));
79-
for i in 0..n {
80-
let ri = BR(i, nbits);
81-
let rii = (n - ri) % n;
82-
for p in 0..n_pols {
83-
buffdst[i * n_pols + p] = buffsrc[rii * n_pols + p] * n_inv;
84-
}
133+
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.
134+
135+
let len = n * n_pols;
136+
assert_eq!(len, buffdst.len());
137+
for j in 0..len {
138+
let i = j / n_pols;
139+
let k = j % n_pols;
140+
let rii = (n - ris[i]) % n;
141+
buffdst[j] = buffsrc[rii * n_pols + k] * n_inv;
85142
}
86143
}
87144

starky/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#![allow(clippy::needless_range_loop)]
2+
#![allow(dead_code)]
23

34
pub mod errors;
45
pub mod polsarray;
@@ -31,6 +32,7 @@ pub mod poseidon_bls12381_opt;
3132

3233
pub mod merklehash;
3334
pub mod merklehash_bls12381;
35+
3436
pub mod merklehash_bn128;
3537

3638
mod digest;

0 commit comments

Comments
 (0)