@@ -2,15 +2,25 @@ use reikna::totient::totient;
22use reikna:: factor:: quick_factorize;
33use std:: collections:: HashMap ;
44
5- // Modular arithmetic functions using i64
5+ /// Modular arithmetic functions using i64
66fn mod_add ( a : i64 , b : i64 , p : i64 ) -> i64 {
77 ( a + b) % p
88}
99
10+ /// Modular multiplication
1011fn mod_mul ( a : i64 , b : i64 , p : i64 ) -> i64 {
1112 ( a * b) % p
1213}
1314
15+ /// Modular exponentiation
16+ /// # Arguments
17+ ///
18+ /// * `base` - Base of the exponentiation.
19+ /// * `exp` - Exponent.
20+ /// * `p` - Prime modulus for the operations.
21+ ///
22+ /// # Returns
23+ /// The result of the exponentiation modulo `p`.
1424pub fn mod_exp ( mut base : i64 , mut exp : i64 , p : i64 ) -> i64 {
1525 let mut result = 1 ;
1626 base %= p;
@@ -24,6 +34,14 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
2434 result
2535}
2636
37+ /// Extended Euclidean algorithm
38+ /// # Arguments
39+ ///
40+ /// * `a` - First number.
41+ /// * `b` - Second number.
42+ ///
43+ /// # Returns
44+ /// A tuple with the greatest common divisor and the Bézout coefficients.
2745fn extended_gcd ( a : i64 , b : i64 ) -> ( i64 , i64 , i64 ) {
2846 if b == 0 {
2947 ( a, 1 , 0 ) // gcd, x, y
@@ -33,15 +51,38 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
3351 }
3452}
3553
36- pub fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
54+ /// Compute the modular inverse of a modulo modulus
55+ fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
3756 let ( gcd, x, _) = extended_gcd ( a, modulus) ;
3857 if gcd != 1 {
3958 panic ! ( "{} and {} are not coprime, no inverse exists" , a, modulus) ;
4059 }
4160 ( x % modulus + modulus) % modulus // Ensure a positive result
4261}
4362
44- // Compute n-th root of unity (omega) for p not necessarily prime
63+ /// Compute n-th root of unity (omega) for p not necessarily prime
64+ /// # Arguments
65+ ///
66+ /// * `modulus` - Modulus. n must divide each prime power factor.
67+ /// * `n` - Order of the root of unity.
68+ ///
69+ /// # Returns
70+ /// The n-th root of unity modulo `modulus`.
71+ ///
72+ /// # Examples
73+ ///
74+ /// ```
75+ /// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity.
76+ /// let modulus = 17 * 17;
77+ /// let n = 8;
78+ /// let omega = ntt::omega(modulus, n);
79+ /// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
80+ ///
81+ /// // For modulus = 17*41*73, we compute and verify an 8th root of unity.
82+ /// let modulus = 17*41*73;
83+ /// let omega = ntt::omega(modulus, n);
84+ /// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
85+ /// ```
4586pub fn omega ( modulus : i64 , n : usize ) -> i64 {
4687 let factors = factorize ( modulus as i64 ) ;
4788 if factors. len ( ) == 1 {
@@ -56,7 +97,29 @@ pub fn omega(modulus: i64, n: usize) -> i64 {
5697 }
5798}
5899
59- // Forward transform using NTT, output bit-reversed
100+ /// Forward transform using NTT, output bit-reversed
101+ /// # Arguments
102+ ///
103+ /// * `a` - Input vector.
104+ /// * `omega` - Primitive root of unity modulo `p`.
105+ /// * `n` - Length of the input vector and the result.
106+ /// * `p` - Prime modulus for the operations.
107+ ///
108+ /// # Returns
109+ /// A vector representing the NTT of the input vector.
110+ ///
111+ /// # Examples
112+ ///
113+ /// ```
114+ /// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
115+ /// let n: usize = 8; // Length of the NTT (must be a power of 2)
116+ /// let omega = ntt::omega(modulus, n); // n-th root of unity
117+ /// let mut a = vec![1, 2, 3, 4];
118+ /// a.resize(n, 0);
119+ /// // Perform the forward NTT
120+ /// let a_ntt = ntt::ntt(&a, omega, n, modulus);
121+ /// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15];
122+ /// assert_eq!(a_ntt, a_ntt_expected);
60123pub fn ntt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
61124 let mut result = a. to_vec ( ) ;
62125 let mut step = n/2 ;
@@ -77,7 +140,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
77140 result
78141}
79142
80- // Inverse transform using INTT, input bit-reversed
143+ /// Inverse transform using INTT, input bit-reversed
144+ /// # Arguments
145+ ///
146+ /// * `a` - Input vector (bit-reversed).
147+ /// * `omega` - Primitive root of unity modulo `p`.
148+ /// * `n` - Length of the input vector and the result.
149+ /// * `p` - Prime modulus for the operations.
150+ ///
151+ /// # Returns
152+ /// A vector representing the inverse NTT of the input vector.
81153pub fn intt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
82154 let omega_inv = mod_inv ( omega, p) ;
83155 let n_inv = mod_inv ( n as i64 , p) ;
@@ -103,7 +175,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
103175 . collect ( )
104176}
105177
106- // Naive polynomial multiplication
178+ /// Naive polynomial multiplication
179+ /// # Arguments
180+ ///
181+ /// * `a` - First polynomial (as a vector of coefficients).
182+ /// * `b` - Second polynomial (as a vector of coefficients).
183+ /// * `n` - Length of the polynomials and the result.
184+ /// * `p` - Prime modulus for the operations.
185+ ///
186+ /// # Returns
187+ /// A vector representing the polynomial product modulo `p`.
107188pub fn polymul ( a : & Vec < i64 > , b : & Vec < i64 > , n : i64 , p : i64 ) -> Vec < i64 > {
108189 let mut result = vec ! [ 0 ; n as usize ] ;
109190 for i in 0 ..a. len ( ) {
@@ -145,7 +226,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6
145226 c
146227}
147228
148- /// Compute the prime factorization of `n` (with multiplicities).
229+ /// Compute the prime factorization of `n` (with multiplicities)
230+ /// Uses reikna::quick_factorize internally
231+ /// # Arguments
232+ ///
233+ /// * `n` - Number to factorize.
234+ ///
235+ /// # Returns
236+ /// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
149237fn factorize ( n : i64 ) -> HashMap < i64 , u32 > {
150238 let mut factors = HashMap :: new ( ) ;
151239 for factor in quick_factorize ( n as u64 ) {
@@ -155,6 +243,23 @@ fn factorize(n: i64) -> HashMap<i64, u32> {
155243}
156244
157245/// Fast computation of a primitive root mod p^e
246+ /// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p
247+ /// # Arguments
248+ ///
249+ /// * `p` - Prime modulus.
250+ /// * `e` - Exponent.
251+ ///
252+ /// # Returns
253+ /// A primitive root modulo `p^e`.
254+ ///
255+ /// # Examples
256+ ///
257+ /// ```
258+ /// // For p = 17 and e = 2, we compute a primitive root modulo 289.
259+ /// let p = 17;
260+ /// let e = 2;
261+ /// let g = ntt::primitive_root(p, e);
262+ /// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1);
158263pub fn primitive_root ( p : i64 , e : u32 ) -> i64 {
159264 let g = primitive_root_mod_p ( p) ;
160265 let mut g_lifted = g; // Lift it to p^e
@@ -167,6 +272,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 {
167272}
168273
169274/// Finds a primitive root modulo a prime p
275+ /// # Arguments
276+ ///
277+ /// * `p` - Prime modulus.
278+ ///
279+ /// # Returns
280+ /// A primitive root modulo `p`.
170281fn primitive_root_mod_p ( p : i64 ) -> i64 {
171282 let phi = p - 1 ;
172283 let factors = factorize ( phi) ; // Reusing factorize to get both prime factors and multiplicities
@@ -179,7 +290,16 @@ fn primitive_root_mod_p(p: i64) -> i64 {
179290 0 // Should never happen
180291}
181292
182- // the Chinese remainder theorem for two moduli
293+ /// the Chinese remainder theorem for two moduli
294+ /// # Arguments
295+ ///
296+ /// * `a1` - First residue.
297+ /// * `n1` - First modulus.
298+ /// * `a2` - Second residue.
299+ /// * `n2` - Second modulus.
300+ ///
301+ /// # Returns
302+ /// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
183303pub fn crt ( a1 : i64 , n1 : i64 , a2 : i64 , n2 : i64 ) -> i64 {
184304 let n = n1 * n2;
185305 let m1 = mod_inv ( n1, n2) ; // Inverse of n1 mod n2
@@ -188,10 +308,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
188308 if x < 0 { x + n } else { x }
189309}
190310
191- // computes an n^th root of unity modulo a composite modulus
192- // note we require that an n^th root of unity exists for each multiplicative group modulo p^e
193- // use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
194- // for the NTT, we require than a 2n^th root of unity exists
311+ /// computes an n^th root of unity modulo a composite modulus
312+ /// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
313+ /// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus
314+ /// for the NTT, we require than a 2n^th root of unity exists
315+ /// # Arguments
316+ ///
317+ /// * `modulus` - Modulus. n must divide each prime power factor.
318+ /// * `n` - Order of the root of unity.
319+ ///
320+ /// # Returns
321+ /// The n-th root of unity modulo `modulus`.
195322pub fn root_of_unity ( modulus : i64 , n : i64 ) -> i64 {
196323 let factors = factorize ( modulus) ;
197324 let mut result = 1 ;
@@ -202,10 +329,17 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
202329 result
203330}
204331
205- //ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
332+ /// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
333+ /// # Arguments
334+ ///
335+ /// * `omega` - n-th root of unity.
336+ /// * `n` - Order of the root of unity.
337+ /// * `modulus` - Modulus.
338+ ///
339+ /// # Returns
340+ /// True if the root of unity satisfies the condition.
206341pub fn verify_root_of_unity ( omega : i64 , n : i64 , modulus : i64 ) -> bool {
207342 assert ! ( mod_exp( omega, n, modulus as i64 ) == 1 , "omega is not an n-th root of unity" ) ;
208343 assert ! ( mod_exp( omega, n/2 , modulus as i64 ) == modulus-1 , "omgea^(n/2) != -1 (mod modulus)" ) ;
209344 true
210- }
211-
345+ }
0 commit comments