Skip to content

Commit d3920d3

Browse files
committed
feat: use cache BR for bit_reverse
1 parent 7534ccf commit d3920d3

File tree

1 file changed

+71
-46
lines changed

1 file changed

+71
-46
lines changed

starky/src/fft_p.rs

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,67 @@ use crate::constant::{get_max_workers, MAX_OPS_PER_THREAD, MIN_OPS_PER_THREAD, S
33
use crate::fft_worker::{fft_block, interpolate_prepare_block};
44
use crate::helper::log2_any;
55
use crate::traits::FieldExtension;
6+
use crate::utils::parallells::parallelize;
67
use core::cmp::min;
8+
use lazy_static::lazy_static;
79
use rayon::prelude::*;
8-
use rayon::{current_num_threads, scope};
10+
use std::collections::HashMap;
11+
use std::sync::Mutex;
912

13+
lazy_static! {
14+
static ref BR_CACHE: Mutex<HashMap<usize, Vec<usize>>> = Mutex::new(HashMap::new());
15+
}
1016
pub fn BR(x: usize, domain_pow: usize) -> usize {
1117
assert!(domain_pow <= 32);
12-
let mut x = x;
13-
x = (x >> 16) | (x << 16);
14-
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
15-
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
16-
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
17-
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
18+
let cal = |x: usize, domain_pow: usize| -> usize {
19+
let mut x = x;
20+
x = (x >> 16) | (x << 16);
21+
x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8);
22+
x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4);
23+
x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2);
24+
(((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32 - domain_pow)
25+
};
26+
27+
// get cache by domain_pow
28+
let mut map = BR_CACHE.lock().unwrap();
29+
let mut cache = if map.contains_key(&domain_pow) {
30+
map.remove(&domain_pow).unwrap() // get and remove the old values.
31+
} else {
32+
vec![]
33+
};
34+
// check if need append more to cache
35+
let cache_len = cache.len();
36+
let n = 1 << domain_pow;
37+
if cache_len <= n || cache_len < x {
38+
let end = if n >= x { n } else { x };
39+
// todo parallel
40+
for i in cache_len..=end {
41+
let a = cal(i, domain_pow);
42+
cache.push(a);
43+
}
44+
}
45+
let res = cache[x];
46+
// update map with cache
47+
map.insert(domain_pow, cache);
48+
res
49+
}
50+
fn BRs(start: usize, end: usize, domain_pow: usize) -> Vec<usize> {
51+
assert!(end > start);
52+
// 1. obtain a useless one to precompute the cache.
53+
// to make sure the cache existed and its len >= end.
54+
BR(end, domain_pow);
55+
56+
// 2. get cache by domain_pow
57+
let map = BR_CACHE.lock().unwrap();
58+
let cache = if map.contains_key(&domain_pow) {
59+
map.get(&domain_pow).unwrap()
60+
} else {
61+
// double check
62+
BR(end, domain_pow);
63+
map.get(&domain_pow).unwrap()
64+
};
65+
66+
(start..end).map(|i| cache[i]).collect()
1867
}
1968

2069
pub fn transpose<F: FieldExtension>(
@@ -45,26 +94,16 @@ pub fn bit_reverse<F: FieldExtension>(
4594
nbits: usize,
4695
) {
4796
let n = 1 << nbits;
48-
let num_threads = current_num_threads();
49-
let mut chunk_size = n / num_threads;
50-
if chunk_size < num_threads {
51-
chunk_size = 1;
97+
let ris = BRs(0, n, nbits); // move it outside the loop. obtain it from cache.
98+
99+
// todo parallel. Does the buffdst
100+
let len = n * n_pols;
101+
assert_eq!(len, buffdst.len()); // is ok?
102+
for j in 0..len {
103+
let i = j / n_pols;
104+
let k = j % n_pols;
105+
buffdst[j] = buffsrc[ris[i] * n_pols + k];
52106
}
53-
54-
scope(|scope| {
55-
// for (chunk_num, v) in buffdst.chunks_mut(chunk).enumerate() {
56-
for chunk_num in 0..num_threads {
57-
scope.spawn(move |_| {
58-
let start = chunk_num * chunk_size;
59-
for i in start..(start + chunk_size) {
60-
let ri = BR(i, nbits);
61-
for k in 0..n_pols {
62-
buffdst[i * n_pols + k] = buffsrc[ri * n_pols + k];
63-
}
64-
}
65-
});
66-
}
67-
});
68107
}
69108

70109
pub fn interpolate_bit_reverse<F: FieldExtension>(
@@ -91,27 +130,13 @@ pub fn inv_bit_reverse<F: FieldExtension>(
91130
) {
92131
let n = 1 << nbits;
93132
let n_inv = F::inv(&F::from(n));
94-
let num_threads = current_num_threads();
95-
let mut chunk_size = n / num_threads;
96-
if chunk_size < num_threads {
97-
chunk_size = 1;
98-
}
99-
100-
scope(|scope| {
101-
// for (chunk_num, v) in buffdst.chunks_mut(chunk).enumerate() {
102-
for chunk_num in 0..num_threads {
103-
scope.spawn(move |_| {
104-
let start = chunk_num * chunk_size;
105-
for i in start..(start + chunk_size) {
106-
let ri = BR(i, nbits);
107-
let rii = (n - ri) % n;
108-
for p in 0..n_pols {
109-
buffdst[i * n_pols + p] = buffsrc[rii * n_pols + p] * n_inv;
110-
}
111-
}
112-
});
133+
for i in 0..n {
134+
let ri = BR(i, nbits);
135+
let rii = (n - ri) % n;
136+
for p in 0..n_pols {
137+
buffdst[i * n_pols + p] = buffsrc[rii * n_pols + p] * n_inv;
113138
}
114-
});
139+
}
115140
}
116141

117142
pub fn interpolate_prepare<F: FieldExtension>(buff: &mut Vec<F>, n_pols: usize, nbits: usize) {

0 commit comments

Comments
 (0)