Skip to content

Commit 2ff3b82

Browse files
committed
Fixes and comments from Anas
1 parent 35841c7 commit 2ff3b82

File tree

3 files changed

+26
-135
lines changed

3 files changed

+26
-135
lines changed

src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg
200200
__ vmovdqu(output2[i], input2[i], vector_len);
201201
}
202202
for (int i = 0; i < regCnt; i++) {
203-
__ evmovshdup(output2[i], k2, input1[i], true, vector_len);
203+
__ evmovshdup(output2[i], mergeMask2, input1[i], true, vector_len);
204204
}
205205
for (int i = 0; i < regCnt; i++) {
206-
__ evmovsldup(input1[i], k1, input2[i], true, vector_len);
206+
__ evmovsldup(input1[i], mergeMask1, input2[i], true, vector_len);
207207
}
208208
break;
209209
// Special cases
@@ -390,7 +390,7 @@ static void storeXmms(Register destination, int offset, const XMMRegister xmmReg
390390
// static int implDilithiumAlmostNtt(int[] coeffs, int zetas[]) {}
391391
//
392392
// coeffs (int[256]) = c_rarg0
393-
// zetas (int[256]) = c_rarg1
393+
// zetas (int[128*8]) = c_rarg1
394394
//
395395
static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen,
396396
int vector_len, MacroAssembler *_masm) {
@@ -647,7 +647,7 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen,
647647
// static int implDilithiumAlmostInverseNtt(int[] coeffs, int[] zetas) {}
648648
//
649649
// coeffs (int[256]) = c_rarg0
650-
// zetas (int[256]) = c_rarg1
650+
// zetas (int[128*8]) = c_rarg1
651651
static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen,
652652
int vector_len,MacroAssembler *_masm) {
653653
__ align(CodeEntryAlignment);
@@ -1017,12 +1017,13 @@ static address generate_dilithiumMontMulByConstant_avx(StubGenerator *stubgen,
10171017
__ evpbroadcastd(constant, rConstant, Assembler::AVX_512bit); // constant multiplier
10181018

10191019
__ mov64(scratch, 0b0101010101010101); //dw-mask
1020-
__ kmovwl(k2, scratch);
1020+
__ kmovwl(mergeMask, scratch);
10211021
}
10221022

10231023
// Total payload is 256*int32s.
10241024
// - memStep is number of bytes one montMul64 processes.
10251025
// - loopCnt is number of iterations it will take to process entire payload.
1026+
// - (two memSteps per loop)
10261027
int memStep = 4 * 64;
10271028
int loopCnt = 2;
10281029
if (vector_len == Assembler::AVX_256bit) {
@@ -1321,15 +1322,15 @@ void StubGenerator::generate_dilithium_stubs() {
13211322
}
13221323
// Generate Dilithium intrinsics code
13231324
if (UseDilithiumIntrinsics) {
1324-
StubRoutines::_dilithiumAlmostNtt =
1325+
StubRoutines::_dilithiumAlmostNtt =
13251326
generate_dilithiumAlmostNtt_avx(this, vector_len, _masm);
1326-
StubRoutines::_dilithiumAlmostInverseNtt =
1327+
StubRoutines::_dilithiumAlmostInverseNtt =
13271328
generate_dilithiumAlmostInverseNtt_avx(this, vector_len, _masm);
1328-
StubRoutines::_dilithiumNttMult =
1329+
StubRoutines::_dilithiumNttMult =
13291330
generate_dilithiumNttMult_avx(this, vector_len, _masm);
1330-
StubRoutines::_dilithiumMontMulByConstant =
1331+
StubRoutines::_dilithiumMontMulByConstant =
13311332
generate_dilithiumMontMulByConstant_avx(this, vector_len, _masm);
1332-
StubRoutines::_dilithiumDecomposePoly =
1333+
StubRoutines::_dilithiumDecomposePoly =
13331334
generate_dilithiumDecomposePoly_avx(this, vector_len, _masm);
13341335
}
13351336
}

test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java

Lines changed: 12 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
import java.lang.reflect.Constructor;
3232
import java.util.HexFormat;
3333

34+
/*
35+
* @test
36+
* @library /test/lib
37+
* @modules java.base/sun.security.provider:open
38+
* @run main ML_DSA_Intrinsic_Test
39+
*/
3440
public class ML_DSA_Intrinsic_Test {
3541
public static void main(String[] args) throws Exception {
3642
MethodHandles.Lookup lookup = MethodHandles.lookup();
@@ -129,7 +135,7 @@ public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeff
129135
mult.invoke(prod1, coeffs1, coeffs2);
130136
multJava.invoke(prod2, coeffs1, coeffs2);
131137

132-
if (!Arrays.equals(prod1, parseHex(formatOf(prod2)))) {
138+
if (!Arrays.equals(prod1, prod2)) {
133139
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mult mismatch: " + formatOf(prod1) + " != " + formatOf(prod2));
134140
}
135141
}
@@ -141,7 +147,9 @@ public static void testMultConst(int[] prod1, int[] prod2,
141147
for (int j = 0; j<ML_DSA_N; j++) {
142148
prod1[j] = prod2[j] = rnd.nextInt();
143149
}
144-
int c = rnd.nextInt();
150+
// Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf, one of the inputs is bound, which prevents overflows
151+
int dilithium_q = 8380417;
152+
int c = rnd.nextInt(dilithium_q);
145153

146154
multConst.invoke(prod1, c);
147155
multConstJava.invoke(prod2, c);
@@ -189,9 +197,6 @@ public static void testAlmostNtt(int[] coeffs1, int[] coeffs2,
189197
almostNttJava.invoke(coeffs2);
190198

191199
if (!Arrays.equals(coeffs1, coeffs2)) {
192-
if (false) {
193-
implDilithiumAlmostNttJava(coeffs2);
194-
}
195200
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result AlmostNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
196201
}
197202
}
@@ -209,9 +214,6 @@ public static void testInverseNtt(int[] coeffs1, int[] coeffs2,
209214
inverseNttJava.invoke(coeffs2);
210215

211216
if (!Arrays.equals(coeffs1, coeffs2)) {
212-
if (true) {
213-
implDilithiumAlmostInverseNttJava(coeffs3);
214-
}
215217
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result InverseNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
216218
}
217219
}
@@ -225,18 +227,8 @@ private static CharSequence formatOf(int[] arr) {
225227
return b.toString();
226228
}
227229

228-
private static int[] parseHex(CharSequence string) {
229-
assert(string.length()%8==0);
230-
231-
int[] r = new int[string.length()/8];
232-
HexFormat hex = HexFormat.of();
233-
for (int i = 0, j = 0; j<string.length(); i++, j+=8) {
234-
r[i] = hex.fromHexDigits(string, j, j+8);
235-
}
236-
return r;
237-
}
238-
239-
private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{
230+
// Copied constants from sun.security.provider.ML_DSA
231+
private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{
240232
-1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416,
241233
3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036,
242234
3833893, 260646, 1104333, 1667432, -1910376, 1803090, -1723600, 426683,
@@ -511,107 +503,5 @@ private static int[] parseHex(CharSequence string) {
511503
-2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687,
512504
-554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782
513505
};
514-
515-
516-
static void implDilithiumAlmostNttJava(int[] coeffs) {
517-
HexFormat hex = HexFormat.of();
518-
int dimension = 256;
519-
int m = 0;
520-
int testLevel = 2;
521-
for (int l = dimension / 2; l >= testLevel; l /= 2) {
522-
for (int s = 0; s < dimension; s += 2 * l) {
523-
for (int j = s; j < s + l; j++) {
524-
StringBuilder bld = new StringBuilder();
525-
bld.append("l = " + l + ", m = " + m + ", j = " + j+": ");
526-
bld.append(hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " * " + hex.toHexDigits(MONT_ZETAS_FOR_NTT[m]));
527-
int tmp = montMul(MONT_ZETAS_FOR_NTT[m], coeffs[j + l]);
528-
coeffs[j + l] = coeffs[j] - tmp;
529-
coeffs[j] = coeffs[j] + tmp;
530-
bld.append(" -> " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " (tmp = " + hex.toHexDigits(tmp) + ")");
531-
if (l == testLevel) {
532-
System.out.println(bld.toString());
533-
}
534-
}
535-
m++;
536-
}
537-
}
538-
}
539-
static void implDilithiumAlmostInverseNttJava(int[] coeffs) {
540-
HexFormat hex = HexFormat.of();
541-
int dimension = 256;
542-
int m = MONT_ZETAS_FOR_NTT.length - 1;
543-
int testLevel = 1;
544-
for (int l = 1; l < dimension; l *= 2) {
545-
for (int s = 0; s < dimension; s += 2 * l) {
546-
for (int j = s; j < s + l; j++) {
547-
StringBuilder bld = new StringBuilder();
548-
bld.append("l = " + l + ", m = " + m + ", j = " + j+": ");
549-
int tmp = coeffs[j];
550-
bld.append(hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " -> " + hex.toHexDigits(tmp - coeffs[j + l]) + " * " + hex.toHexDigits(-MONT_ZETAS_FOR_NTT[m]));
551-
coeffs[j] = (tmp + coeffs[j + l]);
552-
coeffs[j + l] = montMul(tmp - coeffs[j + l], -MONT_ZETAS_FOR_NTT[m]);
553-
bld.append(" -> " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]));
554-
if (l == testLevel) {
555-
System.out.println(bld.toString());
556-
}
557-
}
558-
m--;
559-
}
560-
}
561-
}
562-
private static int montMul(int b, int c) {
563-
long a = (long) b * (long) c;
564-
int aHigh = (int) (a >> 32);
565-
int aLow = (int) a;
566-
int m = 58728449 * aLow; // signed low product
567-
568-
// subtract signed high product
569-
return (aHigh - (int) (((long)m * 8380417) >> 32));
570-
}
571-
// Zeta values for NTT with montgomery factor precomputed
572-
private static final int[] MONT_ZETAS_FOR_NTT = new int[]{
573-
25847, //0
574-
-2608894, -518909, //1
575-
237124, -777960, -876248, 466468, //2
576-
1826347,
577-
2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, //3
578-
2725464,
579-
1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186,
580-
-3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, // 4
581-
582-
2706023,
583-
95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115,
584-
-3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267,
585-
-1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, 811944,
586-
531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, //5
587-
588-
-3930395,
589-
-1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, -1257611,
590-
1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922, 3412210,
591-
-983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, -671102,
592-
-1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383,
593-
264944, 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342,
594-
-8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, 189548,
595-
-3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669,
596-
-1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, //6
597-
598-
2091667,
599-
3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, 266997,
600-
2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, 900702,
601-
1859098, 909542, 819034, 495491, -1613174, -43260, -522500, -655327,
602-
-3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
603-
286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 2842341,
604-
2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, -3767016,
605-
1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, -1333058,
606-
1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, -1279661,
607-
1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, -542412,
608-
-2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608,
609-
2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426,
610-
162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, -3038916,
611-
3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, -426683,
612-
1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036,
613-
-2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416,
614-
3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782
615-
};
616506
}
617507
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java

test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@
4848
@OutputTimeUnit(TimeUnit.MILLISECONDS)
4949
@Fork(value = 1, jvmArgs = {"--add-opens", "java.base/sun.security.provider=ALL-UNNAMED"})
5050
public class MLDSABench extends CryptoBase {
51-
5251
public static final int SET_SIZE = 128;
53-
private static final int ML_DSA_N = 256;
5452

5553
private int[][] coeffs1;
5654
private int[][] coeffs2;
@@ -109,7 +107,7 @@ public void setup() throws Exception {
109107
}
110108

111109
@Benchmark
112-
public void mult1() throws Exception, Throwable {
110+
public void mult() throws Exception, Throwable {
113111
mult.invoke(prod1[index], coeffs1[index], coeffs2[index]);
114112
index = (index + 1) % SET_SIZE;
115113
}
@@ -143,6 +141,8 @@ public void multInverseNtt() throws Exception, Throwable {
143141
index = (index + 1) % SET_SIZE;
144142
}
145143

144+
// Copied constants from sun.security.provider.ML_DSA
145+
private static final int ML_DSA_N = 256;
146146
private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{
147147
-1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416,
148148
3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036,

0 commit comments

Comments
 (0)