Skip to content

Commit bde3535

Browse files
committed
RV64: Avoid repeated VLEN-evaluation in rejection sampling
For VLEN >= 512, there are tail iterations in the rejection handling loop where we require less coefficients than fit into a vector, requiring a adjustment of the dynamic VL. The previous code did re-evaluate the dynamic VL in every iteration, which incurred a signifcant runtime cost. This commit instead splits the rejection sampling loop in two nested loops, where the inner loop proceeds for a fixed VL and the outer loop re-evaluates the VL. For VL <= 256, there is only one iteration of the outer loop, rendering it as efficient as the original verison. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 2567720 commit bde3535

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

mlkem/src/native/riscv64/src/rv64v_poly.c

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -757,24 +757,30 @@ unsigned int mlk_rv64v_rej_uniform(int16_t *r, unsigned int len,
757757
const vuint16m1_t sel12v = __riscv_vsrl_vx_u16m1(srl12v, 4, vl);
758758
const vuint16m1_t sll12v = __riscv_vsll_vx_u16m1(vid, 2, vl);
759759

760-
x = __riscv_vle16_v_u16m1((uint16_t *)&buf[pos], vl23);
761-
pos += vl23 * 2;
762-
x = __riscv_vrgather_vv_u16m1(x, sel12v, vl);
763-
x = __riscv_vor_vv_u16m1(
764-
__riscv_vsrl_vv_u16m1(x, srl12v, vl),
765-
__riscv_vsll_vv_u16m1(__riscv_vslidedown(x, 1, vl), sll12v, vl), vl);
766-
x = __riscv_vand_vx_u16m1(x, 0xFFF, vl);
767-
768-
lt = __riscv_vmsltu_vx_u16m1_b16(x, MLKEM_Q, vl);
769-
y = __riscv_vcompress_vm_u16m1(x, lt, vl);
770-
n = __riscv_vcpop_m_b16(lt, vl);
771-
772-
if (ctr + n > len)
760+
/* Functionally, this loop is not necessary, but it avoids re-evaluating
761+
* the VL too many times. In particular, in the first outer iteration,
762+
* the inner loop will process the bulk of the data with fixed VL. */
763+
while (ctr < len && vl23 * 2 <= buflen - pos)
773764
{
774-
n = len - ctr;
765+
x = __riscv_vle16_v_u16m1((uint16_t *)&buf[pos], vl23);
766+
pos += vl23 * 2;
767+
x = __riscv_vrgather_vv_u16m1(x, sel12v, vl);
768+
x = __riscv_vor_vv_u16m1(
769+
__riscv_vsrl_vv_u16m1(x, srl12v, vl),
770+
__riscv_vsll_vv_u16m1(__riscv_vslidedown(x, 1, vl), sll12v, vl), vl);
771+
x = __riscv_vand_vx_u16m1(x, 0xFFF, vl);
772+
773+
lt = __riscv_vmsltu_vx_u16m1_b16(x, MLKEM_Q, vl);
774+
y = __riscv_vcompress_vm_u16m1(x, lt, vl);
775+
n = __riscv_vcpop_m_b16(lt, vl);
776+
777+
if (ctr + n > len)
778+
{
779+
n = len - ctr;
780+
}
781+
__riscv_vse16_v_u16m1((uint16_t *)&r[ctr], y, n);
782+
ctr += n;
775783
}
776-
__riscv_vse16_v_u16m1((uint16_t *)&r[ctr], y, n);
777-
ctr += n;
778784
}
779785

780786
return ctr;

0 commit comments

Comments
 (0)