3434/*
3535 * @test
3636 * @library /test/lib
37- * @modules java.base/sun.security.provider:open
37+ * @key randomness
38+ * @modules java.base/sun.security.provider:+open
39+ * @run main/othervm ML_DSA_Intrinsic_Test -XX:+UnlockDiagnosticVMOptions -XX:-UseDilithiumIntrinsics
40+ */
41+ /*
42+ * @test
43+ * @library /test/lib
44+ * @key randomness
45+ * @modules java.base/sun.security.provider:+open
46+ * @run main/othervm -XX:UseAVX=2 ML_DSA_Intrinsic_Test
47+ */
48+ /*
49+ * @test
50+ * @library /test/lib
51+ * @key randomness
52+ * @modules java.base/sun.security.provider:+open
3853 * @run main ML_DSA_Intrinsic_Test
3954 */
4055public class ML_DSA_Intrinsic_Test {
4156 public static void main (String [] args ) throws Exception {
4257 MethodHandles .Lookup lookup = MethodHandles .lookup ();
43- Class <?> kClazz = Class .forName ("sun.security.provider.ML_DSA" );
44- Constructor <?> constructor = kClazz .getDeclaredConstructor (
45- int .class );
46- constructor .setAccessible (true );
47-
58+ Class <?> kClazz = sun .security .provider .ML_DSA .class ;
59+
4860 Method m = kClazz .getDeclaredMethod ("implDilithiumNttMult" ,
4961 int [].class , int [].class , int [].class );
5062 m .setAccessible (true );
@@ -123,46 +135,46 @@ public static void main(String[] args) throws Exception {
123135 }
124136
125137 private static final int ML_DSA_N = 256 ;
126- public static void testMult (int [] prod1 , int [] prod2 , int [] coeffs1 , int [] coeffs2 ,
127- MethodHandle mult , MethodHandle multJava , Random rnd ,
138+ public static void testMult (int [] prod1 , int [] prod2 , int [] coeffs1 , int [] coeffs2 ,
139+ MethodHandle mult , MethodHandle multJava , Random rnd ,
128140 long seed , int i ) throws Exception , Throwable {
129-
141+
130142 for (int j = 0 ; j <ML_DSA_N ; j ++) {
131143 coeffs1 [j ] = rnd .nextInt ();
132144 coeffs2 [j ] = rnd .nextInt ();
133145 }
134146
135147 mult .invoke (prod1 , coeffs1 , coeffs2 );
136148 multJava .invoke (prod2 , coeffs1 , coeffs2 );
137-
149+
138150 if (!Arrays .equals (prod1 , prod2 )) {
139151 throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result mult mismatch: " + formatOf (prod1 ) + " != " + formatOf (prod2 ));
140152 }
141153 }
142154
143155 public static void testMultConst (int [] prod1 , int [] prod2 ,
144- MethodHandle multConst , MethodHandle multConstJava , Random rnd ,
156+ MethodHandle multConst , MethodHandle multConstJava , Random rnd ,
145157 long seed , int i ) throws Exception , Throwable {
146-
158+
147159 for (int j = 0 ; j <ML_DSA_N ; j ++) {
148160 prod1 [j ] = prod2 [j ] = rnd .nextInt ();
149161 }
150162 // Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf, one of the inputs is bound, which prevents overflows
151163 int dilithium_q = 8380417 ;
152- int c = rnd .nextInt (dilithium_q );
164+ int c = rnd .nextInt (dilithium_q );
153165
154166 multConst .invoke (prod1 , c );
155167 multConstJava .invoke (prod2 , c );
156-
168+
157169 if (!Arrays .equals (prod1 , prod2 )) {
158170 throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result multConst mismatch: " + formatOf (prod1 ) + " != " + formatOf (prod2 ));
159171 }
160172 }
161173
162- public static void testDecompose (int [] low1 , int [] high1 , int [] low2 , int [] high2 , int [] coeffs1 , int [] coeffs2 ,
163- MethodHandle decompose , MethodHandle decomposeJava , Random rnd ,
174+ public static void testDecompose (int [] low1 , int [] high1 , int [] low2 , int [] high2 , int [] coeffs1 , int [] coeffs2 ,
175+ MethodHandle decompose , MethodHandle decomposeJava , Random rnd ,
164176 long seed , int i ) throws Exception , Throwable {
165-
177+
166178 for (int j = 0 ; j <ML_DSA_N ; j ++) {
167179 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
168180 }
@@ -174,7 +186,7 @@ public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high
174186
175187 decompose .invoke (coeffs1 , low1 , high1 , 2 * gamma2 , multiplier );
176188 decomposeJava .invoke (coeffs2 , low2 , high2 , 2 * gamma2 , multiplier );
177-
189+
178190 if (!Arrays .equals (low1 , low2 )) {
179191 throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result low mismatch: " + formatOf (low1 ) + " != " + formatOf (low2 ));
180192 }
@@ -184,12 +196,10 @@ public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high
184196 }
185197 }
186198
187- public static void testAlmostNtt (int [] coeffs1 , int [] coeffs2 ,
188- MethodHandle almostNtt , MethodHandle almostNttJava , Random rnd ,
199+ public static void testAlmostNtt (int [] coeffs1 , int [] coeffs2 ,
200+ MethodHandle almostNtt , MethodHandle almostNttJava , Random rnd ,
189201 long seed , int i ) throws Exception , Throwable {
190- //int[] coeffs3 = new int[ML_DSA_N];
191202 for (int j = 0 ; j <ML_DSA_N ; j ++) {
192- //coeffs3[j] =
193203 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
194204 }
195205
@@ -201,12 +211,12 @@ public static void testAlmostNtt(int[] coeffs1, int[] coeffs2,
201211 }
202212 }
203213
204- public static void testInverseNtt (int [] coeffs1 , int [] coeffs2 ,
205- MethodHandle inverseNtt , MethodHandle inverseNttJava , Random rnd ,
214+ public static void testInverseNtt (int [] coeffs1 , int [] coeffs2 ,
215+ MethodHandle inverseNtt , MethodHandle inverseNttJava , Random rnd ,
206216 long seed , int i ) throws Exception , Throwable {
207217 int [] coeffs3 = new int [ML_DSA_N ];
208218 for (int j = 0 ; j <ML_DSA_N ; j ++) {
209- coeffs3 [j ] =
219+ coeffs3 [j ] =
210220 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
211221 }
212222
0 commit comments