From 35841c7a335158f6d63992c25f33e320b1760249 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Wed, 9 Jul 2025 02:58:22 +0000 Subject: [PATCH 1/9] AVX2 and AVX512 intrinsics for MLDSA --- src/hotspot/cpu/x86/assembler_x86.cpp | 42 + src/hotspot/cpu/x86/assembler_x86.hpp | 5 + src/hotspot/cpu/x86/macroAssembler_x86.hpp | 1 + .../x86/stubGenerator_x86_64_dilithium.cpp | 1687 ++++++++++------- src/hotspot/cpu/x86/vm_version_x86.cpp | 2 +- .../provider/acvp/ML_DSA_Intrinsic_Test.java | 617 ++++++ .../bench/javax/crypto/full/MLDSABench.java | 421 ++++ 7 files changed, 2081 insertions(+), 694 deletions(-) create mode 100644 test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java create mode 100644 test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index d1b6897f287c8..af2e5de5803d8 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3754,6 +3754,48 @@ void Assembler::evmovdquq(Address dst, KRegister mask, XMMRegister src, bool mer emit_operand(src, dst, 0); } +void Assembler::vmovsldup(XMMRegister dst, XMMRegister src, int vector_len) { + assert(vector_len == AVX_128bit ? VM_Version::supports_avx() : + (vector_len == AVX_256bit ? VM_Version::supports_avx2() : + (vector_len == AVX_512bit ? VM_Version::supports_evex() : false)), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); + emit_int16(0x12, (0xC0 | encode)); +} + +void Assembler::vmovshdup(XMMRegister dst, XMMRegister src, int vector_len) { + assert(vector_len == AVX_128bit ? VM_Version::supports_avx() : + (vector_len == AVX_256bit ? VM_Version::supports_avx2() : + (vector_len == AVX_512bit ? VM_Version::supports_evex() : false)), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); + emit_int16(0x16, (0xC0 | encode)); +} + +void Assembler::evmovsldup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { + assert(VM_Version::supports_evex(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + attributes.set_embedded_opmask_register_specifier(mask); + attributes.set_is_evex_instruction(); + if (merge) { + attributes.reset_is_clear_context(); + } + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); + emit_int16(0x12, (0xC0 | encode)); +} + +void Assembler::evmovshdup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { + assert(VM_Version::supports_evex(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + attributes.set_embedded_opmask_register_specifier(mask); + attributes.set_is_evex_instruction(); + if (merge) { + attributes.reset_is_clear_context(); + } + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); + emit_int16(0x16, (0xC0 | encode)); +} + // Uses zero extension on 64bit void Assembler::movl(Register dst, int32_t imm32) { diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 45c24f8c83256..2530df532a093 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -1641,6 +1641,11 @@ class Assembler : public AbstractAssembler { void evmovdqaq(XMMRegister dst, Address src, int vector_len); void evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len); + void vmovsldup(XMMRegister dst, XMMRegister src, int vector_len); + void vmovshdup(XMMRegister dst, XMMRegister src, int vector_len); + void evmovsldup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len); + void evmovshdup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len); + // Move lower 64bit to high 64bit in 128bit register void movlhps(XMMRegister dst, XMMRegister src); diff --git a/src/hotspot/cpu/x86/macroAssembler_x86.hpp b/src/hotspot/cpu/x86/macroAssembler_x86.hpp index d75c9b624fd3a..595cc133d3e7a 100644 --- a/src/hotspot/cpu/x86/macroAssembler_x86.hpp +++ b/src/hotspot/cpu/x86/macroAssembler_x86.hpp @@ -1368,6 +1368,7 @@ class MacroAssembler: public Assembler { void vpcmpeqw(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vpcmpeqw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + using Assembler::evpcmpeqd; void evpcmpeqd(KRegister kdst, KRegister mask, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg); // Vector compares diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index 7121db2ab9165..bca7363b9321a 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -30,8 +30,6 @@ #define __ _masm-> -#define xmm(i) as_XMMRegister(i) - #ifdef PRODUCT #define BLOCK_COMMENT(str) /* nothing */ #else @@ -40,15 +38,13 @@ #define BIND(label) bind(label); BLOCK_COMMENT(#label ":") -#define XMMBYTES 64 - // Constants // ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Consts[] = { 58728449, // montQInvModR - 8380417, // dilithium_q - 2365951, // montRSquareModQ - 5373807 // Barrett addend for modular reduction + 8380417, // dilithium_q + 2365951, // montRSquareModQ + 5373807 // Barrett addend for modular reduction }; const int montQInvModRIdx = 0; @@ -60,207 +56,334 @@ static address dilithiumAvx512ConstsAddr(int offset) { return ((address) dilithiumAvx512Consts) + offset; } -const Register scratch = r10; -const XMMRegister montMulPerm = xmm28; -const XMMRegister montQInvModR = xmm30; -const XMMRegister dilithium_q = xmm31; +ATTRIBUTE_ALIGNED(64) static const uint32_t unshufflePerms[] = { + // Shuffle for the 128-bit element swap (uint64_t) + 0, 0, 1, 0, 8, 0, 9, 0, 4, 0, 5, 0, 12, 0, 13, 0, + 10, 0, 11, 0, 2, 0, 3, 0, 14, 0, 15, 0, 6, 0, 7, 0, + // Final shuffle for AlmostNtt + 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, + 24, 8, 25, 9, 26, 10, 27, 11, 28, 12, 29, 13, 30, 14, 31, 15, -ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Perms[] = { - // collect montmul results into the destination register - 17, 1, 19, 3, 21, 5, 23, 7, 25, 9, 27, 11, 29, 13, 31, 15, - // ntt - // level 4 - 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, - // level 5 - 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27, - 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31, - // level 6 - 0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29, - 2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31, - // level 7 - 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, 10, 26, 12, 28, 14, 30, - 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31, - 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, - 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, - - // ntt inverse - // level 0 - 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, - 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, - // level 1 - 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, 10, 26, 12, 28, 14, 30, - 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31, - // level 2 - 0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29, - 2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31, - // level 3 - 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27, - 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31, - // level 4 - 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + // Initial shuffle for AlmostInverseNtt + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 17, 19, 21, 23, 25, 27, 29, 31, 1, 3, 5, 7, 9, 11, 13, 15 }; -const int montMulPermsIdx = 0; -const int nttL4PermsIdx = 64; -const int nttL5PermsIdx = 192; -const int nttL6PermsIdx = 320; -const int nttL7PermsIdx = 448; -const int nttInvL0PermsIdx = 704; -const int nttInvL1PermsIdx = 832; -const int nttInvL2PermsIdx = 960; -const int nttInvL3PermsIdx = 1088; -const int nttInvL4PermsIdx = 1216; - -static address dilithiumAvx512PermsAddr() { - return (address) dilithiumAvx512Perms; +static address unshufflePermsAddr(int offset) { + return ((address) unshufflePerms) + offset*64; } -// We do Montgomery multiplications of two vectors of 16 ints each in 4 steps: +// The following function swaps elements A<->B, C<->D, and so forth. +// input1[] is shuffled in place; shuffle of input2[] is copied to output2[]. +// Element size (in bits) is specified by size parameter. +// size 0 and 1 are used for initial and final shuffles respectivelly of +// dilithiumAlmostInverseNtt and dilithiumAlmostNtt. +// NOTE: For size 0 and 1, input1[] and input2[] are modified in-place +// +// +-----+-----+-----+-----+----- +// | | A | | C | ... +// +-----+-----+-----+-----+----- +// +-----+-----+-----+-----+----- +// | B | | D | | ... +// +-----+-----+-----+-----+----- +// Using C++ lambdas for improved readability (to hide parameters that always repeat) +static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister mergeMask2, + const XMMRegister unshuffle1, const XMMRegister unshuffle2, int vector_len, MacroAssembler *_masm) { + + int regCnt = 4; + if (vector_len == Assembler::AVX_256bit) { + regCnt = 2; + } + + return [=](const XMMRegister output2[], const XMMRegister input1[], + const XMMRegister input2[], int size) { + if (vector_len == Assembler::AVX_256bit) { + switch (size) { + case 128: + for (int i = 0; i < regCnt; i++) { + __ vperm2i128(output2[i], input1[i], input2[i], 0b110001); + } + for (int i = 0; i < regCnt; i++) { + __ vinserti128(input1[i], input1[i], input2[i], 1); + } + break; + case 64: + for (int i = 0; i < regCnt; i++) { + __ vshufpd(output2[i], input1[i], input2[i], 0b11111111, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vshufpd(input1[i], input1[i], input2[i], 0b00000000, vector_len); + } + break; + case 32: + for (int i = 0; i < regCnt; i++) { + __ vmovshdup(output2[i], input1[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vpblendd(output2[i], output2[i], input2[i], 0b10101010, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vmovsldup(input2[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vpblendd(input1[i], input1[i], input2[i], 0b10101010, vector_len); + } + break; + case 1: + for (int i = 0; i < regCnt; i++) { + // 0b-1-2-3-1 + __ vshufps(output2[i], input1[i], input2[i], 0b11011101, vector_len); + } + for (int i = 0; i < regCnt; i++) { + // 0b-2-0-2-0 + __ vshufps(input1[i], input1[i], input2[i], 0b10001000, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vpermq(input2[i], output2[i], 0b11011000, vector_len); + } + for (int i = 0; i < regCnt; i++) { + // 0b-3-1-2-0 + __ vpermq(input1[i], input1[i], 0b11011000, vector_len); + } + break; + case 0: + for (int i = 0; i < regCnt; i++) { + __ vpunpckhdq(output2[i], input1[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vpunpckldq(input1[i], input1[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vperm2i128(input2[i], input1[i], output2[i], 0b110001); + } + for (int i = 0; i < regCnt; i++) { + __ vinserti128(input1[i], input1[i], output2[i], 1); + } + break; + default: + assert(false, "Don't call here"); + } + } else { + switch (size) { + case 256: + for (int i = 0; i < regCnt; i++) { + // 0b-3-2-3-2 + __ evshufi64x2(output2[i], input1[i], input2[i], 0b11101110, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vinserti64x4(input1[i], input1[i], input2[i], 1); + } + break; + case 128: + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(output2[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2q(output2[i], unshuffle2, input1[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2q(input1[i], unshuffle1, input2[i], vector_len); + } + + break; + case 64: + for (int i = 0; i < regCnt; i++) { + __ vshufpd(output2[i], input1[i], input2[i], 0b11111111, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vshufpd(input1[i], input1[i], input2[i], 0b00000000, vector_len); + } + break; + case 32: + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(output2[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evmovshdup(output2[i], k2, input1[i], true, vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evmovsldup(input1[i], k1, input2[i], true, vector_len); + } + break; + // Special cases + case 1: // initial shuffle for dilithiumAlmostInverseNtt + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(output2[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2d(input2[i], unshuffle2, input1[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2d(input1[i], unshuffle1, output2[i], vector_len); + } + break; + case 0: // final unshuffle for dilithiumAlmostNtt + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(output2[i], input2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2d(input2[i], unshuffle2, input1[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ evpermt2d(input1[i], unshuffle1, output2[i], vector_len); + } + break; + default: + assert(false, "Don't call here"); + } + } + }; // return +} + +// We do Montgomery multiplications of two AVX registers in 4 steps: // 1. Do the multiplications of the corresponding even numbered slots into -// the odd numbered slots of a third register. -// 2. Swap the even and odd numbered slots of the original input registers. -// 3. Similar to step 1, but into a different output register. +// the odd numbered slots of a scratch2 register. +// 2. Swap the even and odd numbered slots of the original input registers.* +// 3. Similar to step 1, but into output register. // 4. Combine the outputs of step 1 and step 3 into the output of the Montgomery // multiplication. -// (For levels 0-6 in the Ntt and levels 1-7 of the inverse Ntt we only swap the -// odd-even slots of the first multiplicand as in the second (zetas) the -// odd slots contain the same number as the corresponding even one.) -// The indexes of the registers to be multiplied -// are in inputRegs1[] and inputRegs[2]. -// The results go to the registers whose indexes are in outputRegs. -// scratchRegs should contain 12 different register indexes. -// The set in outputRegs should not overlap with the set of the middle four -// scratch registers. -// The sets in inputRegs1 and inputRegs2 cannot overlap with the set of the -// first eight scratch registers. -// In most of the cases, the odd and the corresponding even slices of the -// registers indexed by the numbers in inputRegs2 will contain the same number, -// this should be indicated by calling this function with -// input2NeedsShuffle=false . +// (*For levels 0-6 in the Ntt and levels 1-7 of the inverse Ntt, need NOT swap +// the second operand (zetas) since the odd slots contain the same number +// as the corresponding even one. This is indicated by input2NeedsShuffle=false) // -static void montMul64(int outputRegs[], int inputRegs1[], int inputRegs2[], - int scratchRegs[], bool input2NeedsShuffle, - MacroAssembler *_masm) { - - for (int i = 0; i < 4; i++) { - __ vpmuldq(xmm(scratchRegs[i]), xmm(inputRegs1[i]), xmm(inputRegs2[i]), - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ vpmulld(xmm(scratchRegs[i + 4]), xmm(scratchRegs[i]), montQInvModR, - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ vpmuldq(xmm(scratchRegs[i + 4]), xmm(scratchRegs[i + 4]), dilithium_q, - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ evpsubd(xmm(scratchRegs[i + 4]), k0, xmm(scratchRegs[i]), - xmm(scratchRegs[i + 4]), false, Assembler::AVX_512bit); +// The registers to be multiplied are in input1[] and inputs2[]. The results go +// into output[]. Two scratch[] register arrays are expected. input1[] can +// overlap with either output[] or scratch1[] +// - If AVX512, all register arrays are of length 4 +// - If AVX2, first two registers of each array are in xmm0-xmm15 range +// Constants montQInvModR, dilithium_q and mergeMask expected to have already +// been loaded. +// +// Using C++ lambdas for improved readability (to hide parameters that always repeat) +static auto whole_montMul(XMMRegister montQInvModR, XMMRegister dilithium_q, + KRegister mergeMask, int vector_len, MacroAssembler *_masm) { + int regCnt = 4; + int regSize = 64; + if (vector_len == Assembler::AVX_256bit) { + regCnt = 2; + regSize = 32; } - for (int i = 0; i < 4; i++) { - __ vpshufd(xmm(inputRegs1[i]), xmm(inputRegs1[i]), 0xB1, - Assembler::AVX_512bit); - if (input2NeedsShuffle) { - __ vpshufd(xmm(inputRegs2[i]), xmm(inputRegs2[i]), 0xB1, - Assembler::AVX_512bit); + return [=](const XMMRegister output[], const XMMRegister input1[], + const XMMRegister input2[], const XMMRegister scratch1[], + const XMMRegister scratch2[], bool input2NeedsShuffle = false) { + // (Register overloading) Can't always use scratch1 (could override input1). + // If so, use output: + const XMMRegister* scratch = scratch1 == input1 ? output: scratch1; + + // scratch = input1_even*intput2_even + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(scratch[i], input1[i], input2[i], vector_len); } - } - for (int i = 0; i < 4; i++) { - __ vpmuldq(xmm(scratchRegs[i]), xmm(inputRegs1[i]), xmm(inputRegs2[i]), - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ vpmulld(xmm(scratchRegs[i + 8]), xmm(scratchRegs[i]), montQInvModR, - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ vpmuldq(xmm(scratchRegs[i + 8]), xmm(scratchRegs[i + 8]), dilithium_q, - Assembler::AVX_512bit); - } - for (int i = 0; i < 4; i++) { - __ evpsubd(xmm(outputRegs[i]), k0, xmm(scratchRegs[i]), - xmm(scratchRegs[i + 8]), false, Assembler::AVX_512bit); - } + // scratch2_low = scratch_low * montQInvModR + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(scratch2[i], scratch[i], montQInvModR, vector_len); + } - for (int i = 0; i < 4; i++) { - __ evpermt2d(xmm(outputRegs[i]), montMulPerm, xmm(scratchRegs[i + 4]), - Assembler::AVX_512bit); - } -} + // scratch2 = scratch2_low * dilithium_q + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(scratch2[i], scratch2[i], dilithium_q, vector_len); + } + + // scratch2_high = scratch2_high - scratch_high + for (int i = 0; i < regCnt; i++) { + __ vpsubd(scratch2[i], scratch[i], scratch2[i], vector_len); + } + + // input1_even = input1_odd + // input2_even = input2_odd + for (int i = 0; i < regCnt; i++) { + __ vpshufd(input1[i], input1[i], 0xB1, vector_len); + if (input2NeedsShuffle) { + __ vpshufd(input2[i], input2[i], 0xB1, vector_len); + } + } + + // scratch1 = input1_even*intput2_even + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(scratch1[i], input1[i], input2[i], vector_len); + } -static void montMul64(int outputRegs[], int inputRegs1[], int inputRegs2[], - int scratchRegs[], MacroAssembler *_masm) { - montMul64(outputRegs, inputRegs1, inputRegs2, scratchRegs, false, _masm); + // output = scratch1_low * montQInvModR + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(output[i], scratch1[i], montQInvModR, vector_len); + } + + // output = output * dilithium_q + for (int i = 0; i < regCnt; i++) { + __ vpmuldq(output[i], output[i], dilithium_q, vector_len); + } + + // output_high = scratch1_high - output_high + for (int i = 0; i < regCnt; i++) { + __ vpsubd(output[i], scratch1[i], output[i], vector_len); + } + + // output = select(output_high, scratch2_high) + if (vector_len == Assembler::AVX_256bit) { + for (int i = 0; i < regCnt; i++) { + __ vmovshdup(scratch2[i], scratch2[i], vector_len); + } + for (int i = 0; i < regCnt; i++) { + __ vpblendd(output[i], output[i], scratch2[i], 0b01010101, vector_len); + } + } else { + for (int i = 0; i < regCnt; i++) { + __ evmovshdup(output[i], mergeMask, scratch2[i], true, vector_len); + } + } + }; // return } -static void sub_add(int subResult[], int addResult[], - int input1[], int input2[], MacroAssembler *_masm) { +static void sub_add(const XMMRegister subResult[], const XMMRegister addResult[], + const XMMRegister input1[], const XMMRegister input2[], + int vector_len, MacroAssembler *_masm) { + int regCnt = 4; + if (vector_len == Assembler::AVX_256bit) { + regCnt = 2; + } - for (int i = 0; i < 4; i++) { - __ evpsubd(xmm(subResult[i]), k0, xmm(input1[i]), xmm(input2[i]), false, - Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpsubd(subResult[i], input1[i], input2[i], vector_len); } - for (int i = 0; i < 4; i++) { - __ evpaddd(xmm(addResult[i]), k0, xmm(input1[i]), xmm(input2[i]), false, - Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpaddd(addResult[i], input1[i], input2[i], vector_len); } } -static void loadPerm(int destinationRegs[], Register perms, - int offset, MacroAssembler *_masm) { - __ evmovdqul(xmm(destinationRegs[0]), Address(perms, offset), - Assembler::AVX_512bit); - for (int i = 1; i < 4; i++) { - __ evmovdqul(xmm(destinationRegs[i]), xmm(destinationRegs[0]), - Assembler::AVX_512bit); - } -} +static void loadXmms(const XMMRegister destinationRegs[], Register source, int offset, + int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { -static void load4Xmms(int destinationRegs[], Register source, int offset, - MacroAssembler *_masm) { - for (int i = 0; i < 4; i++) { - __ evmovdqul(xmm(destinationRegs[i]), Address(source, offset + i * XMMBYTES), - Assembler::AVX_512bit); + if (vector_len == Assembler::AVX_256bit) { + regCnt = regCnt == -1 ? 2 : regCnt; + memStep = memStep == -1 ? 32 : memStep; + } else { + regCnt = 4; + memStep = 64; } -} -static void loadXmm29(Register source, int offset, MacroAssembler *_masm) { - __ evmovdqul(xmm29, Address(source, offset), Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(destinationRegs[i], Address(source, offset + i * memStep), vector_len); + } } -static void store4Xmms(Register destination, int offset, int xmmRegs[], - MacroAssembler *_masm) { - for (int i = 0; i < 4; i++) { - __ evmovdqul(Address(destination, offset + i * XMMBYTES), xmm(xmmRegs[i]), - Assembler::AVX_512bit); +static void storeXmms(Register destination, int offset, const XMMRegister xmmRegs[], + int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { + if (vector_len == Assembler::AVX_256bit) { + regCnt = regCnt == -1 ? 2 : regCnt; + memStep = memStep == -1 ? 32 : memStep; + } else { + regCnt = 4; + memStep = 64; } -} -static int xmm0_3[] = {0, 1, 2, 3}; -static int xmm0145[] = {0, 1, 4, 5}; -static int xmm0246[] = {0, 2, 4, 6}; -static int xmm0426[] = {0, 4, 2, 6}; -static int xmm1357[] = {1, 3, 5, 7}; -static int xmm1537[] = {1, 5, 3, 7}; -static int xmm2367[] = {2, 3, 6, 7}; -static int xmm4_7[] = {4, 5, 6, 7}; -static int xmm8_11[] = {8, 9, 10, 11}; -static int xmm12_15[] = {12, 13, 14, 15}; -static int xmm16_19[] = {16, 17, 18, 19}; -static int xmm20_23[] = {20, 21, 22, 23}; -static int xmm20222426[] = {20, 22, 24, 26}; -static int xmm21232527[] = {21, 23, 25, 27}; -static int xmm24_27[] = {24, 25, 26, 27}; -static int xmm4_20_24[] = {4, 5, 6, 7, 20, 21, 22, 23, 24, 25, 26, 27}; -static int xmm16_27[] = {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}; -static int xmm29_29[] = {29, 29, 29, 29}; + for (int i = 0; i < regCnt; i++) { + __ vmovdqu(Address(destination, offset + i * memStep), xmmRegs[i], vector_len); + } +} // Dilithium NTT function except for the final "normalization" to |coeff| < Q. // Implements @@ -269,184 +392,249 @@ static int xmm29_29[] = {29, 29, 29, 29}; // coeffs (int[256]) = c_rarg0 // zetas (int[256]) = c_rarg1 // -// -static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen, - MacroAssembler *_masm) { - +static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubGenStubId stub_id = dilithiumAlmostNtt_id; StubCodeMark mark(stubgen, stub_id); address start = __ pc(); __ enter(); - Label L_loop, L_end; - const Register coeffs = c_rarg0; const Register zetas = c_rarg1; - const Register iterations = c_rarg2; - - const Register perms = r11; - - __ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr())); - - __ evmovdqul(montMulPerm, Address(perms, montMulPermsIdx), Assembler::AVX_512bit); + const Register scratch = r10; // Each level represents one iteration of the outer for loop of the Java version // In each of these iterations half of the coefficients are (Montgomery) // multiplied by a zeta corresponding to the coefficient and then these // products will be added to and subtracted from the other half of the - // coefficients. In each level we just collect the coefficients (using - // evpermi2d() instructions where necessary, i.e. in levels 4-7) that need to + // coefficients. In each level we just shuffle the coefficients that need to // be multiplied by the zetas in one set, the rest to another set of vector // registers, then redistribute the addition/substraction results. // For levels 0 and 1 the zetas are not different within the 4 xmm registers - // that we would use for them, so we use only one, xmm29. - loadXmm29(zetas, 0, _masm); + // that we would use for them, so we use only one register. + + // AVX2 version uses the first half of these arrays + const XMMRegister Coeffs1[] = {xmm0, xmm1, xmm16, xmm17}; + const XMMRegister Coeffs2[] = {xmm2, xmm3, xmm18, xmm19}; + const XMMRegister Coeffs3[] = {xmm4, xmm5, xmm20, xmm21}; + const XMMRegister Coeffs4[] = {xmm6, xmm7, xmm22, xmm23}; + const XMMRegister Scratch1[] = {xmm8, xmm9, xmm24, xmm25}; + const XMMRegister Scratch2[] = {xmm10, xmm11, xmm26, xmm27}; + const XMMRegister Zetas1[] = {xmm12, xmm12, xmm12, xmm12}; + const XMMRegister Zetas2[] = {xmm12, xmm12, xmm13, xmm13}; + const XMMRegister Zetas3[] = {xmm12, xmm13, xmm28, xmm29}; + const XMMRegister montQInvModR = xmm14; + const XMMRegister dilithium_q = xmm15; + const XMMRegister unshuffle1 = xmm30; + const XMMRegister unshuffle2 = xmm31; + KRegister mergeMask1 = k1; + KRegister mergeMask2 = k2; + // lambdas to hide repeated parameters + auto shuffle = whole_shuffle(scratch, mergeMask1, mergeMask2, unshuffle1, unshuffle2, vector_len, _masm); + auto montMul64 = whole_montMul(montQInvModR, dilithium_q, mergeMask2, vector_len, _masm); + __ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr(montQInvModRIdx)), - Assembler::AVX_512bit, scratch); // q^-1 mod 2^32 + vector_len, scratch); // q^-1 mod 2^32 __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)), - Assembler::AVX_512bit, scratch); // q - - // load all coefficients into the vector registers Zmm_0-Zmm_15, - // 16 coefficients into each - load4Xmms(xmm0_3, coeffs, 0, _masm); - load4Xmms(xmm4_7, coeffs, 4 * XMMBYTES, _masm); - load4Xmms(xmm8_11, coeffs, 8 * XMMBYTES, _masm); - load4Xmms(xmm12_15, coeffs, 12 * XMMBYTES, _masm); - - // level 0 and 1 can be done entirely in registers as the zetas on these - // levels are the same for all the montmuls that we can do in parallel - - // level 0 - montMul64(xmm16_19, xmm8_11, xmm29_29, xmm16_27, _masm); - sub_add(xmm8_11, xmm0_3, xmm0_3, xmm16_19, _masm); - montMul64(xmm16_19, xmm12_15, xmm29_29, xmm16_27, _masm); - loadXmm29(zetas, 512, _masm); // for level 1 - sub_add(xmm12_15, xmm4_7, xmm4_7, xmm16_19, _masm); - - // level 1 - - montMul64(xmm16_19, xmm4_7, xmm29_29, xmm16_27, _masm); - loadXmm29(zetas, 768, _masm); - sub_add(xmm4_7, xmm0_3, xmm0_3, xmm16_19, _masm); - montMul64(xmm16_19, xmm12_15, xmm29_29, xmm16_27, _masm); - sub_add(xmm12_15, xmm8_11, xmm8_11, xmm16_19, _masm); - - // levels 2 to 7 are done in 2 batches, by first saving half of the coefficients - // from level 1 into memory, doing all the level 2 to level 7 computations - // on the remaining half in the vector registers, saving the result to - // memory after level 7, then loading back the coefficients that we saved after - // level 1 and do the same computation with those - - store4Xmms(coeffs, 8 * XMMBYTES, xmm8_11, _masm); - store4Xmms(coeffs, 12 * XMMBYTES, xmm12_15, _masm); - - __ movl(iterations, 2); - - __ align(OptoLoopAlignment); - __ BIND(L_loop); - - __ subl(iterations, 1); - - // level 2 - load4Xmms(xmm12_15, zetas, 2 * 512, _masm); - montMul64(xmm16_19, xmm2367, xmm12_15, xmm16_27, _masm); - load4Xmms(xmm12_15, zetas, 3 * 512, _masm); // for level 3 - sub_add(xmm2367, xmm0145, xmm0145, xmm16_19, _masm); - - // level 3 - - montMul64(xmm16_19, xmm1357, xmm12_15, xmm16_27, _masm); - sub_add(xmm1357, xmm0246, xmm0246, xmm16_19, _masm); - - // level 4 - loadPerm(xmm16_19, perms, nttL4PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttL4PermsIdx + 64, _masm); - load4Xmms(xmm24_27, zetas, 4 * 512, _masm); - - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i/2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - - montMul64(xmm12_15, xmm12_15, xmm24_27, xmm4_20_24, _masm); - sub_add(xmm1357, xmm0246, xmm16_19, xmm12_15, _masm); - - // level 5 - loadPerm(xmm16_19, perms, nttL5PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttL5PermsIdx + 64, _masm); - load4Xmms(xmm24_27, zetas, 5 * 512, _masm); - - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i/2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - - montMul64(xmm12_15, xmm12_15, xmm24_27, xmm4_20_24, _masm); - sub_add(xmm1357, xmm0246, xmm16_19, xmm12_15, _masm); - - // level 6 - loadPerm(xmm16_19, perms, nttL6PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttL6PermsIdx + 64, _masm); - load4Xmms(xmm24_27, zetas, 6 * 512, _masm); - - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i/2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - - montMul64(xmm12_15, xmm12_15, xmm24_27, xmm4_20_24, _masm); - sub_add(xmm1357, xmm0246, xmm16_19, xmm12_15, _masm); - - // level 7 - loadPerm(xmm16_19, perms, nttL7PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttL7PermsIdx + 64, _masm); - load4Xmms(xmm24_27, zetas, 7 * 512, _masm); - - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } + vector_len, scratch); // q + + if (vector_len == Assembler::AVX_512bit) { + // levels 0-3, register shuffles: + const XMMRegister Coeffs1_1[] = {xmm0, xmm1, xmm2, xmm3}; + const XMMRegister Coeffs2_1[] = {xmm16, xmm17, xmm18, xmm19}; + const XMMRegister Coeffs3_1[] = {xmm4, xmm5, xmm6, xmm7}; + const XMMRegister Coeffs4_1[] = {xmm20, xmm21, xmm22, xmm23}; + const XMMRegister Coeffs1_2[] = {xmm0, xmm16, xmm2, xmm18}; + const XMMRegister Coeffs2_2[] = {xmm1, xmm17, xmm3, xmm19}; + const XMMRegister Coeffs3_2[] = {xmm4, xmm20, xmm6, xmm22}; + const XMMRegister Coeffs4_2[] = {xmm5, xmm21, xmm7, xmm23}; + + // Constants for shuffle and montMul64 + __ mov64(scratch, 0b1010101010101010); + __ kmovwl(mergeMask1, scratch); + __ knotwl(mergeMask2, mergeMask1); + __ vmovdqu(unshuffle1, ExternalAddress(unshufflePermsAddr(0)), vector_len, scratch); + __ vmovdqu(unshuffle2, ExternalAddress(unshufflePermsAddr(1)), vector_len, scratch); + + int memStep = 4 * 64; // 4*64-byte registers + loadXmms(Coeffs1, coeffs, 0*memStep, vector_len, _masm); + loadXmms(Coeffs2, coeffs, 1*memStep, vector_len, _masm); + loadXmms(Coeffs3, coeffs, 2*memStep, vector_len, _masm); + loadXmms(Coeffs4, coeffs, 3*memStep, vector_len, _masm); + + // level 0-3 can be done by shuffling registers (also notice fewer zetas loads, they repeat) + // level 0 - 128 + // scratch1 = coeffs3 * zetas1 + // coeffs3, coeffs1 = coeffs1±scratch1 + // scratch1 = coeffs4 * zetas1 + // coeffs4, coeffs2 = coeffs2 ± scratch1 + __ vmovdqu(Zetas1[0], Address(zetas, 0), vector_len); + montMul64(Scratch1, Coeffs3, Zetas1, Coeffs3, Scratch2); + sub_add(Coeffs3, Coeffs1, Coeffs1, Scratch1, vector_len, _masm); + montMul64(Scratch1, Coeffs4, Zetas1, Coeffs4, Scratch2); + sub_add(Coeffs4, Coeffs2, Coeffs2, Scratch1, vector_len, _masm); + + // level 1 - 64 + __ vmovdqu(Zetas1[0], Address(zetas, 512), vector_len); + montMul64(Scratch1, Coeffs2, Zetas1, Coeffs2, Scratch2); + sub_add(Coeffs2, Coeffs1, Coeffs1, Scratch1, vector_len, _masm); + + __ vmovdqu(Zetas1[0], Address(zetas, 4*64 + 512), vector_len); + montMul64(Scratch1, Coeffs4, Zetas1, Coeffs4, Scratch2); + sub_add(Coeffs4, Coeffs3, Coeffs3, Scratch1, vector_len, _masm); + + // level 2 - 32 + __ vmovdqu(Zetas2[0], Address(zetas, 2 * 512), vector_len); + __ vmovdqu(Zetas2[2], Address(zetas, 2*64 + 2 * 512), vector_len); + montMul64(Scratch1, Coeffs2_1, Zetas2, Coeffs2_1, Scratch2); + sub_add(Coeffs2_1, Coeffs1_1, Coeffs1_1, Scratch1, vector_len, _masm); + + __ vmovdqu(Zetas2[0], Address(zetas, 4*64 + 2 * 512), vector_len); + __ vmovdqu(Zetas2[2], Address(zetas, 6*64 + 2 * 512), vector_len); + montMul64(Scratch1, Coeffs4_1, Zetas2, Coeffs4_1, Scratch2); + sub_add(Coeffs4_1, Coeffs3_1, Coeffs3_1, Scratch1, vector_len, _masm); + + // level 3 - 16 + loadXmms(Zetas3, zetas, 3 * 512, vector_len, _masm); + montMul64(Scratch1, Coeffs2_2, Zetas3, Coeffs2_2, Scratch2); + sub_add(Coeffs2_2, Coeffs1_2, Coeffs1_2, Scratch1, vector_len, _masm); + + loadXmms(Zetas3, zetas, 4*64 + 3 * 512, vector_len, _masm); + montMul64(Scratch1, Coeffs4_2, Zetas3, Coeffs4_2, Scratch2); + sub_add(Coeffs4_2, Coeffs3_2, Coeffs3_2, Scratch1, vector_len, _masm); + + for (int level = 4, distance = 8; level<8; level++, distance /= 2) { + // zetas = load(level * 512) + // coeffs1_2, scratch1 = shuffle(coeffs1_2, coeffs2_2) + // scratch1 = scratch1 * zetas + // coeffs2_2 = coeffs1_2 - scratch1 + // coeffs1_2 = coeffs1_2 + scratch1 + loadXmms(Zetas3, zetas, level * 512, vector_len, _masm); + shuffle(Scratch1, Coeffs1_2, Coeffs2_2, distance * 32); //Coeffs2_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs2_2, Scratch2, level==7); + sub_add(Coeffs2_2, Coeffs1_2, Coeffs1_2, Scratch1, vector_len, _masm); + + loadXmms(Zetas3, zetas, 4*64 + level * 512, vector_len, _masm); + shuffle(Scratch1, Coeffs3_2, Coeffs4_2, distance * 32); //Coeffs4_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs4_2, Scratch2, level==7); + sub_add(Coeffs4_2, Coeffs3_2, Coeffs3_2, Scratch1, vector_len, _masm); + } - montMul64(xmm12_15, xmm12_15, xmm24_27, xmm4_20_24, true, _masm); - loadPerm(xmm0246, perms, nttL7PermsIdx + 2 * XMMBYTES, _masm); - loadPerm(xmm1357, perms, nttL7PermsIdx + 3 * XMMBYTES, _masm); - sub_add(xmm21232527, xmm20222426, xmm16_19, xmm12_15, _masm); + // Constants for final unshuffle + __ vmovdqu(unshuffle1, ExternalAddress(unshufflePermsAddr(2)), vector_len, scratch); + __ vmovdqu(unshuffle2, ExternalAddress(unshufflePermsAddr(3)), vector_len, scratch); + shuffle(Scratch1, Coeffs1_2, Coeffs2_2, 0); + shuffle(Scratch1, Coeffs3_2, Coeffs4_2, 0); + + storeXmms(coeffs, 0*memStep, Coeffs1, vector_len, _masm); + storeXmms(coeffs, 1*memStep, Coeffs2, vector_len, _masm); + storeXmms(coeffs, 2*memStep, Coeffs3, vector_len, _masm); + storeXmms(coeffs, 3*memStep, Coeffs4, vector_len, _masm); + } else { // Assembler::AVX_256bit + // levels 0-4, register shuffles: + const XMMRegister Coeffs1_1[] = {xmm0, xmm2}; + const XMMRegister Coeffs2_1[] = {xmm1, xmm3}; + const XMMRegister Coeffs3_1[] = {xmm4, xmm6}; + const XMMRegister Coeffs4_1[] = {xmm5, xmm7}; + + const XMMRegister Coeffs1_2[] = {xmm0, xmm1, xmm2, xmm3}; + const XMMRegister Coeffs2_2[] = {xmm4, xmm5, xmm6, xmm7}; + + // Since we cannot fit the entire payload into registers, we process + // input in two stages. First half, load 8 registers 32 integers each apart. + // With one load, we can process level 0-2 (128-, 64- and 32-integers apart) + // Remaining levels, load 8 registers from consecutive memory (16-, 8-, 4-, + // 2-, 1-integer appart) + // Levels 5, 6, 7 (4-, 2-, 1-integer appart) require shuffles within registers + // Other levels, shuffles can be done by re-aranging register order + + // Four batches of 8 registers each, 128 bytes appart + for (int i=0; i<4; i++) { + loadXmms(Coeffs1_2, coeffs, i*32 + 0*128, vector_len, _masm, 4, 128); + loadXmms(Coeffs2_2, coeffs, i*32 + 4*128, vector_len, _masm, 4, 128); + + // level 0-2 can be done by shuffling registers (also notice fewer zetas loads, they repeat) + // level 0 - 128 + __ vmovdqu(Zetas1[0], Address(zetas, 0), vector_len); + montMul64(Scratch1, Coeffs3, Zetas1, Coeffs3, Scratch2); + sub_add(Coeffs3, Coeffs1, Coeffs1, Scratch1, vector_len, _masm); + montMul64(Scratch1, Coeffs4, Zetas1, Coeffs4, Scratch2); + sub_add(Coeffs4, Coeffs2, Coeffs2, Scratch1, vector_len, _masm); + + // level 1 - 64 + __ vmovdqu(Zetas1[0], Address(zetas, 512), vector_len); + montMul64(Scratch1, Coeffs2, Zetas1, Coeffs2, Scratch2); + sub_add(Coeffs2, Coeffs1, Coeffs1, Scratch1, vector_len, _masm); + + __ vmovdqu(Zetas1[0], Address(zetas, 4*64 + 512), vector_len); + montMul64(Scratch1, Coeffs4, Zetas1, Coeffs4, Scratch2); + sub_add(Coeffs4, Coeffs3, Coeffs3, Scratch1, vector_len, _masm); + + // level 2 - 32 + loadXmms(Zetas3, zetas, 2 * 512, vector_len, _masm, 2, 128); + montMul64(Scratch1, Coeffs2_1, Zetas3, Coeffs2_1, Scratch2); + sub_add(Coeffs2_1, Coeffs1_1, Coeffs1_1, Scratch1, vector_len, _masm); + + loadXmms(Zetas3, zetas, 4*64 + 2 * 512, vector_len, _masm, 2, 128); + montMul64(Scratch1, Coeffs4_1, Zetas3, Coeffs4_1, Scratch2); + sub_add(Coeffs4_1, Coeffs3_1, Coeffs3_1, Scratch1, vector_len, _masm); + + storeXmms(coeffs, i*32 + 0*128, Coeffs1_2, vector_len, _masm, 4, 128); + storeXmms(coeffs, i*32 + 4*128, Coeffs2_2, vector_len, _masm, 4, 128); + } - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i), xmm(i + 20), xmm(i + 21), Assembler::AVX_512bit); - __ evpermi2d(xmm(i + 1), xmm(i + 20), xmm(i + 21), Assembler::AVX_512bit); + // Four batches of 8 registers, consecutive loads + for (int i=0; i<4; i++) { + loadXmms(Coeffs1_2, coeffs, i*256, vector_len, _masm, 4); + loadXmms(Coeffs2_2, coeffs, 128 + i*256, vector_len, _masm, 4); + + // level 3 - 16 + __ vmovdqu(Zetas1[0], Address(zetas, i*128 + 3 * 512), vector_len); + montMul64(Scratch1, Coeffs2, Zetas1, Coeffs2, Scratch2); + sub_add(Coeffs2, Coeffs1, Coeffs1, Scratch1, vector_len, _masm); + + __ vmovdqu(Zetas1[0], Address(zetas, i*128 + 64 + 3 * 512), vector_len); + montMul64(Scratch1, Coeffs4, Zetas1, Coeffs4, Scratch2); + sub_add(Coeffs4, Coeffs3, Coeffs3, Scratch1, vector_len, _masm); + + // level 4 - 8 + loadXmms(Zetas3, zetas, i*128 + 4 * 512, vector_len, _masm); + montMul64(Scratch1, Coeffs2_1, Zetas3, Coeffs2_1, Scratch2); + sub_add(Coeffs2_1, Coeffs1_1, Coeffs1_1, Scratch1, vector_len, _masm); + + loadXmms(Zetas3, zetas, i*128 + 64 + 4 * 512, vector_len, _masm); + montMul64(Scratch1, Coeffs4_1, Zetas3, Coeffs4_1, Scratch2); + sub_add(Coeffs4_1, Coeffs3_1, Coeffs3_1, Scratch1, vector_len, _masm); + + for (int level = 5, distance = 4; level<8; level++, distance /= 2) { + // zetas = load(level * 512) + // coeffs1_2, scratch1 = shuffle(coeffs1_2, coeffs2_2) + // scratch1 = scratch1 * zetas + // coeffs2_2 = coeffs1_2 - scratch1 + // coeffs1_2 = coeffs1_2 + scratch1 + loadXmms(Zetas3, zetas, i*128 + level * 512, vector_len, _masm); + shuffle(Scratch1, Coeffs1_1, Coeffs2_1, distance * 32); //Coeffs2_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs2_1, Scratch2, level==7); + sub_add(Coeffs2_1, Coeffs1_1, Coeffs1_1, Scratch1, vector_len, _masm); + + loadXmms(Zetas3, zetas, i*128 + 64 + level * 512, vector_len, _masm); + shuffle(Scratch1, Coeffs3_1, Coeffs4_1, distance * 32); //Coeffs4_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs4_1, Scratch2, level==7); + sub_add(Coeffs4_1, Coeffs3_1, Coeffs3_1, Scratch1, vector_len, _masm); + } + + shuffle(Scratch1, Coeffs1_1, Coeffs2_1, 0); + shuffle(Scratch1, Coeffs3_1, Coeffs4_1, 0); + + storeXmms(coeffs, i*256, Coeffs1_2, vector_len, _masm, 4); + storeXmms(coeffs, 128 + i*256, Coeffs2_2, vector_len, _masm, 4); + } } - __ cmpl(iterations, 0); - __ jcc(Assembler::equal, L_end); - - store4Xmms(coeffs, 0, xmm0_3, _masm); - store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm); - - load4Xmms(xmm0_3, coeffs, 8 * XMMBYTES, _masm); - load4Xmms(xmm4_7, coeffs, 12 * XMMBYTES, _masm); - - __ addptr(zetas, 4 * XMMBYTES); - - __ jmp(L_loop); - - __ BIND(L_end); - - store4Xmms(coeffs, 8 * XMMBYTES, xmm0_3, _masm); - store4Xmms(coeffs, 12 * XMMBYTES, xmm4_7, _masm); - __ leave(); // required for proper stackwalking of RuntimeStub frame __ mov64(rax, 0); // return 0 __ ret(0); @@ -460,172 +648,233 @@ static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen, // // coeffs (int[256]) = c_rarg0 // zetas (int[256]) = c_rarg1 -static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen, - MacroAssembler *_masm) { - +static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, + int vector_len,MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubGenStubId stub_id = dilithiumAlmostInverseNtt_id; StubCodeMark mark(stubgen, stub_id); address start = __ pc(); __ enter(); - Label L_loop, L_end; - const Register coeffs = c_rarg0; const Register zetas = c_rarg1; + const Register scratch = r10; + + // AVX2 version uses the first half of these arrays + const XMMRegister Coeffs1[] = {xmm0, xmm1, xmm16, xmm17}; + const XMMRegister Coeffs2[] = {xmm2, xmm3, xmm18, xmm19}; + const XMMRegister Coeffs3[] = {xmm4, xmm5, xmm20, xmm21}; + const XMMRegister Coeffs4[] = {xmm6, xmm7, xmm22, xmm23}; + const XMMRegister Scratch1[] = {xmm8, xmm9, xmm24, xmm25}; + const XMMRegister Scratch2[] = {xmm10, xmm11, xmm26, xmm27}; + const XMMRegister Zetas1[] = {xmm12, xmm12, xmm12, xmm12}; + const XMMRegister Zetas2[] = {xmm12, xmm12, xmm13, xmm13}; + const XMMRegister Zetas3[] = {xmm12, xmm13, xmm28, xmm29}; + const XMMRegister montQInvModR = xmm14; + const XMMRegister dilithium_q = xmm15; + const XMMRegister unshuffle1 = xmm30; + const XMMRegister unshuffle2 = xmm31; + KRegister mergeMask1 = k1; + KRegister mergeMask2 = k2; + // lambdas to hide repeated parameters + auto shuffle = whole_shuffle(scratch, mergeMask1, mergeMask2, unshuffle1, unshuffle2, vector_len, _masm); + auto montMul64 = whole_montMul(montQInvModR, dilithium_q, mergeMask2, vector_len, _masm); - const Register iterations = c_rarg2; - - const Register perms = r11; - - __ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr())); - - __ evmovdqul(montMulPerm, Address(perms, montMulPermsIdx), Assembler::AVX_512bit); __ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr(montQInvModRIdx)), - Assembler::AVX_512bit, scratch); // q^-1 mod 2^32 + vector_len, scratch); // q^-1 mod 2^32 __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)), - Assembler::AVX_512bit, scratch); // q + vector_len, scratch); // q // Each level represents one iteration of the outer for loop of the // Java version. // In each of these iterations half of the coefficients are added to and // subtracted from the other half of the coefficients then the result of - // the substartion is (Montgomery) multiplied by the corresponding zetas. - // In each level we just collect the coefficients (using evpermi2d() - // instructions where necessary, i.e. on levels 0-4) so that the results of + // the substration is (Montgomery) multiplied by the corresponding zetas. + // In each level we just shuffle the coefficients so that the results of // the additions and subtractions go to the vector registers so that they // align with each other and the zetas. - // We do levels 0-6 in two batches, each batch entirely in the vector registers - load4Xmms(xmm0_3, coeffs, 0, _masm); - load4Xmms(xmm4_7, coeffs, 4 * XMMBYTES, _masm); - - __ movl(iterations, 2); - - __ align(OptoLoopAlignment); - __ BIND(L_loop); - - __ subl(iterations, 1); - - // level 0 - loadPerm(xmm8_11, perms, nttInvL0PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttInvL0PermsIdx + 64, _masm); - - for (int i = 0; i < 8; i += 2) { - __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); - } - - load4Xmms(xmm4_7, zetas, 0, _masm); - sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); - montMul64(xmm4_7, xmm4_7, xmm24_27, xmm16_27, true, _masm); - - // level 1 - loadPerm(xmm8_11, perms, nttInvL1PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttInvL1PermsIdx + 64, _masm); - - for (int i = 0; i < 4; i++) { - __ evpermi2d(xmm(i + 8), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - __ evpermi2d(xmm(i + 12), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - } - - load4Xmms(xmm4_7, zetas, 512, _masm); - sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); - montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); - - // level 2 - loadPerm(xmm8_11, perms, nttInvL2PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttInvL2PermsIdx + 64, _masm); - - for (int i = 0; i < 4; i++) { - __ evpermi2d(xmm(i + 8), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - __ evpermi2d(xmm(i + 12), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - } - - load4Xmms(xmm4_7, zetas, 2 * 512, _masm); - sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); - montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); - - // level 3 - loadPerm(xmm8_11, perms, nttInvL3PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttInvL3PermsIdx + 64, _masm); - - for (int i = 0; i < 4; i++) { - __ evpermi2d(xmm(i + 8), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - __ evpermi2d(xmm(i + 12), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - } - - load4Xmms(xmm4_7, zetas, 3 * 512, _masm); - sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); - montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); - - // level 4 - loadPerm(xmm8_11, perms, nttInvL4PermsIdx, _masm); - loadPerm(xmm12_15, perms, nttInvL4PermsIdx + 64, _masm); - - for (int i = 0; i < 4; i++) { - __ evpermi2d(xmm(i + 8), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - __ evpermi2d(xmm(i + 12), xmm(i), xmm(i + 4), Assembler::AVX_512bit); - } - - load4Xmms(xmm4_7, zetas, 4 * 512, _masm); - sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); - montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); - - // level 5 - load4Xmms(xmm12_15, zetas, 5 * 512, _masm); - sub_add(xmm8_11, xmm0_3, xmm0426, xmm1537, _masm); - montMul64(xmm4_7, xmm8_11, xmm12_15, xmm16_27, _masm); - - // level 6 - load4Xmms(xmm12_15, zetas, 6 * 512, _masm); - sub_add(xmm8_11, xmm0_3, xmm0145, xmm2367, _masm); - montMul64(xmm4_7, xmm8_11, xmm12_15, xmm16_27, _masm); - - __ cmpl(iterations, 0); - __ jcc(Assembler::equal, L_end); - - // save the coefficients of the first batch, adjust the zetas - // and load the second batch of coefficients - store4Xmms(coeffs, 0, xmm0_3, _masm); - store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm); - - __ addptr(zetas, 4 * XMMBYTES); + if (vector_len == Assembler::AVX_512bit) { + // levels 4-7, register shuffles: + const XMMRegister Coeffs1_1[] = {xmm0, xmm1, xmm2, xmm3}; + const XMMRegister Coeffs2_1[] = {xmm16, xmm17, xmm18, xmm19}; + const XMMRegister Coeffs3_1[] = {xmm4, xmm5, xmm6, xmm7}; + const XMMRegister Coeffs4_1[] = {xmm20, xmm21, xmm22, xmm23}; + const XMMRegister Coeffs1_2[] = {xmm0, xmm16, xmm2, xmm18}; + const XMMRegister Coeffs2_2[] = {xmm1, xmm17, xmm3, xmm19}; + const XMMRegister Coeffs3_2[] = {xmm4, xmm20, xmm6, xmm22}; + const XMMRegister Coeffs4_2[] = {xmm5, xmm21, xmm7, xmm23}; + + // Constants for shuffle and montMul64 + __ mov64(scratch, 0b1010101010101010); + __ kmovwl(mergeMask1, scratch); + __ knotwl(mergeMask2, mergeMask1); + __ vmovdqu(unshuffle1, ExternalAddress(unshufflePermsAddr(4)), vector_len, scratch); + __ vmovdqu(unshuffle2, ExternalAddress(unshufflePermsAddr(5)), vector_len, scratch); + + int memStep = 4 * 64; + loadXmms(Coeffs1, coeffs, 0*memStep, vector_len, _masm); + loadXmms(Coeffs2, coeffs, 1*memStep, vector_len, _masm); + loadXmms(Coeffs3, coeffs, 2*memStep, vector_len, _masm); + loadXmms(Coeffs4, coeffs, 3*memStep, vector_len, _masm); + + shuffle(Scratch1, Coeffs1_2, Coeffs2_2, 1); + shuffle(Scratch1, Coeffs3_2, Coeffs4_2, 1); + + // Constants for shuffle(128) + __ vmovdqu(unshuffle1, ExternalAddress(unshufflePermsAddr(0)), vector_len, scratch); + __ vmovdqu(unshuffle2, ExternalAddress(unshufflePermsAddr(1)), vector_len, scratch); + for (int level = 0, distance = 1; level<4; level++, distance *= 2) { + // zetas = load(level * 512) + // coeffs1_2 = coeffs1_2 + coeffs2_2 + // scratch1 = coeffs1_2 - coeffs2_2 + // scratch1 = scratch1 * zetas + // coeffs1_2, coeffs2_2 = shuffle(coeffs1_2, scratch1) + loadXmms(Zetas3, zetas, level * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs1_2, Coeffs1_2, Coeffs2_2, vector_len, _masm); // Coeffs2_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs2_2, Scratch2, level==0); + shuffle(Coeffs2_2, Coeffs1_2, Scratch1, distance * 32); + + loadXmms(Zetas3, zetas, 4*64 + level * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs3_2, Coeffs3_2, Coeffs4_2, vector_len, _masm); // Coeffs4_2 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs4_2, Scratch2, level==0); + shuffle(Coeffs4_2, Coeffs3_2, Scratch1, distance * 32); + } - load4Xmms(xmm0_3, coeffs, 8 * XMMBYTES, _masm); - load4Xmms(xmm4_7, coeffs, 12 * XMMBYTES, _masm); + // level 4 + loadXmms(Zetas3, zetas, 4 * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs1_2, Coeffs1_2, Coeffs2_2, vector_len, _masm); // Coeffs2_2 freed + montMul64(Coeffs2_2, Scratch1, Zetas3, Scratch1, Scratch2); - __ jmp(L_loop); + loadXmms(Zetas3, zetas, 4*64 + 4 * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs3_2, Coeffs3_2, Coeffs4_2, vector_len, _masm); // Coeffs4_2 freed + montMul64(Coeffs4_2, Scratch1, Zetas3, Scratch1, Scratch2); - __ BIND(L_end); + // level 5 + __ vmovdqu(Zetas2[0], Address(zetas, 5 * 512), vector_len); + __ vmovdqu(Zetas2[2], Address(zetas, 2*64 + 5 * 512), vector_len); + sub_add(Scratch1, Coeffs1_1, Coeffs1_1, Coeffs2_1, vector_len, _masm); // Coeffs2_1 freed + montMul64(Coeffs2_1, Scratch1, Zetas2, Scratch1, Scratch2); - // load the coeffs of the first batch of coefficients that were saved after - // level 6 into Zmm_8-Zmm_15 and do the last level entirely in the vector - // registers - load4Xmms(xmm8_11, coeffs, 0, _masm); - load4Xmms(xmm12_15, coeffs, 4 * XMMBYTES, _masm); + __ vmovdqu(Zetas2[0], Address(zetas, 4*64 + 5 * 512), vector_len); + __ vmovdqu(Zetas2[2], Address(zetas, 6*64 + 5 * 512), vector_len); + sub_add(Scratch1, Coeffs3_1, Coeffs3_1, Coeffs4_1, vector_len, _masm); // Coeffs4_1 freed + montMul64(Coeffs4_1, Scratch1, Zetas2, Scratch1, Scratch2); - // level 7 + // level 6 + __ vmovdqu(Zetas1[0], Address(zetas, 6 * 512), vector_len); + sub_add(Scratch1, Coeffs1, Coeffs1, Coeffs2, vector_len, _masm); // Coeffs2 freed + montMul64(Coeffs2, Scratch1, Zetas1, Scratch1, Scratch2); - loadXmm29(zetas, 7 * 512, _masm); + __ vmovdqu(Zetas1[0], Address(zetas, 4*64 + 6 * 512), vector_len); + sub_add(Scratch1, Coeffs3, Coeffs3, Coeffs4, vector_len, _masm); // Coeffs4 freed + montMul64(Coeffs4, Scratch1, Zetas1, Scratch1, Scratch2); - for (int i = 0; i < 8; i++) { - __ evpaddd(xmm(i + 16), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit); - } + // level 7 + __ vmovdqu(Zetas1[0], Address(zetas, 7 * 512), vector_len); + sub_add(Scratch1, Coeffs1, Coeffs1, Coeffs3, vector_len, _masm); // Coeffs3 freed + montMul64(Coeffs3, Scratch1, Zetas1, Scratch1, Scratch2); + sub_add(Scratch1, Coeffs2, Coeffs2, Coeffs4, vector_len, _masm); // Coeffs4 freed + montMul64(Coeffs4, Scratch1, Zetas1, Scratch1, Scratch2); + + storeXmms(coeffs, 0*memStep, Coeffs1, vector_len, _masm); + storeXmms(coeffs, 1*memStep, Coeffs2, vector_len, _masm); + storeXmms(coeffs, 2*memStep, Coeffs3, vector_len, _masm); + storeXmms(coeffs, 3*memStep, Coeffs4, vector_len, _masm); + } else { // Assembler::AVX_256bit + // Permutations of Coeffs1, Coeffs2, Coeffs3 and Coeffs4 + const XMMRegister Coeffs1_1[] = {xmm0, xmm2}; + const XMMRegister Coeffs2_1[] = {xmm1, xmm3}; + const XMMRegister Coeffs3_1[] = {xmm4, xmm6}; + const XMMRegister Coeffs4_1[] = {xmm5, xmm7}; + + const XMMRegister Coeffs1_2[] = {xmm0, xmm1, xmm2, xmm3}; + const XMMRegister Coeffs2_2[] = {xmm4, xmm5, xmm6, xmm7}; + + // Four batches of 8 registers, consecutive loads + for (int i=0; i<4; i++) { + loadXmms(Coeffs1_2, coeffs, i*256, vector_len, _masm, 4); + loadXmms(Coeffs2_2, coeffs, 128 + i*256, vector_len, _masm, 4); + + shuffle(Scratch1, Coeffs1_1, Coeffs2_1, 1); + shuffle(Scratch1, Coeffs3_1, Coeffs4_1, 1); + + for (int level = 0, distance = 1; level <= 2; level++, distance *= 2) { + // zetas = load(level * 512) + // coeffs1_2 = coeffs1_2 + coeffs2_2 + // scratch1 = coeffs1_2 - coeffs2_2 + // scratch1 = scratch1 * zetas + // coeffs1_2, coeffs2_2 = shuffle(coeffs1_2, scratch1) + loadXmms(Zetas3, zetas, i*128 + level * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs1_1, Coeffs1_1, Coeffs2_1, vector_len, _masm); // Coeffs2_1 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs2_1, Scratch2, level==0); + shuffle(Coeffs2_1, Coeffs1_1, Scratch1, distance * 32); + + loadXmms(Zetas3, zetas, i*128 + 64 + level * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs3_1, Coeffs3_1, Coeffs4_1, vector_len, _masm); // Coeffs4_1 freed + montMul64(Scratch1, Scratch1, Zetas3, Coeffs4_1, Scratch2, level==0); + shuffle(Coeffs4_1, Coeffs3_1, Scratch1, distance * 32); + } + + // level 3 + loadXmms(Zetas3, zetas, i*128 + 3 * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs1_1, Coeffs1_1, Coeffs2_1, vector_len, _masm); // Coeffs2_1 freed + montMul64(Coeffs2_1, Scratch1, Zetas3, Scratch1, Scratch2); + + loadXmms(Zetas3, zetas, i*128 + 64 + 3 * 512, vector_len, _masm); + sub_add(Scratch1, Coeffs3_1, Coeffs3_1, Coeffs4_1, vector_len, _masm); // Coeffs4_1 freed + montMul64(Coeffs4_1, Scratch1, Zetas3, Scratch1, Scratch2); + + // level 4 + __ vmovdqu(Zetas1[0], Address(zetas, i*128 + 4 * 512), vector_len); + sub_add(Scratch1, Coeffs1, Coeffs1, Coeffs2, vector_len, _masm); // Coeffs2 freed + montMul64(Coeffs2, Scratch1, Zetas1, Scratch1, Scratch2); + + __ vmovdqu(Zetas1[0], Address(zetas, i*128 + 64 + 4 * 512), vector_len); + sub_add(Scratch1, Coeffs3, Coeffs3, Coeffs4, vector_len, _masm); // Coeffs4 freed + montMul64(Coeffs4, Scratch1, Zetas1, Scratch1, Scratch2); + + storeXmms(coeffs, i*256, Coeffs1_2, vector_len, _masm, 4); + storeXmms(coeffs, 128 + i*256, Coeffs2_2, vector_len, _masm, 4); + } - for (int i = 0; i < 8; i++) { - __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit); + // Four batches of 8 registers each, 128 bytes appart + for (int i=0; i<4; i++) { + loadXmms(Coeffs1_2, coeffs, i*32 + 0*128, vector_len, _masm, 4, 128); + loadXmms(Coeffs2_2, coeffs, i*32 + 4*128, vector_len, _masm, 4, 128); + + // level 5 + loadXmms(Zetas3, zetas, 5 * 512, vector_len, _masm, 2, 128); + sub_add(Scratch1, Coeffs1_1, Coeffs1_1, Coeffs2_1, vector_len, _masm); // Coeffs2_1 freed + montMul64(Coeffs2_1, Scratch1, Zetas3, Scratch1, Scratch2); + + loadXmms(Zetas3, zetas, 4*64 + 5 * 512, vector_len, _masm, 2, 128); + sub_add(Scratch1, Coeffs3_1, Coeffs3_1, Coeffs4_1, vector_len, _masm); // Coeffs4_1 freed + montMul64(Coeffs4_1, Scratch1, Zetas3, Scratch1, Scratch2); + + // level 6 + __ vmovdqu(Zetas1[0], Address(zetas, 6 * 512), vector_len); + sub_add(Scratch1, Coeffs1, Coeffs1, Coeffs2, vector_len, _masm); // Coeffs2 freed + montMul64(Coeffs2, Scratch1, Zetas1, Scratch1, Scratch2); + + __ vmovdqu(Zetas1[0], Address(zetas, 4*64 + 6 * 512), vector_len); + sub_add(Scratch1, Coeffs3, Coeffs3, Coeffs4, vector_len, _masm); // Coeffs4 freed + montMul64(Coeffs4, Scratch1, Zetas1, Scratch1, Scratch2); + + // level 7 + __ vmovdqu(Zetas1[0], Address(zetas, 7 * 512), vector_len); + sub_add(Scratch1, Coeffs1, Coeffs1, Coeffs3, vector_len, _masm); // Coeffs3 freed + montMul64(Coeffs3, Scratch1, Zetas1, Scratch1, Scratch2); + sub_add(Scratch1, Coeffs2, Coeffs2, Coeffs4, vector_len, _masm); // Coeffs4 freed + montMul64(Coeffs4, Scratch1, Zetas1, Scratch1, Scratch2); + + storeXmms(coeffs, i*32 + 0*128, Coeffs1_2, vector_len, _masm, 4, 128); + storeXmms(coeffs, i*32 + 4*128, Coeffs2_2, vector_len, _masm, 4, 128); + } } - store4Xmms(coeffs, 0, xmm16_19, _masm); - store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm); - montMul64(xmm0_3, xmm0_3, xmm29_29, xmm16_27, _masm); - montMul64(xmm4_7, xmm4_7, xmm29_29, xmm16_27, _masm); - store4Xmms(coeffs, 8 * XMMBYTES, xmm0_3, _masm); - store4Xmms(coeffs, 12 * XMMBYTES, xmm4_7, _masm); - __ leave(); // required for proper stackwalking of RuntimeStub frame __ mov64(rax, 0); // return 0 __ ret(0); @@ -641,8 +890,8 @@ static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen, // result (int[256]) = c_rarg0 // poly1 (int[256]) = c_rarg1 // poly2 (int[256]) = c_rarg2 -static address generate_dilithiumNttMult_avx512(StubGenerator *stubgen, - MacroAssembler *_masm) { +static address generate_dilithiumNttMult_avx(StubGenerator *stubgen, + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubGenStubId stub_id = dilithiumNttMult_id; @@ -655,40 +904,60 @@ static address generate_dilithiumNttMult_avx512(StubGenerator *stubgen, const Register result = c_rarg0; const Register poly1 = c_rarg1; const Register poly2 = c_rarg2; - - const Register perms = r10; // scratch reused after not needed any more + const Register scratch = r10; const Register len = r11; - const XMMRegister montRSquareModQ = xmm29; + const XMMRegister montQInvModR = xmm8; + const XMMRegister dilithium_q = xmm9; + + const XMMRegister Poly1[] = {xmm0, xmm1, xmm16, xmm17}; + const XMMRegister Poly2[] = {xmm2, xmm3, xmm18, xmm19}; + const XMMRegister Scratch1[] = {xmm4, xmm5, xmm20, xmm21}; + const XMMRegister Scratch2[] = {xmm6, xmm7, xmm22, xmm23}; + const XMMRegister MontRSquareModQ[] = {xmm10, xmm10, xmm10, xmm10}; + KRegister mergeMask = k1; + // lambda to hide repeated parameters + auto montMul64 = whole_montMul(montQInvModR, dilithium_q, mergeMask, vector_len, _masm); __ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr(montQInvModRIdx)), - Assembler::AVX_512bit, scratch); // q^-1 mod 2^32 + vector_len, scratch); // q^-1 mod 2^32 __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)), - Assembler::AVX_512bit, scratch); // q - __ vpbroadcastd(montRSquareModQ, + vector_len, scratch); // q + __ vpbroadcastd(MontRSquareModQ[0], ExternalAddress(dilithiumAvx512ConstsAddr(montRSquareModQIdx)), - Assembler::AVX_512bit, scratch); // 2^64 mod q + vector_len, scratch); // 2^64 mod q + if (vector_len == Assembler::AVX_512bit) { + __ mov64(scratch, 0b0101010101010101); + __ kmovwl(mergeMask, scratch); + } - __ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr())); - __ evmovdqul(montMulPerm, Address(perms, montMulPermsIdx), Assembler::AVX_512bit); + // Total payload is 256*int32s. + // - memStep is number of bytes one iteration processes. + // - loopCnt is number of iterations it will take to process entire payload. + int loopCnt = 4; + int memStep = 4 * 64; + if (vector_len == Assembler::AVX_256bit) { + loopCnt = 16; + memStep = 2 * 32; + } - __ movl(len, 4); + __ movl(len, loopCnt); __ align(OptoLoopAlignment); __ BIND(L_loop); - load4Xmms(xmm4_7, poly2, 0, _masm); - load4Xmms(xmm0_3, poly1, 0, _masm); - montMul64(xmm4_7, xmm4_7, xmm29_29, xmm16_27, _masm); - montMul64(xmm0_3, xmm0_3, xmm4_7, xmm16_27, true, _masm); - store4Xmms(result, 0, xmm0_3, _masm); + loadXmms(Poly2, poly2, 0, vector_len, _masm); + loadXmms(Poly1, poly1, 0, vector_len, _masm); + montMul64(Poly2, Poly2, MontRSquareModQ, Scratch1, Scratch2); + montMul64(Poly1, Poly1, Poly2, Scratch1, Scratch2, true); + storeXmms(result, 0, Poly1, vector_len, _masm); __ subl(len, 1); - __ addptr(poly1, 4 * XMMBYTES); - __ addptr(poly2, 4 * XMMBYTES); - __ addptr(result, 4 * XMMBYTES); + __ addptr(poly1, memStep); + __ addptr(poly2, memStep); + __ addptr(result, memStep); __ cmpl(len, 0); __ jcc(Assembler::notEqual, L_loop); @@ -705,8 +974,8 @@ static address generate_dilithiumNttMult_avx512(StubGenerator *stubgen, // // coeffs (int[256]) = c_rarg0 // constant (int) = c_rarg1 -static address generate_dilithiumMontMulByConstant_avx512(StubGenerator *stubgen, - MacroAssembler *_masm) { +static address generate_dilithiumMontMulByConstant_avx(StubGenerator *stubgen, + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubGenStubId stub_id = dilithiumMontMulByConstant_id; @@ -718,38 +987,63 @@ static address generate_dilithiumMontMulByConstant_avx512(StubGenerator *stubgen const Register coeffs = c_rarg0; const Register rConstant = c_rarg1; - - const Register perms = c_rarg2; // not used for argument + const Register scratch = r10; const Register len = r11; - const XMMRegister constant = xmm29; + const XMMRegister montQInvModR = xmm8; + const XMMRegister dilithium_q = xmm9; - __ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr())); + const XMMRegister Coeffs1[] = {xmm0, xmm1, xmm16, xmm17}; + const XMMRegister Coeffs2[] = {xmm2, xmm3, xmm18, xmm19}; + const XMMRegister Scratch1[] = {xmm4, xmm5, xmm20, xmm21}; + const XMMRegister Scratch2[] = {xmm6, xmm7, xmm22, xmm23}; + const XMMRegister Constant[] = {xmm10, xmm10, xmm10, xmm10}; + XMMRegister constant = Constant[0]; + KRegister mergeMask = k1; + // lambda to hide repeated parameters + auto montMul64 = whole_montMul(montQInvModR, dilithium_q, mergeMask, vector_len, _masm); - // the following four vector registers are used in montMul64 + // load constants for montMul64 __ vpbroadcastd(montQInvModR, ExternalAddress(dilithiumAvx512ConstsAddr(montQInvModRIdx)), - Assembler::AVX_512bit, scratch); // q^-1 mod 2^32 + vector_len, scratch); // q^-1 mod 2^32 __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)), - Assembler::AVX_512bit, scratch); // q - __ evmovdqul(montMulPerm, Address(perms, montMulPermsIdx), Assembler::AVX_512bit); - __ evpbroadcastd(constant, rConstant, Assembler::AVX_512bit); // constant multiplier + vector_len, scratch); // q + if (vector_len == Assembler::AVX_256bit) { + __ movdl(constant, rConstant); + __ vpbroadcastd(constant, constant, vector_len); // constant multiplier + } else { + __ evpbroadcastd(constant, rConstant, Assembler::AVX_512bit); // constant multiplier + + __ mov64(scratch, 0b0101010101010101); //dw-mask + __ kmovwl(k2, scratch); + } + + // Total payload is 256*int32s. + // - memStep is number of bytes one montMul64 processes. + // - loopCnt is number of iterations it will take to process entire payload. + int memStep = 4 * 64; + int loopCnt = 2; + if (vector_len == Assembler::AVX_256bit) { + memStep = 2 * 32; + loopCnt = 8; + } - __ movl(len, 2); + __ movl(len, loopCnt); __ align(OptoLoopAlignment); __ BIND(L_loop); - load4Xmms(xmm0_3, coeffs, 0, _masm); - load4Xmms(xmm4_7, coeffs, 4 * XMMBYTES, _masm); - montMul64(xmm0_3, xmm0_3, xmm29_29, xmm16_27, _masm); - montMul64(xmm4_7, xmm4_7, xmm29_29, xmm16_27, _masm); - store4Xmms(coeffs, 0, xmm0_3, _masm); - store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm); + loadXmms(Coeffs1, coeffs, 0, vector_len, _masm); + loadXmms(Coeffs2, coeffs, memStep, vector_len, _masm); + montMul64(Coeffs1, Coeffs1, Constant, Scratch1, Scratch2); + montMul64(Coeffs2, Coeffs2, Constant, Scratch1, Scratch2); + storeXmms(coeffs, 0, Coeffs1, vector_len, _masm); + storeXmms(coeffs, memStep, Coeffs2, vector_len, _masm); __ subl(len, 1); - __ addptr(coeffs, 512); + __ addptr(coeffs, 2 * memStep); __ cmpl(len, 0); __ jcc(Assembler::notEqual, L_loop); @@ -769,9 +1063,8 @@ static address generate_dilithiumMontMulByConstant_avx512(StubGenerator *stubgen // highPart (int[256]) = c_rarg2 // twoGamma2 (int) = c_rarg3 // multiplier (int) = c_rarg4 -static address generate_dilithiumDecomposePoly_avx512(StubGenerator *stubgen, - MacroAssembler *_masm) { - +static address generate_dilithiumDecomposePoly_avx(StubGenerator *stubgen, + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubGenStubId stub_id = dilithiumDecomposePoly_id; StubCodeMark mark(stubgen, stub_id); @@ -785,26 +1078,45 @@ static address generate_dilithiumDecomposePoly_avx512(StubGenerator *stubgen, const Register highPart = c_rarg2; const Register rTwoGamma2 = c_rarg3; + const Register scratch = r10; const Register len = r11; - const XMMRegister zero = xmm24; - const XMMRegister one = xmm25; - const XMMRegister qMinus1 = xmm26; - const XMMRegister gamma2 = xmm27; - const XMMRegister twoGamma2 = xmm28; - const XMMRegister barrettMultiplier = xmm29; - const XMMRegister barrettAddend = xmm30; - - __ vpxor(zero, zero, zero, Assembler::AVX_512bit); // 0 - __ vpternlogd(xmm0, 0xff, xmm0, xmm0, Assembler::AVX_512bit); // -1 - __ vpsubd(one, zero, xmm0, Assembler::AVX_512bit); // 1 + + const XMMRegister one = xmm0; + const XMMRegister gamma2 = xmm1; + const XMMRegister twoGamma2 = xmm2; + const XMMRegister barrettMultiplier = xmm3; + const XMMRegister barrettAddend = xmm4; + const XMMRegister dilithium_q = xmm5; + const XMMRegister zero = xmm29; // AVX512-only + const XMMRegister minusOne = xmm30; // AVX512-only + const XMMRegister qMinus1 = xmm31; // AVX512-only + + XMMRegister RPlus[] = {xmm6, xmm7, xmm16, xmm17}; + XMMRegister Quotient[] = {xmm8, xmm9, xmm18, xmm19}; + XMMRegister R0[] = {xmm10, xmm11, xmm20, xmm21}; + XMMRegister Mask[] = {xmm12, xmm13, xmm22, xmm23}; + XMMRegister Tmp1[] = {xmm14, xmm15, xmm24, xmm25}; + __ vpbroadcastd(dilithium_q, ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)), - Assembler::AVX_512bit, scratch); // q + vector_len, scratch); // q __ vpbroadcastd(barrettAddend, ExternalAddress(dilithiumAvx512ConstsAddr(barrettAddendIdx)), - Assembler::AVX_512bit, scratch); // addend for Barrett reduction + vector_len, scratch); // addend for Barrett reduction + if (vector_len == Assembler::AVX_512bit) { + __ vpxor(zero, zero, zero, vector_len); // 0 + __ vpternlogd(minusOne, 0xff, minusOne, minusOne, vector_len); // -1 + __ vpsrld(one, minusOne, 31, vector_len); + __ vpsubd(qMinus1, dilithium_q, one, vector_len); // q - 1 + __ evpbroadcastd(twoGamma2, rTwoGamma2, vector_len); // 2 * gamma2 + } else { + __ vpcmpeqd(one, one, one, vector_len); + __ vpsrld(one, one, 31, vector_len); + __ movdl(twoGamma2, rTwoGamma2); + __ vpbroadcastd(twoGamma2, twoGamma2, vector_len); // 2 * gamma2 + } - __ evpbroadcastd(twoGamma2, rTwoGamma2, Assembler::AVX_512bit); // 2 * gamma2 + __ vpsrad(gamma2, twoGamma2, 1, vector_len); // gamma2 #ifndef _WIN64 const Register rMultiplier = c_rarg4; @@ -813,201 +1125,186 @@ static address generate_dilithiumDecomposePoly_avx512(StubGenerator *stubgen, const Register rMultiplier = c_rarg3; // arg3 is already consumed, reused here __ movptr(rMultiplier, multiplier_mem); #endif - __ evpbroadcastd(barrettMultiplier, rMultiplier, - Assembler::AVX_512bit); // multiplier for mod 2 * gamma2 reduce + if (vector_len == Assembler::AVX_512bit) { + __ evpbroadcastd(barrettMultiplier, rMultiplier, + vector_len); // multiplier for mod 2 * gamma2 reduce + } else { + __ movdl(barrettMultiplier, rMultiplier); + __ vpbroadcastd(barrettMultiplier, barrettMultiplier, vector_len); + } - __ evpsubd(qMinus1, k0, dilithium_q, one, false, Assembler::AVX_512bit); // q - 1 - __ evpsrad(gamma2, k0, twoGamma2, 1, false, Assembler::AVX_512bit); // gamma2 + // Total payload is 1024 bytes + int memStep = 4 * 64; // Number of bytes per loop iteration + int regCnt = 4; // Register array length + if (vector_len == Assembler::AVX_256bit) { + memStep = 2 * 32; + regCnt = 2; + } __ movl(len, 1024); __ align(OptoLoopAlignment); __ BIND(L_loop); - load4Xmms(xmm0_3, input, 0, _masm); + loadXmms(RPlus, input, 0, vector_len, _masm); - __ addptr(input, 4 * XMMBYTES); + __ addptr(input, memStep); - // rplus in xmm0 // rplus = rplus - ((rplus + 5373807) >> 23) * dilithium_q; - __ evpaddd(xmm4, k0, xmm0, barrettAddend, false, Assembler::AVX_512bit); - __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit); - __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit); - __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit); - - __ evpsrad(xmm4, k0, xmm4, 23, false, Assembler::AVX_512bit); - __ evpsrad(xmm5, k0, xmm5, 23, false, Assembler::AVX_512bit); - __ evpsrad(xmm6, k0, xmm6, 23, false, Assembler::AVX_512bit); - __ evpsrad(xmm7, k0, xmm7, 23, false, Assembler::AVX_512bit); - - __ evpmulld(xmm4, k0, xmm4, dilithium_q, false, Assembler::AVX_512bit); - __ evpmulld(xmm5, k0, xmm5, dilithium_q, false, Assembler::AVX_512bit); - __ evpmulld(xmm6, k0, xmm6, dilithium_q, false, Assembler::AVX_512bit); - __ evpmulld(xmm7, k0, xmm7, dilithium_q, false, Assembler::AVX_512bit); - - __ evpsubd(xmm0, k0, xmm0, xmm4, false, Assembler::AVX_512bit); - __ evpsubd(xmm1, k0, xmm1, xmm5, false, Assembler::AVX_512bit); - __ evpsubd(xmm2, k0, xmm2, xmm6, false, Assembler::AVX_512bit); - __ evpsubd(xmm3, k0, xmm3, xmm7, false, Assembler::AVX_512bit); - // rplus in xmm0 + for (int i = 0; i < regCnt; i++) { + __ vpaddd(Tmp1[i], RPlus[i], barrettAddend, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsrad(Tmp1[i], Tmp1[i], 23, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpmulld(Tmp1[i], Tmp1[i], dilithium_q, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsubd(RPlus[i], RPlus[i], Tmp1[i], vector_len); + } + // rplus = rplus + ((rplus >> 31) & dilithium_q); - __ evpsrad(xmm4, k0, xmm0, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm5, k0, xmm1, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm6, k0, xmm2, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm7, k0, xmm3, 31, false, Assembler::AVX_512bit); - - __ evpandd(xmm4, k0, xmm4, dilithium_q, false, Assembler::AVX_512bit); - __ evpandd(xmm5, k0, xmm5, dilithium_q, false, Assembler::AVX_512bit); - __ evpandd(xmm6, k0, xmm6, dilithium_q, false, Assembler::AVX_512bit); - __ evpandd(xmm7, k0, xmm7, dilithium_q, false, Assembler::AVX_512bit); - - __ evpaddd(xmm0, k0, xmm0, xmm4, false, Assembler::AVX_512bit); - __ evpaddd(xmm1, k0, xmm1, xmm5, false, Assembler::AVX_512bit); - __ evpaddd(xmm2, k0, xmm2, xmm6, false, Assembler::AVX_512bit); - __ evpaddd(xmm3, k0, xmm3, xmm7, false, Assembler::AVX_512bit); - // rplus in xmm0 + for (int i = 0; i < regCnt; i++) { + __ vpsrad(Tmp1[i], RPlus[i], 31, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpand(Tmp1[i], Tmp1[i], dilithium_q, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpaddd(RPlus[i], RPlus[i], Tmp1[i], vector_len); + } + // int quotient = (rplus * barrettMultiplier) >> 22; - __ evpmulld(xmm4, k0, xmm0, barrettMultiplier, false, Assembler::AVX_512bit); - __ evpmulld(xmm5, k0, xmm1, barrettMultiplier, false, Assembler::AVX_512bit); - __ evpmulld(xmm6, k0, xmm2, barrettMultiplier, false, Assembler::AVX_512bit); - __ evpmulld(xmm7, k0, xmm3, barrettMultiplier, false, Assembler::AVX_512bit); - - __ evpsrad(xmm4, k0, xmm4, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm5, k0, xmm5, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm6, k0, xmm6, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm7, k0, xmm7, 22, false, Assembler::AVX_512bit); - // quotient in xmm4 + for (int i = 0; i < regCnt; i++) { + __ vpmulld(Quotient[i], RPlus[i], barrettMultiplier, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsrad(Quotient[i], Quotient[i], 22, vector_len); + } + // int r0 = rplus - quotient * twoGamma2; - __ evpmulld(xmm8, k0, xmm4, twoGamma2, false, Assembler::AVX_512bit); - __ evpmulld(xmm9, k0, xmm5, twoGamma2, false, Assembler::AVX_512bit); - __ evpmulld(xmm10, k0, xmm6, twoGamma2, false, Assembler::AVX_512bit); - __ evpmulld(xmm11, k0, xmm7, twoGamma2, false, Assembler::AVX_512bit); - - __ evpsubd(xmm8, k0, xmm0, xmm8, false, Assembler::AVX_512bit); - __ evpsubd(xmm9, k0, xmm1, xmm9, false, Assembler::AVX_512bit); - __ evpsubd(xmm10, k0, xmm2, xmm10, false, Assembler::AVX_512bit); - __ evpsubd(xmm11, k0, xmm3, xmm11, false, Assembler::AVX_512bit); - // r0 in xmm8 + for (int i = 0; i < regCnt; i++) { + __ vpmulld(R0[i], Quotient[i], twoGamma2, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsubd(R0[i], RPlus[i], R0[i], vector_len); + } + // int mask = (twoGamma2 - r0) >> 22; - __ evpsubd(xmm12, k0, twoGamma2, xmm8, false, Assembler::AVX_512bit); - __ evpsubd(xmm13, k0, twoGamma2, xmm9, false, Assembler::AVX_512bit); - __ evpsubd(xmm14, k0, twoGamma2, xmm10, false, Assembler::AVX_512bit); - __ evpsubd(xmm15, k0, twoGamma2, xmm11, false, Assembler::AVX_512bit); - - __ evpsrad(xmm12, k0, xmm12, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm13, k0, xmm13, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm14, k0, xmm14, 22, false, Assembler::AVX_512bit); - __ evpsrad(xmm15, k0, xmm15, 22, false, Assembler::AVX_512bit); - // mask in xmm12 + for (int i = 0; i < regCnt; i++) { + __ vpsubd(Mask[i], twoGamma2, R0[i], vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsrad(Mask[i], Mask[i], 22, vector_len); + } + // r0 -= (mask & twoGamma2); - __ evpandd(xmm16, k0, xmm12, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm17, k0, xmm13, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm18, k0, xmm14, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm19, k0, xmm15, twoGamma2, false, Assembler::AVX_512bit); - - __ evpsubd(xmm8, k0, xmm8, xmm16, false, Assembler::AVX_512bit); - __ evpsubd(xmm9, k0, xmm9, xmm17, false, Assembler::AVX_512bit); - __ evpsubd(xmm10, k0, xmm10, xmm18, false, Assembler::AVX_512bit); - __ evpsubd(xmm11, k0, xmm11, xmm19, false, Assembler::AVX_512bit); - // r0 in xmm8 + for (int i = 0; i < regCnt; i++) { + __ vpand(Tmp1[i], Mask[i], twoGamma2, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsubd(R0[i], R0[i], Tmp1[i], vector_len); + } + // quotient += (mask & 1); - __ evpandd(xmm16, k0, xmm12, one, false, Assembler::AVX_512bit); - __ evpandd(xmm17, k0, xmm13, one, false, Assembler::AVX_512bit); - __ evpandd(xmm18, k0, xmm14, one, false, Assembler::AVX_512bit); - __ evpandd(xmm19, k0, xmm15, one, false, Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpand(Tmp1[i], Mask[i], one, vector_len); + } - __ evpaddd(xmm4, k0, xmm4, xmm16, false, Assembler::AVX_512bit); - __ evpaddd(xmm5, k0, xmm5, xmm17, false, Assembler::AVX_512bit); - __ evpaddd(xmm6, k0, xmm6, xmm18, false, Assembler::AVX_512bit); - __ evpaddd(xmm7, k0, xmm7, xmm19, false, Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpaddd(Quotient[i], Quotient[i], Tmp1[i], vector_len); + } // mask = (twoGamma2 / 2 - r0) >> 31; - __ evpsubd(xmm12, k0, gamma2, xmm8, false, Assembler::AVX_512bit); - __ evpsubd(xmm13, k0, gamma2, xmm9, false, Assembler::AVX_512bit); - __ evpsubd(xmm14, k0, gamma2, xmm10, false, Assembler::AVX_512bit); - __ evpsubd(xmm15, k0, gamma2, xmm11, false, Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpsubd(Mask[i], gamma2, R0[i], vector_len); + } - __ evpsrad(xmm12, k0, xmm12, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm13, k0, xmm13, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm14, k0, xmm14, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm15, k0, xmm15, 31, false, Assembler::AVX_512bit); + for (int i = 0; i < regCnt; i++) { + __ vpsrad(Mask[i], Mask[i], 31, vector_len); + } // r0 -= (mask & twoGamma2); - __ evpandd(xmm16, k0, xmm12, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm17, k0, xmm13, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm18, k0, xmm14, twoGamma2, false, Assembler::AVX_512bit); - __ evpandd(xmm19, k0, xmm15, twoGamma2, false, Assembler::AVX_512bit); - - __ evpsubd(xmm8, k0, xmm8, xmm16, false, Assembler::AVX_512bit); - __ evpsubd(xmm9, k0, xmm9, xmm17, false, Assembler::AVX_512bit); - __ evpsubd(xmm10, k0, xmm10, xmm18, false, Assembler::AVX_512bit); - __ evpsubd(xmm11, k0, xmm11, xmm19, false, Assembler::AVX_512bit); - // r0 in xmm8 + for (int i = 0; i < regCnt; i++) { + __ vpand(Tmp1[i], Mask[i], twoGamma2, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpsubd(R0[i], R0[i], Tmp1[i], vector_len); + } + // quotient += (mask & 1); - __ evpandd(xmm16, k0, xmm12, one, false, Assembler::AVX_512bit); - __ evpandd(xmm17, k0, xmm13, one, false, Assembler::AVX_512bit); - __ evpandd(xmm18, k0, xmm14, one, false, Assembler::AVX_512bit); - __ evpandd(xmm19, k0, xmm15, one, false, Assembler::AVX_512bit); - - __ evpaddd(xmm4, k0, xmm4, xmm16, false, Assembler::AVX_512bit); - __ evpaddd(xmm5, k0, xmm5, xmm17, false, Assembler::AVX_512bit); - __ evpaddd(xmm6, k0, xmm6, xmm18, false, Assembler::AVX_512bit); - __ evpaddd(xmm7, k0, xmm7, xmm19, false, Assembler::AVX_512bit); - // quotient in xmm4 + for (int i = 0; i < regCnt; i++) { + __ vpand(Tmp1[i], Mask[i], one, vector_len); + } + + for (int i = 0; i < regCnt; i++) { + __ vpaddd(Quotient[i], Quotient[i], Tmp1[i], vector_len); + } + // r1 in RPlus // int r1 = rplus - r0 - (dilithium_q - 1); - __ evpsubd(xmm16, k0, xmm0, xmm8, false, Assembler::AVX_512bit); - __ evpsubd(xmm17, k0, xmm1, xmm9, false, Assembler::AVX_512bit); - __ evpsubd(xmm18, k0, xmm2, xmm10, false, Assembler::AVX_512bit); - __ evpsubd(xmm19, k0, xmm3, xmm11, false, Assembler::AVX_512bit); - - __ evpsubd(xmm16, k0, xmm16, xmm26, false, Assembler::AVX_512bit); - __ evpsubd(xmm17, k0, xmm17, xmm26, false, Assembler::AVX_512bit); - __ evpsubd(xmm18, k0, xmm18, xmm26, false, Assembler::AVX_512bit); - __ evpsubd(xmm19, k0, xmm19, xmm26, false, Assembler::AVX_512bit); - // r1 in xmm16 + for (int i = 0; i < regCnt; i++) { + __ vpsubd(RPlus[i], RPlus[i], R0[i], vector_len); + } + // r1 = (r1 | (-r1)) >> 31; // 0 if rplus - r0 == (dilithium_q - 1), -1 otherwise - __ evpsubd(xmm20, k0, zero, xmm16, false, Assembler::AVX_512bit); - __ evpsubd(xmm21, k0, zero, xmm17, false, Assembler::AVX_512bit); - __ evpsubd(xmm22, k0, zero, xmm18, false, Assembler::AVX_512bit); - __ evpsubd(xmm23, k0, zero, xmm19, false, Assembler::AVX_512bit); - - __ evporq(xmm16, k0, xmm16, xmm20, false, Assembler::AVX_512bit); - __ evporq(xmm17, k0, xmm17, xmm21, false, Assembler::AVX_512bit); - __ evporq(xmm18, k0, xmm18, xmm22, false, Assembler::AVX_512bit); - __ evporq(xmm19, k0, xmm19, xmm23, false, Assembler::AVX_512bit); - - __ evpsubd(xmm12, k0, zero, one, false, Assembler::AVX_512bit); // -1 - - __ evpsrad(xmm0, k0, xmm16, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm1, k0, xmm17, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm2, k0, xmm18, 31, false, Assembler::AVX_512bit); - __ evpsrad(xmm3, k0, xmm19, 31, false, Assembler::AVX_512bit); - // r1 in xmm0 - // r0 += ~r1; - __ evpxorq(xmm20, k0, xmm0, xmm12, false, Assembler::AVX_512bit); - __ evpxorq(xmm21, k0, xmm1, xmm12, false, Assembler::AVX_512bit); - __ evpxorq(xmm22, k0, xmm2, xmm12, false, Assembler::AVX_512bit); - __ evpxorq(xmm23, k0, xmm3, xmm12, false, Assembler::AVX_512bit); - - __ evpaddd(xmm8, k0, xmm8, xmm20, false, Assembler::AVX_512bit); - __ evpaddd(xmm9, k0, xmm9, xmm21, false, Assembler::AVX_512bit); - __ evpaddd(xmm10, k0, xmm10, xmm22, false, Assembler::AVX_512bit); - __ evpaddd(xmm11, k0, xmm11, xmm23, false, Assembler::AVX_512bit); - // r0 in xmm8 - // r1 = r1 & quotient; - __ evpandd(xmm0, k0, xmm4, xmm0, false, Assembler::AVX_512bit); - __ evpandd(xmm1, k0, xmm5, xmm1, false, Assembler::AVX_512bit); - __ evpandd(xmm2, k0, xmm6, xmm2, false, Assembler::AVX_512bit); - __ evpandd(xmm3, k0, xmm7, xmm3, false, Assembler::AVX_512bit); - // r1 in xmm0 + if (vector_len == Assembler::AVX_512bit) { + KRegister EqMsk[] = {k1, k2, k3, k4}; + for (int i = 0; i < regCnt; i++) { + __ evpcmpeqd(EqMsk[i], k0, RPlus[i], qMinus1, vector_len); + } + + // r0 += ~r1; // add -1 or keep as is, using EqMsk as filter + for (int i = 0; i < regCnt; i++) { + __ evpaddd(R0[i], EqMsk[i], R0[i], minusOne, true, vector_len); + } + + // r1 in Quotient + // r1 = r1 & quotient; // copy 0 or keep as is, using EqMsk as filter + for (int i = 0; i < regCnt; i++) { + // FIXME: replace with void evmovdqul(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len);? + __ evpandd(Quotient[i], EqMsk[i], Quotient[i], zero, true, vector_len); + } + } else { + const XMMRegister qMinus1 = Tmp1[0]; + __ vpsubd(qMinus1, dilithium_q, one, vector_len); // q - 1 + + for (int i = 0; i < regCnt; i++) { + __ vpcmpeqd(Mask[i], RPlus[i], qMinus1, vector_len); + } + + // r0 += ~r1; + // Mask already negated + for (int i = 0; i < regCnt; i++) { + __ vpaddd(R0[i], R0[i], Mask[i], vector_len); + } + + // r1 in Quotient + // r1 = r1 & quotient; + for (int i = 0; i < regCnt; i++) { + __ vpandn(Quotient[i], Mask[i], Quotient[i], vector_len); + } + } + + // r1 in Quotient // lowPart[m] = r0; // highPart[m] = r1; - store4Xmms(highPart, 0, xmm0_3, _masm); - store4Xmms(lowPart, 0, xmm8_11, _masm); + storeXmms(highPart, 0, Quotient, vector_len, _masm); + storeXmms(lowPart, 0, R0, vector_len, _masm); - __ addptr(highPart, 4 * XMMBYTES); - __ addptr(lowPart, 4 * XMMBYTES); - __ subl(len, 4 * XMMBYTES); + __ addptr(highPart, memStep); + __ addptr(lowPart, memStep); + __ subl(len, memStep); __ jcc(Assembler::notEqual, L_loop); __ leave(); // required for proper stackwalking of RuntimeStub frame @@ -1018,17 +1315,21 @@ static address generate_dilithiumDecomposePoly_avx512(StubGenerator *stubgen, } void StubGenerator::generate_dilithium_stubs() { + int vector_len = Assembler::AVX_256bit; + if (VM_Version::supports_evex() && VM_Version::supports_avx512bw()) { + vector_len = Assembler::AVX_512bit; + } // Generate Dilithium intrinsics code if (UseDilithiumIntrinsics) { StubRoutines::_dilithiumAlmostNtt = - generate_dilithiumAlmostNtt_avx512(this, _masm); + generate_dilithiumAlmostNtt_avx(this, vector_len, _masm); StubRoutines::_dilithiumAlmostInverseNtt = - generate_dilithiumAlmostInverseNtt_avx512(this, _masm); + generate_dilithiumAlmostInverseNtt_avx(this, vector_len, _masm); StubRoutines::_dilithiumNttMult = - generate_dilithiumNttMult_avx512(this, _masm); + generate_dilithiumNttMult_avx(this, vector_len, _masm); StubRoutines::_dilithiumMontMulByConstant = - generate_dilithiumMontMulByConstant_avx512(this, _masm); + generate_dilithiumMontMulByConstant_avx(this, vector_len, _masm); StubRoutines::_dilithiumDecomposePoly = - generate_dilithiumDecomposePoly_avx512(this, _masm); + generate_dilithiumDecomposePoly_avx(this, vector_len, _masm); } } diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index e4a101a597734..5e38759962ae9 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -1263,7 +1263,7 @@ void VM_Version::get_processor_features() { // Dilithium Intrinsics // Currently we only have them for AVX512 - if (supports_evex() && supports_avx512bw()) { + if (UseAVX > 1) { if (FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) { UseDilithiumIntrinsics = true; } diff --git a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java new file mode 100644 index 0000000000000..0f46496aa8323 --- /dev/null +++ b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java @@ -0,0 +1,617 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.util.Arrays; +import java.util.Random; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Constructor; +import java.util.HexFormat; + +public class ML_DSA_Intrinsic_Test { + public static void main(String[] args) throws Exception { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + Class kClazz = Class.forName("sun.security.provider.ML_DSA"); + Constructor constructor = kClazz.getDeclaredConstructor( + int.class); + constructor.setAccessible(true); + + Method m = kClazz.getDeclaredMethod("implDilithiumNttMult", + int[].class, int[].class, int[].class); + m.setAccessible(true); + MethodHandle mult = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumNttMultJava", + int[].class, int[].class, int[].class); + m.setAccessible(true); + MethodHandle multJava = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumMontMulByConstant", + int[].class, int.class); + m.setAccessible(true); + MethodHandle multConst = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumMontMulByConstantJava", + int[].class, int.class); + m.setAccessible(true); + MethodHandle multConstJava = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumDecomposePoly", + int[].class, int[].class, int[].class, int.class, int.class); + m.setAccessible(true); + MethodHandle decompose = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("decomposePolyJava", + int[].class, int[].class, int[].class, int.class, int.class); + m.setAccessible(true); + MethodHandle decomposeJava = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostNtt", + int[].class, int[].class); + m.setAccessible(true); + MethodHandle almostNtt = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostNttJava", + int[].class); + m.setAccessible(true); + MethodHandle almostNttJava = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostInverseNtt", + int[].class, int[].class); + m.setAccessible(true); + MethodHandle inverseNtt = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostInverseNttJava", + int[].class); + m.setAccessible(true); + MethodHandle inverseNttJava = lookup.unreflect(m); + + Random rnd = new Random(); + long seed = rnd.nextLong(); + rnd.setSeed(seed); + //Note: it might be useful to increase this number during development of new intrinsics + final int repeat = 10000000; + int[] coeffs1 = new int[ML_DSA_N]; + int[] coeffs2 = new int[ML_DSA_N]; + int[] prod1 = new int[ML_DSA_N]; + int[] prod2 = new int[ML_DSA_N]; + int[] prod3 = new int[ML_DSA_N]; + int[] prod4 = new int[ML_DSA_N]; + try { + for (int i = 0; i < repeat; i++) { + // seed = rnd.nextLong(); + //rnd.setSeed(seed); + testMult(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i); + testMultConst(prod1, prod2, multConst, multConstJava, rnd, seed, i); + testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2, decompose, decomposeJava, rnd, seed, i); + testAlmostNtt(coeffs1, coeffs2, almostNtt, almostNttJava, rnd, seed, i); + testInverseNtt(coeffs1, coeffs2, inverseNtt, inverseNttJava, rnd, seed, i); + } + System.out.println("Fuzz Success"); + } catch (Throwable e) { + System.out.println("Fuzz Failed: " + e); + } + } + + private static final int ML_DSA_N = 256; + public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, + MethodHandle mult, MethodHandle multJava, Random rnd, + long seed, int i) throws Exception, Throwable { + + for (int j = 0; j " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " (tmp = " + hex.toHexDigits(tmp) + ")"); + if (l == testLevel) { + System.out.println(bld.toString()); + } + } + m++; + } + } + } + static void implDilithiumAlmostInverseNttJava(int[] coeffs) { + HexFormat hex = HexFormat.of(); + int dimension = 256; + int m = MONT_ZETAS_FOR_NTT.length - 1; + int testLevel = 1; + for (int l = 1; l < dimension; l *= 2) { + for (int s = 0; s < dimension; s += 2 * l) { + for (int j = s; j < s + l; j++) { + StringBuilder bld = new StringBuilder(); + bld.append("l = " + l + ", m = " + m + ", j = " + j+": "); + int tmp = coeffs[j]; + bld.append(hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " -> " + hex.toHexDigits(tmp - coeffs[j + l]) + " * " + hex.toHexDigits(-MONT_ZETAS_FOR_NTT[m])); + coeffs[j] = (tmp + coeffs[j + l]); + coeffs[j + l] = montMul(tmp - coeffs[j + l], -MONT_ZETAS_FOR_NTT[m]); + bld.append(" -> " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l])); + if (l == testLevel) { + System.out.println(bld.toString()); + } + } + m--; + } + } + } + private static int montMul(int b, int c) { + long a = (long) b * (long) c; + int aHigh = (int) (a >> 32); + int aLow = (int) a; + int m = 58728449 * aLow; // signed low product + + // subtract signed high product + return (aHigh - (int) (((long)m * 8380417) >> 32)); + } + // Zeta values for NTT with montgomery factor precomputed + private static final int[] MONT_ZETAS_FOR_NTT = new int[]{ + 25847, //0 + -2608894, -518909, //1 + 237124, -777960, -876248, 466468, //2 + 1826347, + 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, //3 + 2725464, + 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, + -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, // 4 + + 2706023, + 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, + -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, + -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, 811944, + 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, //5 + + -3930395, + -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, -1257611, + 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922, 3412210, + -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, -671102, + -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383, + 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342, + -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, 189548, + -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669, + -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, //6 + + 2091667, + 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, 266997, + 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, 900702, + 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, -655327, + -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297, + 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 2842341, + 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, -3767016, + 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, -1333058, + 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, -1279661, + 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, -542412, + -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608, + 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426, + 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, -3038916, + 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, -426683, + 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036, + -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416, + 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782 + }; +} +// java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java diff --git a/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java b/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java new file mode 100644 index 0000000000000..2c4f94f4a4db8 --- /dev/null +++ b/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.javax.crypto.full; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Constructor; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; + +@Measurement(iterations = 3, time = 10) +@Warmup(iterations = 3, time = 10) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(value = 1, jvmArgs = {"--add-opens", "java.base/sun.security.provider=ALL-UNNAMED"}) +public class MLDSABench extends CryptoBase { + + public static final int SET_SIZE = 128; + private static final int ML_DSA_N = 256; + + private int[][] coeffs1; + private int[][] coeffs2; + private int[][] prod1; + private int[][] prod2; + private MethodHandle mult, multConst, decompose, almostNtt, inverseNtt; + int index = 0; + + public static int[][] fillRandom(int[][] data) { + Random rnd = new Random(); + for (int[] d : data) { + for (int j = 0; j kClazz = Class.forName("sun.security.provider.ML_DSA"); + Constructor constructor = kClazz.getDeclaredConstructor( + int.class); + constructor.setAccessible(true); + + Method m = kClazz.getDeclaredMethod("implDilithiumNttMult", + int[].class, int[].class, int[].class); + m.setAccessible(true); + mult = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumMontMulByConstant", + int[].class, int.class); + m.setAccessible(true); + multConst = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumDecomposePoly", + int[].class, int[].class, int[].class, int.class, int.class); + m.setAccessible(true); + decompose = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostNtt", + int[].class, int[].class); + m.setAccessible(true); + almostNtt = lookup.unreflect(m); + + m = kClazz.getDeclaredMethod("implDilithiumAlmostInverseNtt", + int[].class, int[].class); + m.setAccessible(true); + inverseNtt = lookup.unreflect(m); + } + + @Benchmark + public void mult1() throws Exception, Throwable { + mult.invoke(prod1[index], coeffs1[index], coeffs2[index]); + index = (index + 1) % SET_SIZE; + } + + @Benchmark + public void multConst() throws Exception, Throwable { + multConst.invoke(prod1[index], coeffs1[index][index]); + index = (index + 1) % SET_SIZE; + } + + @Benchmark + public void multDecompose() throws Exception, Throwable { + int gamma2 = 95232; + if (coeffs1[index][0]%2==1) { + gamma2 = coeffs1[index][1]; + } + int multiplier = (gamma2 == 95232 ? 22 : 8); + decompose.invoke(coeffs1[index], prod1[index], prod2[index], 2 * gamma2, multiplier); + index = (index + 1) % SET_SIZE; + } + + @Benchmark + public void multAlmostNtt() throws Exception, Throwable { + almostNtt.invoke(coeffs1[index], MONT_ZETAS_FOR_VECTOR_NTT); + index = (index + 1) % SET_SIZE; + } + + @Benchmark + public void multInverseNtt() throws Exception, Throwable { + inverseNtt.invoke(coeffs2[index], MONT_ZETAS_FOR_VECTOR_INVERSE_NTT); + index = (index + 1) % SET_SIZE; + } + + private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{ + -1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416, + 3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036, + 3833893, 260646, 1104333, 1667432, -1910376, 1803090, -1723600, 426683, + -472078, -1717735, 975884, -2213111, -269760, -3866901, -3523897, 3038916, + 1799107, 3694233, -1652634, -810149, -3014001, -1616392, -162844, 3183426, + 1207385, -185531, -3369112, -1957272, 164721, -2454455, -2432395, 2013608, + 3776993, -594136, 3724270, 2584293, 1846953, 1671176, 2831860, 542412, + -3406031, -2235880, -777191, -1500165, 1374803, 2546312, -1917081, 1279661, + 1962642, -3306115, -1312455, 451100, 1430225, 3318210, -1237275, 1333058, + 1050970, -1903435, -1869119, 2994039, 3548272, -2635921, -1250494, 3767016, + -1595974, -2486353, -1247620, -4055324, -1265009, 2590150, -2691481, -2842341, + -203044, -1735879, 3342277, -3437287, -4108315, 2437823, -286988, -342297, + 3595838, 768622, 525098, 3556995, -3207046, -2031748, 3122442, 655327, + 522500, 43260, 1613174, -495491, -819034, -909542, -1859098, -900702, + 3193378, 1197226, 3759364, 3520352, -3513181, 1235728, -2434439, -266997, + 3562462, 2446433, -2244091, 3342478, -3817976, -2316500, -3407706, -2091667, + + -3839961, -3839961, 3628969, 3628969, 3881060, 3881060, 3019102, 3019102, + 1439742, 1439742, 812732, 812732, 1584928, 1584928, -1285669, -1285669, + -1341330, - 1341330, -1315589, -1315589, 177440, 177440, 2409325, 2409325, + 1851402, 1851402, -3159746, -3159746, 3553272, 3553272, -189548, -189548, + 1316856, 1316856, -759969, -759969, 210977, 210977, -2389356, -2389356, + 3249728, 3249728, -1653064, -1653064, 8578, 8578, 3724342, 3724342, + -3958618, -3958618, -904516, -904516, 1100098, 1100098, -44288, -44288, + -3097992, -3097992, -508951, -508951, -264944, -264944, 3343383, 3343383, + 1430430, 1430430, -1852771, -1852771, -1349076, -1349076, 381987, 381987, + 1308169, 1308169, 22981, 22981, 1228525, 1228525, 671102, 671102, + 2477047, 2477047, 411027, 411027, 3693493, 3693493, 2967645, 2967645, + -2715295, -2715295, -2147896, -2147896, 983419, 983419, -3412210, -3412210, + -126922, -126922, 3632928, 3632928, 3157330, 3157330, 3190144, 3190144, + 1000202, 1000202, 4083598, 4083598, -1939314, -1939314, 1257611, 1257611, + 1585221, 1585221, -2176455, -2176455, -3475950, -3475950, 1452451, 1452451, + 3041255, 3041255, 3677745, 3677745, 1528703, 1528703, 3930395, 3930395, + + 2797779, 2797779, 2797779, 2797779, -2071892, -2071892, -2071892, -2071892, + 2556880, 2556880, 2556880, 2556880, -3900724, -3900724, -3900724, -3900724, + -3881043, -3881043, -3881043, -3881043, -954230, -954230, -954230, -954230, + -531354, -531354, -531354, -531354, -811944, -811944, -811944, -811944, + -3699596, -3699596, -3699596, -3699596, 1600420, 1600420, 1600420, 1600420, + 2140649, 2140649, 2140649, 2140649, -3507263, -3507263, -3507263, -3507263, + 3821735, 3821735, 3821735, 3821735, -3505694, -3505694, -3505694, -3505694, + 1643818, 1643818, 1643818, 1643818, 1699267, 1699267, 1699267, 1699267, + 539299, 539299, 539299, 539299, -2348700, -2348700, -2348700, -2348700, + 300467, 300467, 300467, 300467, -3539968, -3539968, -3539968, -3539968, + 2867647, 2867647, 2867647, 2867647, -3574422, -3574422, -3574422, -3574422, + 3043716, 3043716, 3043716, 3043716, 3861115, 3861115, 3861115, 3861115, + -3915439, -3915439, -3915439, -3915439, 2537516, 2537516, 2537516, 2537516, + 3592148, 3592148, 3592148, 3592148, 1661693, 1661693, 1661693, 1661693, + -3530437, -3530437, -3530437, -3530437, -3077325, -3077325, -3077325, -3077325, + -95776, -95776, -95776, -95776, -2706023, -2706023, -2706023, -2706023, + + -280005, -280005, -280005, -280005, -280005, -280005, -280005, -280005, + -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, + 19422, 19422, 19422, 19422, 19422, 19422, 19422, 19422, + -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, + 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, + 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, + 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, + 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, + 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, + -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, + 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, + 549488, 549488, 549488, 549488, 549488, 549488, 549488, 549488, + -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, + 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, + -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, + -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, + + -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, + -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, + -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, + -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, + 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, + 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, + -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, + -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, + 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, + 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, + 359251, 359251, 359251, 359251, 359251, 359251, 359251, 359251, + 359251, 359251, 359251, 359251, 359251, 359251, 359251, 359251, + -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, + -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, + -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, + -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, + + -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, + -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, + -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, + -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, + 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, + 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, + 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, + 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, + 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, + 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, + 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, + 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, + -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, + -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, + -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, + -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, + + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, + + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, + -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847 + }; + + private static final int[] MONT_ZETAS_FOR_VECTOR_NTT = new int[]{ + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, + + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, + + 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, + 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, + 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, + 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, + -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, + -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, + -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, + -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, + -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, + -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, + -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, + -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, + 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, + 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, + 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, + 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, + + 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, + 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, + 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, + 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, + -359251, -359251, -359251, -359251, -359251, -359251, -359251, -359251, + -359251, -359251, -359251, -359251, -359251, -359251, -359251, -359251, + -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, + -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, + 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, + 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, + -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, + -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, + 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, + 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, + 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, + 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, + + 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, + 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, + -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, + 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, + -549488, -549488, -549488, -549488, -549488, -549488, -549488, -549488, + -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, + 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, + -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, + -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, + -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, + -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, + -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, + 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, + -19422, -19422, -19422, -19422, -19422, -19422, -19422, -19422, + 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, + 280005, 280005, 280005, 280005, 280005, 280005, 280005, 280005, + + 2706023, 2706023, 2706023, 2706023, 95776, 95776, 95776, 95776, + 3077325, 3077325, 3077325, 3077325, 3530437, 3530437, 3530437, 3530437, + -1661693, -1661693, -1661693, -1661693, -3592148, -3592148, -3592148, -3592148, + -2537516, -2537516, -2537516, -2537516, 3915439, 3915439, 3915439, 3915439, + -3861115, -3861115, -3861115, -3861115, -3043716, -3043716, -3043716, -3043716, + 3574422, 3574422, 3574422, 3574422, -2867647, -2867647, -2867647, -2867647, + 3539968, 3539968, 3539968, 3539968, -300467, -300467, -300467, -300467, + 2348700, 2348700, 2348700, 2348700, -539299, -539299, -539299, -539299, + -1699267, -1699267, -1699267, -1699267, -1643818, -1643818, -1643818, -1643818, + 3505694, 3505694, 3505694, 3505694, -3821735, -3821735, -3821735, -3821735, + 3507263, 3507263, 3507263, 3507263, -2140649, -2140649, -2140649, -2140649, + -1600420, -1600420, -1600420, -1600420, 3699596, 3699596, 3699596, 3699596, + 811944, 811944, 811944, 811944, 531354, 531354, 531354, 531354, + 954230, 954230, 954230, 954230, 3881043, 3881043, 3881043, 3881043, + 3900724, 3900724, 3900724, 3900724, -2556880, -2556880, -2556880, -2556880, + 2071892, 2071892, 2071892, 2071892, -2797779, -2797779, -2797779, -2797779, + + -3930395, -3930395, -1528703, -1528703, -3677745, -3677745, -3041255, -3041255, + -1452451, -1452451, 3475950, 3475950, 2176455, 2176455, -1585221, -1585221, + -1257611, -1257611, 1939314, 1939314, -4083598, -4083598, -1000202, -1000202, + -3190144, -3190144, -3157330, -3157330, -3632928, -3632928, 126922, 126922, + 3412210, 3412210, -983419, -983419, 2147896, 2147896, 2715295, 2715295, + -2967645, -2967645, -3693493, -3693493, -411027, -411027, -2477047, -2477047, + -671102, -671102, -1228525, -1228525, -22981, -22981, -1308169, -1308169, + -381987, -381987, 1349076, 1349076, 1852771, 1852771, -1430430, -1430430, + -3343383, -3343383, 264944, 264944, 508951, 508951, 3097992, 3097992, + 44288, 44288, -1100098, -1100098, 904516, 904516, 3958618, 3958618, + -3724342, -3724342, -8578, -8578, 1653064, 1653064, -3249728, -3249728, + 2389356, 2389356, -210977, -210977, 759969, 759969, -1316856, -1316856, + 189548, 189548, -3553272, -3553272, 3159746, 3159746, -1851402, -1851402, + -2409325, -2409325, -177440, -177440, 1315589, 1315589, 1341330, 1341330, + 1285669, 1285669, -1584928, -1584928, -812732, -812732, -1439742, -1439742, + -3019102, -3019102, -3881060, -3881060, -3628969, -3628969, 3839961, 3839961, + + 2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, + 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, + 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, + -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, + 342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, + 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, + -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, + -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, + -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, + -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, + -2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, + -3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, + -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, + -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, + -2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, + -554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782 + }; +} From 2ff3b82326a0afa683471edd78cac93ab5c83826 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Mon, 6 Oct 2025 21:23:40 +0000 Subject: [PATCH 2/9] Fixes and comments from Anas --- .../x86/stubGenerator_x86_64_dilithium.cpp | 21 +-- .../provider/acvp/ML_DSA_Intrinsic_Test.java | 134 ++---------------- .../bench/javax/crypto/full/MLDSABench.java | 6 +- 3 files changed, 26 insertions(+), 135 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index bca7363b9321a..da0ae464d98dc 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -200,10 +200,10 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg __ vmovdqu(output2[i], input2[i], vector_len); } for (int i = 0; i < regCnt; i++) { - __ evmovshdup(output2[i], k2, input1[i], true, vector_len); + __ evmovshdup(output2[i], mergeMask2, input1[i], true, vector_len); } for (int i = 0; i < regCnt; i++) { - __ evmovsldup(input1[i], k1, input2[i], true, vector_len); + __ evmovsldup(input1[i], mergeMask1, input2[i], true, vector_len); } break; // Special cases @@ -390,7 +390,7 @@ static void storeXmms(Register destination, int offset, const XMMRegister xmmReg // static int implDilithiumAlmostNtt(int[] coeffs, int zetas[]) {} // // coeffs (int[256]) = c_rarg0 -// zetas (int[256]) = c_rarg1 +// zetas (int[128*8]) = c_rarg1 // static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, int vector_len, MacroAssembler *_masm) { @@ -647,7 +647,7 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, // static int implDilithiumAlmostInverseNtt(int[] coeffs, int[] zetas) {} // // coeffs (int[256]) = c_rarg0 -// zetas (int[256]) = c_rarg1 +// zetas (int[128*8]) = c_rarg1 static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, int vector_len,MacroAssembler *_masm) { __ align(CodeEntryAlignment); @@ -1017,12 +1017,13 @@ static address generate_dilithiumMontMulByConstant_avx(StubGenerator *stubgen, __ evpbroadcastd(constant, rConstant, Assembler::AVX_512bit); // constant multiplier __ mov64(scratch, 0b0101010101010101); //dw-mask - __ kmovwl(k2, scratch); + __ kmovwl(mergeMask, scratch); } // Total payload is 256*int32s. // - memStep is number of bytes one montMul64 processes. // - loopCnt is number of iterations it will take to process entire payload. + // - (two memSteps per loop) int memStep = 4 * 64; int loopCnt = 2; if (vector_len == Assembler::AVX_256bit) { @@ -1321,15 +1322,15 @@ void StubGenerator::generate_dilithium_stubs() { } // Generate Dilithium intrinsics code if (UseDilithiumIntrinsics) { - StubRoutines::_dilithiumAlmostNtt = + StubRoutines::_dilithiumAlmostNtt = generate_dilithiumAlmostNtt_avx(this, vector_len, _masm); - StubRoutines::_dilithiumAlmostInverseNtt = + StubRoutines::_dilithiumAlmostInverseNtt = generate_dilithiumAlmostInverseNtt_avx(this, vector_len, _masm); - StubRoutines::_dilithiumNttMult = + StubRoutines::_dilithiumNttMult = generate_dilithiumNttMult_avx(this, vector_len, _masm); - StubRoutines::_dilithiumMontMulByConstant = + StubRoutines::_dilithiumMontMulByConstant = generate_dilithiumMontMulByConstant_avx(this, vector_len, _masm); - StubRoutines::_dilithiumDecomposePoly = + StubRoutines::_dilithiumDecomposePoly = generate_dilithiumDecomposePoly_avx(this, vector_len, _masm); } } diff --git a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java index 0f46496aa8323..16fdff9d28878 100644 --- a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java +++ b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java @@ -31,6 +31,12 @@ import java.lang.reflect.Constructor; import java.util.HexFormat; +/* + * @test + * @library /test/lib + * @modules java.base/sun.security.provider:open + * @run main ML_DSA_Intrinsic_Test + */ public class ML_DSA_Intrinsic_Test { public static void main(String[] args) throws Exception { MethodHandles.Lookup lookup = MethodHandles.lookup(); @@ -129,7 +135,7 @@ public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeff mult.invoke(prod1, coeffs1, coeffs2); multJava.invoke(prod2, coeffs1, coeffs2); - if (!Arrays.equals(prod1, parseHex(formatOf(prod2)))) { + if (!Arrays.equals(prod1, prod2)) { throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mult mismatch: " + formatOf(prod1) + " != " + formatOf(prod2)); } } @@ -141,7 +147,9 @@ public static void testMultConst(int[] prod1, int[] prod2, for (int j = 0; j " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " (tmp = " + hex.toHexDigits(tmp) + ")"); - if (l == testLevel) { - System.out.println(bld.toString()); - } - } - m++; - } - } - } - static void implDilithiumAlmostInverseNttJava(int[] coeffs) { - HexFormat hex = HexFormat.of(); - int dimension = 256; - int m = MONT_ZETAS_FOR_NTT.length - 1; - int testLevel = 1; - for (int l = 1; l < dimension; l *= 2) { - for (int s = 0; s < dimension; s += 2 * l) { - for (int j = s; j < s + l; j++) { - StringBuilder bld = new StringBuilder(); - bld.append("l = " + l + ", m = " + m + ", j = " + j+": "); - int tmp = coeffs[j]; - bld.append(hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l]) + " -> " + hex.toHexDigits(tmp - coeffs[j + l]) + " * " + hex.toHexDigits(-MONT_ZETAS_FOR_NTT[m])); - coeffs[j] = (tmp + coeffs[j + l]); - coeffs[j + l] = montMul(tmp - coeffs[j + l], -MONT_ZETAS_FOR_NTT[m]); - bld.append(" -> " + hex.toHexDigits(coeffs[j]) + ", " + hex.toHexDigits(coeffs[j + l])); - if (l == testLevel) { - System.out.println(bld.toString()); - } - } - m--; - } - } - } - private static int montMul(int b, int c) { - long a = (long) b * (long) c; - int aHigh = (int) (a >> 32); - int aLow = (int) a; - int m = 58728449 * aLow; // signed low product - - // subtract signed high product - return (aHigh - (int) (((long)m * 8380417) >> 32)); - } - // Zeta values for NTT with montgomery factor precomputed - private static final int[] MONT_ZETAS_FOR_NTT = new int[]{ - 25847, //0 - -2608894, -518909, //1 - 237124, -777960, -876248, 466468, //2 - 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, //3 - 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, // 4 - - 2706023, - 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, - -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, - -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, 811944, - 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, //5 - - -3930395, - -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, -1257611, - 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922, 3412210, - -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, -671102, - -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383, - 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342, - -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, 189548, - -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669, - -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, //6 - - 2091667, - 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, 266997, - 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, 900702, - 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, -655327, - -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297, - 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 2842341, - 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, -3767016, - 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, -1333058, - 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, -1279661, - 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, -542412, - -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608, - 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426, - 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, -3038916, - 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, -426683, - 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036, - -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416, - 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782 - }; } // java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java diff --git a/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java b/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java index 2c4f94f4a4db8..1eeb35c194dff 100644 --- a/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java +++ b/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java @@ -48,9 +48,7 @@ @OutputTimeUnit(TimeUnit.MILLISECONDS) @Fork(value = 1, jvmArgs = {"--add-opens", "java.base/sun.security.provider=ALL-UNNAMED"}) public class MLDSABench extends CryptoBase { - public static final int SET_SIZE = 128; - private static final int ML_DSA_N = 256; private int[][] coeffs1; private int[][] coeffs2; @@ -109,7 +107,7 @@ public void setup() throws Exception { } @Benchmark - public void mult1() throws Exception, Throwable { + public void mult() throws Exception, Throwable { mult.invoke(prod1[index], coeffs1[index], coeffs2[index]); index = (index + 1) % SET_SIZE; } @@ -143,6 +141,8 @@ public void multInverseNtt() throws Exception, Throwable { index = (index + 1) % SET_SIZE; } + // Copied constants from sun.security.provider.ML_DSA + private static final int ML_DSA_N = 256; private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{ -1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416, 3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036, From f4f84b6e717cc486601574632c50f66fce24a692 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Tue, 4 Nov 2025 16:14:23 +0000 Subject: [PATCH 3/9] add copyright, whitespace and test jtreg tags --- .../x86/stubGenerator_x86_64_dilithium.cpp | 1 + .../provider/acvp/ML_DSA_Intrinsic_Test.java | 60 +++++++++++-------- .../bench/javax/crypto/full/MLDSABench.java | 2 +- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index da0ae464d98dc..2f03a4e3e391a 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Intel Corporation. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it diff --git a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java index 16fdff9d28878..b90d93f613b92 100644 --- a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java +++ b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java @@ -34,17 +34,29 @@ /* * @test * @library /test/lib - * @modules java.base/sun.security.provider:open + * @key randomness + * @modules java.base/sun.security.provider:+open + * @run main/othervm ML_DSA_Intrinsic_Test -XX:+UnlockDiagnosticVMOptions -XX:-UseDilithiumIntrinsics + */ +/* + * @test + * @library /test/lib + * @key randomness + * @modules java.base/sun.security.provider:+open + * @run main/othervm -XX:UseAVX=2 ML_DSA_Intrinsic_Test + */ +/* + * @test + * @library /test/lib + * @key randomness + * @modules java.base/sun.security.provider:+open * @run main ML_DSA_Intrinsic_Test */ public class ML_DSA_Intrinsic_Test { public static void main(String[] args) throws Exception { MethodHandles.Lookup lookup = MethodHandles.lookup(); - Class kClazz = Class.forName("sun.security.provider.ML_DSA"); - Constructor constructor = kClazz.getDeclaredConstructor( - int.class); - constructor.setAccessible(true); - + Class kClazz = sun.security.provider.ML_DSA.class; + Method m = kClazz.getDeclaredMethod("implDilithiumNttMult", int[].class, int[].class, int[].class); m.setAccessible(true); @@ -123,10 +135,10 @@ public static void main(String[] args) throws Exception { } private static final int ML_DSA_N = 256; - public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, - MethodHandle mult, MethodHandle multJava, Random rnd, + public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, + MethodHandle mult, MethodHandle multJava, Random rnd, long seed, int i) throws Exception, Throwable { - + for (int j = 0; j Date: Mon, 17 Nov 2025 23:33:41 +0000 Subject: [PATCH 4/9] address first comments --- src/hotspot/cpu/x86/assembler_x86.cpp | 10 ++++----- .../x86/stubGenerator_x86_64_dilithium.cpp | 22 ++++++++++++------- src/hotspot/cpu/x86/vm_version_x86.cpp | 1 - .../provider/acvp/ML_DSA_Intrinsic_Test.java | 17 +++++++------- .../bench/javax/crypto/full/MLDSABench.java | 2 +- 5 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index ea2f1b7144c06..cbc5c6988d4bb 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3861,18 +3861,14 @@ void Assembler::evmovdquq(Address dst, KRegister mask, XMMRegister src, bool mer } void Assembler::vmovsldup(XMMRegister dst, XMMRegister src, int vector_len) { - assert(vector_len == AVX_128bit ? VM_Version::supports_avx() : - (vector_len == AVX_256bit ? VM_Version::supports_avx2() : - (vector_len == AVX_512bit ? VM_Version::supports_evex() : false)), ""); + assert(vector_len == AVX_512bit ? VM_Version::supports_evex() : VM_Version::supports_avx(), ""); InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); emit_int16(0x12, (0xC0 | encode)); } void Assembler::vmovshdup(XMMRegister dst, XMMRegister src, int vector_len) { - assert(vector_len == AVX_128bit ? VM_Version::supports_avx() : - (vector_len == AVX_256bit ? VM_Version::supports_avx2() : - (vector_len == AVX_512bit ? VM_Version::supports_evex() : false)), ""); + assert(vector_len == AVX_512bit ? VM_Version::supports_evex() : VM_Version::supports_avx(), ""); InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F, &attributes); emit_int16(0x16, (0xC0 | encode)); @@ -3880,6 +3876,7 @@ void Assembler::vmovshdup(XMMRegister dst, XMMRegister src, int vector_len) { void Assembler::evmovsldup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { assert(VM_Version::supports_evex(), ""); + assert(vector_len == AVX_512bit || VM_Version::supports_avx512vl(), ""); InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); attributes.set_embedded_opmask_register_specifier(mask); attributes.set_is_evex_instruction(); @@ -3892,6 +3889,7 @@ void Assembler::evmovsldup(XMMRegister dst, KRegister mask, XMMRegister src, boo void Assembler::evmovshdup(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { assert(VM_Version::supports_evex(), ""); + assert(vector_len == AVX_512bit || VM_Version::supports_avx512vl(), ""); InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); attributes.set_embedded_opmask_register_specifier(mask); attributes.set_is_evex_instruction(); diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index b555adf23eb03..b4830f455dec9 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -78,16 +78,17 @@ static address unshufflePermsAddr(int offset) { // The following function swaps elements A<->B, C<->D, and so forth. // input1[] is shuffled in place; shuffle of input2[] is copied to output2[]. // Element size (in bits) is specified by size parameter. -// size 0 and 1 are used for initial and final shuffles respectivelly of -// dilithiumAlmostInverseNtt and dilithiumAlmostNtt. -// NOTE: For size 0 and 1, input1[] and input2[] are modified in-place -// // +-----+-----+-----+-----+----- // | | A | | C | ... // +-----+-----+-----+-----+----- // +-----+-----+-----+-----+----- // | B | | D | | ... // +-----+-----+-----+-----+----- +// +// NOTE: size 0 and 1 are used for initial and final shuffles respectivelly of +// dilithiumAlmostInverseNtt and dilithiumAlmostNtt. For size 0 and 1, input1[] +// and input2[] are modified in-place (and output2 is used as a temporary) +// // Using C++ lambdas for improved readability (to hide parameters that always repeat) static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister mergeMask2, const XMMRegister unshuffle1, const XMMRegister unshuffle2, int vector_len, MacroAssembler *_masm) { @@ -131,9 +132,11 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg __ vpblendd(input1[i], input1[i], input2[i], 0b10101010, vector_len); } break; - case 1: + // Special cases + case 1: // initial shuffle for dilithiumAlmostInverseNtt + // shuffle all even 32bit columns to input1, and odd to input2 for (int i = 0; i < regCnt; i++) { - // 0b-1-2-3-1 + // 0b-3-1-3-1 __ vshufps(output2[i], input1[i], input2[i], 0b11011101, vector_len); } for (int i = 0; i < regCnt; i++) { @@ -148,7 +151,8 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg __ vpermq(input1[i], input1[i], 0b11011000, vector_len); } break; - case 0: + case 0: // final unshuffle for dilithiumAlmostNtt + // reverse case 1: all even are in input1 and odd in input2, put back for (int i = 0; i < regCnt; i++) { __ vpunpckhdq(output2[i], input1[i], input2[i], vector_len); } @@ -209,6 +213,7 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg break; // Special cases case 1: // initial shuffle for dilithiumAlmostInverseNtt + // shuffle all even 32bit columns to input1, and odd to input2 for (int i = 0; i < regCnt; i++) { __ vmovdqu(output2[i], input2[i], vector_len); } @@ -220,6 +225,7 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg } break; case 0: // final unshuffle for dilithiumAlmostNtt + // reverse case 1: all even are in input1 and odd in input2, put back for (int i = 0; i < regCnt; i++) { __ vmovdqu(output2[i], input2[i], vector_len); } @@ -394,7 +400,7 @@ static void storeXmms(Register destination, int offset, const XMMRegister xmmReg // zetas (int[128*8]) = c_rarg1 // static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, - int vector_len, MacroAssembler *_masm) { + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubId stub_id = StubId::stubgen_dilithiumAlmostNtt_id; StubCodeMark mark(stubgen, stub_id); diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index a7003002988ae..747daefd51d65 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -1271,7 +1271,6 @@ void VM_Version::get_processor_features() { } // Dilithium Intrinsics - // Currently we only have them for AVX512 if (UseAVX > 1) { if (FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) { UseDilithiumIntrinsics = true; diff --git a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java index b90d93f613b92..60c0804891d6c 100644 --- a/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java +++ b/test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Intel Corporation. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -23,12 +23,9 @@ import java.util.Arrays; import java.util.Random; - import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; -import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.lang.reflect.Constructor; import java.util.HexFormat; /* @@ -52,6 +49,10 @@ * @modules java.base/sun.security.provider:+open * @run main ML_DSA_Intrinsic_Test */ + +// To run manually: java --add-opens java.base/sun.security.provider=ALL-UNNAMED --add-exports java.base/sun.security.provider=ALL-UNNAMED +// -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java + public class ML_DSA_Intrinsic_Test { public static void main(String[] args) throws Exception { MethodHandles.Lookup lookup = MethodHandles.lookup(); @@ -107,6 +108,7 @@ public static void main(String[] args) throws Exception { m.setAccessible(true); MethodHandle inverseNttJava = lookup.unreflect(m); + // Hint: if test fails, you can hardcode the seed to make the test more reproducible Random rnd = new Random(); long seed = rnd.nextLong(); rnd.setSeed(seed); @@ -120,8 +122,8 @@ public static void main(String[] args) throws Exception { int[] prod4 = new int[ML_DSA_N]; try { for (int i = 0; i < repeat; i++) { - // seed = rnd.nextLong(); - //rnd.setSeed(seed); + // Hint: if test fails, you can hardcode the seed to make the test more reproducible: + // rnd.setSeed(seed); testMult(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i); testMultConst(prod1, prod2, multConst, multConstJava, rnd, seed, i); testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2, decompose, decomposeJava, rnd, seed, i); @@ -214,9 +216,7 @@ public static void testAlmostNtt(int[] coeffs1, int[] coeffs2, public static void testInverseNtt(int[] coeffs1, int[] coeffs2, MethodHandle inverseNtt, MethodHandle inverseNttJava, Random rnd, long seed, int i) throws Exception, Throwable { - int[] coeffs3 = new int[ML_DSA_N]; for (int j = 0; j Date: Mon, 17 Nov 2025 23:37:52 +0000 Subject: [PATCH 5/9] whitespace --- src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index b4830f455dec9..40c3f87d074b9 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -86,7 +86,7 @@ static address unshufflePermsAddr(int offset) { // +-----+-----+-----+-----+----- // // NOTE: size 0 and 1 are used for initial and final shuffles respectivelly of -// dilithiumAlmostInverseNtt and dilithiumAlmostNtt. For size 0 and 1, input1[] +// dilithiumAlmostInverseNtt and dilithiumAlmostNtt. For size 0 and 1, input1[] // and input2[] are modified in-place (and output2 is used as a temporary) // // Using C++ lambdas for improved readability (to hide parameters that always repeat) From b04f4f0d0de0700ff450e24b4e30144ac553b163 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Thu, 20 Nov 2025 22:52:50 +0000 Subject: [PATCH 6/9] next set of comments --- .../x86/stubGenerator_x86_64_dilithium.cpp | 3 +- .../bench/javax/crypto/full/MLDSABench.java | 421 ------------------ 2 files changed, 1 insertion(+), 423 deletions(-) delete mode 100644 test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index 40c3f87d074b9..39fad4c097cdc 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -1261,11 +1261,11 @@ static address generate_dilithiumDecomposePoly_avx(StubGenerator *stubgen, } // r1 in RPlus // int r1 = rplus - r0 - (dilithium_q - 1); + // r1 = (r1 | (-r1)) >> 31; // 0 if rplus - r0 == (dilithium_q - 1), -1 otherwise for (int i = 0; i < regCnt; i++) { __ vpsubd(RPlus[i], RPlus[i], R0[i], vector_len); } - // r1 = (r1 | (-r1)) >> 31; // 0 if rplus - r0 == (dilithium_q - 1), -1 otherwise if (vector_len == Assembler::AVX_512bit) { KRegister EqMsk[] = {k1, k2, k3, k4}; for (int i = 0; i < regCnt; i++) { @@ -1280,7 +1280,6 @@ static address generate_dilithiumDecomposePoly_avx(StubGenerator *stubgen, // r1 in Quotient // r1 = r1 & quotient; // copy 0 or keep as is, using EqMsk as filter for (int i = 0; i < regCnt; i++) { - // FIXME: replace with void evmovdqul(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len);? __ evpandd(Quotient[i], EqMsk[i], Quotient[i], zero, true, vector_len); } } else { diff --git a/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java b/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java deleted file mode 100644 index 0cbc15b0a0198..0000000000000 --- a/test/micro/org/openjdk/bench/javax/crypto/full/MLDSABench.java +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Copyright (c) 2025, Intel Corporation. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package org.openjdk.bench.javax.crypto.full; - -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.Param; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.OutputTimeUnit; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Constructor; -import java.util.Random; -import java.util.concurrent.TimeUnit; - -import javax.crypto.KeyGenerator; -import javax.crypto.Mac; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; - -@Measurement(iterations = 3, time = 10) -@Warmup(iterations = 3, time = 10) -@OutputTimeUnit(TimeUnit.MILLISECONDS) -@Fork(value = 1, jvmArgs = {"--add-opens", "java.base/sun.security.provider=ALL-UNNAMED"}) -public class MLDSABench extends CryptoBase { - public static final int SET_SIZE = 128; - - private int[][] coeffs1; - private int[][] coeffs2; - private int[][] prod1; - private int[][] prod2; - private MethodHandle mult, multConst, decompose, almostNtt, inverseNtt; - int index = 0; - - public static int[][] fillRandom(int[][] data) { - Random rnd = new Random(); - for (int[] d : data) { - for (int j = 0; j kClazz = Class.forName("sun.security.provider.ML_DSA"); - Constructor constructor = kClazz.getDeclaredConstructor( - int.class); - constructor.setAccessible(true); - - Method m = kClazz.getDeclaredMethod("implDilithiumNttMult", - int[].class, int[].class, int[].class); - m.setAccessible(true); - mult = lookup.unreflect(m); - - m = kClazz.getDeclaredMethod("implDilithiumMontMulByConstant", - int[].class, int.class); - m.setAccessible(true); - multConst = lookup.unreflect(m); - - m = kClazz.getDeclaredMethod("implDilithiumDecomposePoly", - int[].class, int[].class, int[].class, int.class, int.class); - m.setAccessible(true); - decompose = lookup.unreflect(m); - - m = kClazz.getDeclaredMethod("implDilithiumAlmostNtt", - int[].class, int[].class); - m.setAccessible(true); - almostNtt = lookup.unreflect(m); - - m = kClazz.getDeclaredMethod("implDilithiumAlmostInverseNtt", - int[].class, int[].class); - m.setAccessible(true); - inverseNtt = lookup.unreflect(m); - } - - @Benchmark - public void mult() throws Exception, Throwable { - mult.invoke(prod1[index], coeffs1[index], coeffs2[index]); - index = (index + 1) % SET_SIZE; - } - - @Benchmark - public void multConst() throws Exception, Throwable { - multConst.invoke(prod1[index], coeffs1[index][index]); - index = (index + 1) % SET_SIZE; - } - - @Benchmark - public void multDecompose() throws Exception, Throwable { - int gamma2 = 95232; - if (coeffs1[index][0]%2==1) { - gamma2 = coeffs1[index][1]; - } - int multiplier = (gamma2 == 95232 ? 22 : 8); - decompose.invoke(coeffs1[index], prod1[index], prod2[index], 2 * gamma2, multiplier); - index = (index + 1) % SET_SIZE; - } - - @Benchmark - public void multAlmostNtt() throws Exception, Throwable { - almostNtt.invoke(coeffs1[index], MONT_ZETAS_FOR_VECTOR_NTT); - index = (index + 1) % SET_SIZE; - } - - @Benchmark - public void multInverseNtt() throws Exception, Throwable { - inverseNtt.invoke(coeffs2[index], MONT_ZETAS_FOR_VECTOR_INVERSE_NTT); - index = (index + 1) % SET_SIZE; - } - - // Copied constants from sun.security.provider.ML_DSA - private static final int ML_DSA_N = 256; - private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{ - -1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416, - 3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036, - 3833893, 260646, 1104333, 1667432, -1910376, 1803090, -1723600, 426683, - -472078, -1717735, 975884, -2213111, -269760, -3866901, -3523897, 3038916, - 1799107, 3694233, -1652634, -810149, -3014001, -1616392, -162844, 3183426, - 1207385, -185531, -3369112, -1957272, 164721, -2454455, -2432395, 2013608, - 3776993, -594136, 3724270, 2584293, 1846953, 1671176, 2831860, 542412, - -3406031, -2235880, -777191, -1500165, 1374803, 2546312, -1917081, 1279661, - 1962642, -3306115, -1312455, 451100, 1430225, 3318210, -1237275, 1333058, - 1050970, -1903435, -1869119, 2994039, 3548272, -2635921, -1250494, 3767016, - -1595974, -2486353, -1247620, -4055324, -1265009, 2590150, -2691481, -2842341, - -203044, -1735879, 3342277, -3437287, -4108315, 2437823, -286988, -342297, - 3595838, 768622, 525098, 3556995, -3207046, -2031748, 3122442, 655327, - 522500, 43260, 1613174, -495491, -819034, -909542, -1859098, -900702, - 3193378, 1197226, 3759364, 3520352, -3513181, 1235728, -2434439, -266997, - 3562462, 2446433, -2244091, 3342478, -3817976, -2316500, -3407706, -2091667, - - -3839961, -3839961, 3628969, 3628969, 3881060, 3881060, 3019102, 3019102, - 1439742, 1439742, 812732, 812732, 1584928, 1584928, -1285669, -1285669, - -1341330, - 1341330, -1315589, -1315589, 177440, 177440, 2409325, 2409325, - 1851402, 1851402, -3159746, -3159746, 3553272, 3553272, -189548, -189548, - 1316856, 1316856, -759969, -759969, 210977, 210977, -2389356, -2389356, - 3249728, 3249728, -1653064, -1653064, 8578, 8578, 3724342, 3724342, - -3958618, -3958618, -904516, -904516, 1100098, 1100098, -44288, -44288, - -3097992, -3097992, -508951, -508951, -264944, -264944, 3343383, 3343383, - 1430430, 1430430, -1852771, -1852771, -1349076, -1349076, 381987, 381987, - 1308169, 1308169, 22981, 22981, 1228525, 1228525, 671102, 671102, - 2477047, 2477047, 411027, 411027, 3693493, 3693493, 2967645, 2967645, - -2715295, -2715295, -2147896, -2147896, 983419, 983419, -3412210, -3412210, - -126922, -126922, 3632928, 3632928, 3157330, 3157330, 3190144, 3190144, - 1000202, 1000202, 4083598, 4083598, -1939314, -1939314, 1257611, 1257611, - 1585221, 1585221, -2176455, -2176455, -3475950, -3475950, 1452451, 1452451, - 3041255, 3041255, 3677745, 3677745, 1528703, 1528703, 3930395, 3930395, - - 2797779, 2797779, 2797779, 2797779, -2071892, -2071892, -2071892, -2071892, - 2556880, 2556880, 2556880, 2556880, -3900724, -3900724, -3900724, -3900724, - -3881043, -3881043, -3881043, -3881043, -954230, -954230, -954230, -954230, - -531354, -531354, -531354, -531354, -811944, -811944, -811944, -811944, - -3699596, -3699596, -3699596, -3699596, 1600420, 1600420, 1600420, 1600420, - 2140649, 2140649, 2140649, 2140649, -3507263, -3507263, -3507263, -3507263, - 3821735, 3821735, 3821735, 3821735, -3505694, -3505694, -3505694, -3505694, - 1643818, 1643818, 1643818, 1643818, 1699267, 1699267, 1699267, 1699267, - 539299, 539299, 539299, 539299, -2348700, -2348700, -2348700, -2348700, - 300467, 300467, 300467, 300467, -3539968, -3539968, -3539968, -3539968, - 2867647, 2867647, 2867647, 2867647, -3574422, -3574422, -3574422, -3574422, - 3043716, 3043716, 3043716, 3043716, 3861115, 3861115, 3861115, 3861115, - -3915439, -3915439, -3915439, -3915439, 2537516, 2537516, 2537516, 2537516, - 3592148, 3592148, 3592148, 3592148, 1661693, 1661693, 1661693, 1661693, - -3530437, -3530437, -3530437, -3530437, -3077325, -3077325, -3077325, -3077325, - -95776, -95776, -95776, -95776, -2706023, -2706023, -2706023, -2706023, - - -280005, -280005, -280005, -280005, -280005, -280005, -280005, -280005, - -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, -4010497, - 19422, 19422, 19422, 19422, 19422, 19422, 19422, 19422, - -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, -1757237, - 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, 3277672, - 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, 1399561, - 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, 3859737, - 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, 2118186, - 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, 2108549, - -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, -2619752, - 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, 1119584, - 549488, 549488, 549488, 549488, 549488, 549488, 549488, 549488, - -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, -3585928, - 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, 1079900, - -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, -1024112, - -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, -2725464, - - -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, - -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, -2680103, - -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, - -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, -3111497, - 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, - 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, 2884855, - -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, - -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, -3119733, - 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, - 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, 2091905, - 359251, 359251, 359251, 359251, 359251, 359251, 359251, 359251, - 359251, 359251, 359251, 359251, 359251, 359251, 359251, 359251, - -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, - -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, -2353451, - -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, - -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, -1826347, - - -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, - -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, - -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, - -466468, -466468, -466468, -466468, -466468, -466468, -466468, -466468, - 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, - 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, - 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, - 876248, 876248, 876248, 876248, 876248, 876248, 876248, 876248, - 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, - 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, - 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, - 777960, 777960, 777960, 777960, 777960, 777960, 777960, 777960, - -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, - -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, - -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, - -237124, -237124, -237124, -237124, -237124, -237124, -237124, -237124, - - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 518909, 518909, 518909, 518909, 518909, 518909, 518909, 518909, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, 2608894, - - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847, - -25847, -25847, -25847, -25847, -25847, -25847, -25847, -25847 - }; - - private static final int[] MONT_ZETAS_FOR_VECTOR_NTT = new int[]{ - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - 25847, 25847, 25847, 25847, 25847, 25847, 25847, 25847, - - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, -2608894, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - -518909, -518909, -518909, -518909, -518909, -518909, -518909, -518909, - - 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, - 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, - 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, - 237124, 237124, 237124, 237124, 237124, 237124, 237124, 237124, - -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, - -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, - -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, - -777960, -777960, -777960, -777960, -777960, -777960, -777960, -777960, - -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, - -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, - -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, - -876248, -876248, -876248, -876248, -876248, -876248, -876248, -876248, - 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, - 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, - 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, - 466468, 466468, 466468, 466468, 466468, 466468, 466468, 466468, - - 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, - 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, 1826347, - 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, - 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, 2353451, - -359251, -359251, -359251, -359251, -359251, -359251, -359251, -359251, - -359251, -359251, -359251, -359251, -359251, -359251, -359251, -359251, - -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, - -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, -2091905, - 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, - 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, 3119733, - -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, - -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, -2884855, - 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, - 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, 3111497, - 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, - 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, 2680103, - - 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, 2725464, - 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, 1024112, - -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, -1079900, - 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, 3585928, - -549488, -549488, -549488, -549488, -549488, -549488, -549488, -549488, - -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, -1119584, - 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, 2619752, - -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, -2108549, - -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, -2118186, - -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, -3859737, - -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, -1399561, - -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, -3277672, - 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, 1757237, - -19422, -19422, -19422, -19422, -19422, -19422, -19422, -19422, - 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, 4010497, - 280005, 280005, 280005, 280005, 280005, 280005, 280005, 280005, - - 2706023, 2706023, 2706023, 2706023, 95776, 95776, 95776, 95776, - 3077325, 3077325, 3077325, 3077325, 3530437, 3530437, 3530437, 3530437, - -1661693, -1661693, -1661693, -1661693, -3592148, -3592148, -3592148, -3592148, - -2537516, -2537516, -2537516, -2537516, 3915439, 3915439, 3915439, 3915439, - -3861115, -3861115, -3861115, -3861115, -3043716, -3043716, -3043716, -3043716, - 3574422, 3574422, 3574422, 3574422, -2867647, -2867647, -2867647, -2867647, - 3539968, 3539968, 3539968, 3539968, -300467, -300467, -300467, -300467, - 2348700, 2348700, 2348700, 2348700, -539299, -539299, -539299, -539299, - -1699267, -1699267, -1699267, -1699267, -1643818, -1643818, -1643818, -1643818, - 3505694, 3505694, 3505694, 3505694, -3821735, -3821735, -3821735, -3821735, - 3507263, 3507263, 3507263, 3507263, -2140649, -2140649, -2140649, -2140649, - -1600420, -1600420, -1600420, -1600420, 3699596, 3699596, 3699596, 3699596, - 811944, 811944, 811944, 811944, 531354, 531354, 531354, 531354, - 954230, 954230, 954230, 954230, 3881043, 3881043, 3881043, 3881043, - 3900724, 3900724, 3900724, 3900724, -2556880, -2556880, -2556880, -2556880, - 2071892, 2071892, 2071892, 2071892, -2797779, -2797779, -2797779, -2797779, - - -3930395, -3930395, -1528703, -1528703, -3677745, -3677745, -3041255, -3041255, - -1452451, -1452451, 3475950, 3475950, 2176455, 2176455, -1585221, -1585221, - -1257611, -1257611, 1939314, 1939314, -4083598, -4083598, -1000202, -1000202, - -3190144, -3190144, -3157330, -3157330, -3632928, -3632928, 126922, 126922, - 3412210, 3412210, -983419, -983419, 2147896, 2147896, 2715295, 2715295, - -2967645, -2967645, -3693493, -3693493, -411027, -411027, -2477047, -2477047, - -671102, -671102, -1228525, -1228525, -22981, -22981, -1308169, -1308169, - -381987, -381987, 1349076, 1349076, 1852771, 1852771, -1430430, -1430430, - -3343383, -3343383, 264944, 264944, 508951, 508951, 3097992, 3097992, - 44288, 44288, -1100098, -1100098, 904516, 904516, 3958618, 3958618, - -3724342, -3724342, -8578, -8578, 1653064, 1653064, -3249728, -3249728, - 2389356, 2389356, -210977, -210977, 759969, 759969, -1316856, -1316856, - 189548, 189548, -3553272, -3553272, 3159746, 3159746, -1851402, -1851402, - -2409325, -2409325, -177440, -177440, 1315589, 1315589, 1341330, 1341330, - 1285669, 1285669, -1584928, -1584928, -812732, -812732, -1439742, -1439742, - -3019102, -3019102, -3881060, -3881060, -3628969, -3628969, 3839961, 3839961, - - 2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, - 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, - 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, - -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, - 342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, - 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, - -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, - -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, - -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, - -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, - -2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, - -3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, - -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, - -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, - -2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, - -554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782 - }; -} From 691e1dfc03932dd0ec712cdff41a42cc90ec3ade Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Mon, 24 Nov 2025 21:16:52 +0000 Subject: [PATCH 7/9] comments from Ferenc --- .../x86/stubGenerator_x86_64_dilithium.cpp | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index 39fad4c097cdc..57e818ff3cf08 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -85,7 +85,7 @@ static address unshufflePermsAddr(int offset) { // | B | | D | | ... // +-----+-----+-----+-----+----- // -// NOTE: size 0 and 1 are used for initial and final shuffles respectivelly of +// NOTE: size 0 and 1 are used for initial and final shuffles respectively of // dilithiumAlmostInverseNtt and dilithiumAlmostNtt. For size 0 and 1, input1[] // and input2[] are modified in-place (and output2 is used as a temporary) // @@ -245,14 +245,14 @@ static auto whole_shuffle(Register scratch, KRegister mergeMask1, KRegister merg // We do Montgomery multiplications of two AVX registers in 4 steps: // 1. Do the multiplications of the corresponding even numbered slots into -// the odd numbered slots of a scratch2 register. -// 2. Swap the even and odd numbered slots of the original input registers.* -// 3. Similar to step 1, but into output register. -// 4. Combine the outputs of step 1 and step 3 into the output of the Montgomery -// multiplication. -// (*For levels 0-6 in the Ntt and levels 1-7 of the inverse Ntt, need NOT swap -// the second operand (zetas) since the odd slots contain the same number -// as the corresponding even one. This is indicated by input2NeedsShuffle=false) +// the odd numbered slots of the scratch2 register. +// 2. Swap the even and odd numbered slots of the original input registers.(*Note) +// 3. Similar to step 1, but multiplication result is placed into output register. +// 4. Combine odd/even slots respectively from the scratch2 and output registers +// into the output register for the final result of the Montgomery multiplication. +// (*Note: For levels 0-6 in the Ntt and levels 1-7 of the inverse Ntt, need NOT +// swap the second operand (zetas) since the odd slots contain the same number +// as the corresponding even one. This is indicated by input2NeedsShuffle=false) // // The registers to be multiplied are in input1[] and inputs2[]. The results go // into output[]. Two scratch[] register arrays are expected. input1[] can @@ -279,7 +279,7 @@ static auto whole_montMul(XMMRegister montQInvModR, XMMRegister dilithium_q, // If so, use output: const XMMRegister* scratch = scratch1 == input1 ? output: scratch1; - // scratch = input1_even*intput2_even + // scratch = input1_even * intput2_even for (int i = 0; i < regCnt; i++) { __ vpmuldq(scratch[i], input1[i], input2[i], vector_len); } @@ -476,7 +476,7 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, // level 0-3 can be done by shuffling registers (also notice fewer zetas loads, they repeat) // level 0 - 128 // scratch1 = coeffs3 * zetas1 - // coeffs3, coeffs1 = coeffs1±scratch1 + // coeffs3, coeffs1 = coeffs1 ± scratch1 // scratch1 = coeffs4 * zetas1 // coeffs4, coeffs2 = coeffs2 ± scratch1 __ vmovdqu(Zetas1[0], Address(zetas, 0), vector_len); @@ -521,12 +521,12 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, // coeffs2_2 = coeffs1_2 - scratch1 // coeffs1_2 = coeffs1_2 + scratch1 loadXmms(Zetas3, zetas, level * 512, vector_len, _masm); - shuffle(Scratch1, Coeffs1_2, Coeffs2_2, distance * 32); //Coeffs2_2 freed + shuffle(Scratch1, Coeffs1_2, Coeffs2_2, distance * 32); // Coeffs2_2 freed montMul64(Scratch1, Scratch1, Zetas3, Coeffs2_2, Scratch2, level==7); sub_add(Coeffs2_2, Coeffs1_2, Coeffs1_2, Scratch1, vector_len, _masm); loadXmms(Zetas3, zetas, 4*64 + level * 512, vector_len, _masm); - shuffle(Scratch1, Coeffs3_2, Coeffs4_2, distance * 32); //Coeffs4_2 freed + shuffle(Scratch1, Coeffs3_2, Coeffs4_2, distance * 32); // Coeffs4_2 freed montMul64(Scratch1, Scratch1, Zetas3, Coeffs4_2, Scratch2, level==7); sub_add(Coeffs4_2, Coeffs3_2, Coeffs3_2, Scratch1, vector_len, _masm); } @@ -551,15 +551,15 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, const XMMRegister Coeffs1_2[] = {xmm0, xmm1, xmm2, xmm3}; const XMMRegister Coeffs2_2[] = {xmm4, xmm5, xmm6, xmm7}; - // Since we cannot fit the entire payload into registers, we process - // input in two stages. First half, load 8 registers 32 integers each apart. - // With one load, we can process level 0-2 (128-, 64- and 32-integers apart) - // Remaining levels, load 8 registers from consecutive memory (16-, 8-, 4-, - // 2-, 1-integer appart) - // Levels 5, 6, 7 (4-, 2-, 1-integer appart) require shuffles within registers - // Other levels, shuffles can be done by re-aranging register order + // Since we cannot fit the entire payload into registers, we process the + // input in two stages. For the first half, load 8 registers, each 32 integers + // apart. With one load, we can process level 0-2 (128-, 64- and 32-integers + // apart). For the remaining levels, load 8 registers from consecutive memory + // (16-, 8-, 4-, 2-, 1-integer apart) + // Levels 5, 6, 7 (4-, 2-, 1-integer apart) require shuffles within registers. + // On the other levels, shuffles can be done by rearanging the register order - // Four batches of 8 registers each, 128 bytes appart + // Four batches of 8 registers each, 128 bytes apart for (int i=0; i<4; i++) { loadXmms(Coeffs1_2, coeffs, i*32 + 0*128, vector_len, _masm, 4, 128); loadXmms(Coeffs2_2, coeffs, i*32 + 4*128, vector_len, _masm, 4, 128); @@ -698,7 +698,7 @@ static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, // Java version. // In each of these iterations half of the coefficients are added to and // subtracted from the other half of the coefficients then the result of - // the substration is (Montgomery) multiplied by the corresponding zetas. + // the subtration is (Montgomery) multiplied by the corresponding zetas. // In each level we just shuffle the coefficients so that the results of // the additions and subtractions go to the vector registers so that they // align with each other and the zetas. @@ -847,7 +847,7 @@ static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, storeXmms(coeffs, 128 + i*256, Coeffs2_2, vector_len, _masm, 4); } - // Four batches of 8 registers each, 128 bytes appart + // Four batches of 8 registers each, 128 bytes apart for (int i=0; i<4; i++) { loadXmms(Coeffs1_2, coeffs, i*32 + 0*128, vector_len, _masm, 4, 128); loadXmms(Coeffs2_2, coeffs, i*32 + 4*128, vector_len, _masm, 4, 128); From bfc16f1f6f6a1190aa8ae23b984b324c4915d746 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Mon, 24 Nov 2025 22:01:11 +0000 Subject: [PATCH 8/9] spelling --- src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index 57e818ff3cf08..9311bd84cd5a5 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -557,7 +557,7 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, // apart). For the remaining levels, load 8 registers from consecutive memory // (16-, 8-, 4-, 2-, 1-integer apart) // Levels 5, 6, 7 (4-, 2-, 1-integer apart) require shuffles within registers. - // On the other levels, shuffles can be done by rearanging the register order + // On the other levels, shuffles can be done by rearranging the register order // Four batches of 8 registers each, 128 bytes apart for (int i=0; i<4; i++) { @@ -698,7 +698,7 @@ static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, // Java version. // In each of these iterations half of the coefficients are added to and // subtracted from the other half of the coefficients then the result of - // the subtration is (Montgomery) multiplied by the corresponding zetas. + // the subtraction is (Montgomery) multiplied by the corresponding zetas. // In each level we just shuffle the coefficients so that the results of // the additions and subtractions go to the vector registers so that they // align with each other and the zetas. From 094051e0ae56ac3995d93f1e61f84557fda344f2 Mon Sep 17 00:00:00 2001 From: Volodymyr Paprotski Date: Tue, 25 Nov 2025 20:02:05 +0000 Subject: [PATCH 9/9] comments from Jatin --- src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp index 9311bd84cd5a5..b95909394684f 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp @@ -362,7 +362,7 @@ static void sub_add(const XMMRegister subResult[], const XMMRegister addResult[] } static void loadXmms(const XMMRegister destinationRegs[], Register source, int offset, - int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { + int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { if (vector_len == Assembler::AVX_256bit) { regCnt = regCnt == -1 ? 2 : regCnt; @@ -378,7 +378,7 @@ static void loadXmms(const XMMRegister destinationRegs[], Register source, int o } static void storeXmms(Register destination, int offset, const XMMRegister xmmRegs[], - int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { + int vector_len, MacroAssembler *_masm, int regCnt = -1, int memStep = -1) { if (vector_len == Assembler::AVX_256bit) { regCnt = regCnt == -1 ? 2 : regCnt; memStep = memStep == -1 ? 32 : memStep; @@ -656,7 +656,7 @@ static address generate_dilithiumAlmostNtt_avx(StubGenerator *stubgen, // coeffs (int[256]) = c_rarg0 // zetas (int[128*8]) = c_rarg1 static address generate_dilithiumAlmostInverseNtt_avx(StubGenerator *stubgen, - int vector_len,MacroAssembler *_masm) { + int vector_len, MacroAssembler *_masm) { __ align(CodeEntryAlignment); StubId stub_id = StubId::stubgen_dilithiumAlmostInverseNtt_id; StubCodeMark mark(stubgen, stub_id);