@@ -26,19 +26,24 @@ mod tests {
2626
2727 #[ test]
2828 fn test_polymul_ntt_square_modulus ( ) {
29- let moduli = [ 17 * 17 , 12289 * 12289 ] ; // Different moduli to test
30- let n: usize = 8 ; // Length of the NTT (must be a power of 2)
31-
32- for & modulus in & moduli {
33- let omega = omega ( modulus, n) ; // n-th root of unity
34- let mut a = vec ! [ 1 , 2 , 3 , 4 ] ;
35- let mut b = vec ! [ 5 , 6 , 7 , 8 ] ;
36- a. resize ( n, 0 ) ;
37- b. resize ( n, 0 ) ;
38- let c_std = polymul ( & a, & b, n as i64 , modulus) ;
39- let c_fast = polymul_ntt ( & a, & b, n, modulus, omega) ;
40- assert_eq ! ( c_std, c_fast, "The results of polymul and polymul_ntt do not match" ) ;
29+ let cases = [
30+ ( 17 * 17 , 4 ) , // small square modulus
31+ ( 12289 * 12289 , 512 ) // large square modulus
32+ ] ;
33+
34+ for & ( modulus, n) in & cases {
35+ let omega = omega ( modulus, 2 * n) ; // n-th root of unity
36+ let mut a: Vec < i64 > = ( 0 ..n) . map ( |x| x as i64 ) . collect ( ) ;
37+ let mut b: Vec < i64 > = ( 0 ..n) . map ( |x| x as i64 ) . collect ( ) ;
38+ a. resize ( 2 * n, 0 ) ;
39+ b. resize ( 2 * n, 0 ) ;
40+
41+ let c_std = polymul ( & a, & b, 2 * n as i64 , modulus) ;
42+ let c_fast = polymul_ntt ( & a, & b, 2 * n, modulus, omega) ;
43+
44+ assert_eq ! ( c_std, c_fast, "The results of polymul and polymul_ntt do not match for modulus {} and n {}" , modulus, n) ;
4145 }
46+
4247 }
4348
4449 #[ test]
0 commit comments