Skip to content

Commit 01bab43

Browse files
committed
RV64: Avoid repeated VLEN-evaluation in rejection sampling
For VLEN >= 512, there are tail iterations in the rejection handling loop where we require less coefficients than fit into a vector, requiring a adjustment of the dynamic VL. The previous code did re-evaluate the dynamic VL in every iteration, which incurred a signifcant runtime cost. This commit instead splits the rejection sampling loop in two nested loops, where the inner loop proceeds for a fixed VL and the outer loop re-evaluates the VL. For VL <= 256, there is only one iteration of the outer loop, rendering it as efficient as the original verison. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 6dc4a61 commit 01bab43

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

mlkem/src/native/riscv64/src/rv64v_poly.c

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -751,24 +751,30 @@ unsigned int mlk_rv64v_rej_uniform(int16_t *r, unsigned int len,
751751
const vuint16m1_t sel12v = __riscv_vsrl_vx_u16m1(srl12v, 4, vl);
752752
const vuint16m1_t sll12v = __riscv_vsll_vx_u16m1(vid, 2, vl);
753753

754-
x = __riscv_vle16_v_u16m1((uint16_t *)&buf[pos], vl23);
755-
pos += vl23 * 2;
756-
x = __riscv_vrgather_vv_u16m1(x, sel12v, vl);
757-
x = __riscv_vor_vv_u16m1(
758-
__riscv_vsrl_vv_u16m1(x, srl12v, vl),
759-
__riscv_vsll_vv_u16m1(__riscv_vslidedown(x, 1, vl), sll12v, vl), vl);
760-
x = __riscv_vand_vx_u16m1(x, 0xFFF, vl);
761-
762-
lt = __riscv_vmsltu_vx_u16m1_b16(x, MLKEM_Q, vl);
763-
y = __riscv_vcompress_vm_u16m1(x, lt, vl);
764-
n = __riscv_vcpop_m_b16(lt, vl);
765-
766-
if (ctr + n > len)
754+
/* Functionally, this loop is not necessary, but it avoids re-evaluating
755+
* the VL too many times. In particular, in the first outer iteration,
756+
* the inner loop will process the bulk of the data with fixed VL. */
757+
while (ctr < len && vl23 * 2 <= buflen - pos)
767758
{
768-
n = len - ctr;
759+
x = __riscv_vle16_v_u16m1((uint16_t *)&buf[pos], vl23);
760+
pos += vl23 * 2;
761+
x = __riscv_vrgather_vv_u16m1(x, sel12v, vl);
762+
x = __riscv_vor_vv_u16m1(
763+
__riscv_vsrl_vv_u16m1(x, srl12v, vl),
764+
__riscv_vsll_vv_u16m1(__riscv_vslidedown(x, 1, vl), sll12v, vl), vl);
765+
x = __riscv_vand_vx_u16m1(x, 0xFFF, vl);
766+
767+
lt = __riscv_vmsltu_vx_u16m1_b16(x, MLKEM_Q, vl);
768+
y = __riscv_vcompress_vm_u16m1(x, lt, vl);
769+
n = __riscv_vcpop_m_b16(lt, vl);
770+
771+
if (ctr + n > len)
772+
{
773+
n = len - ctr;
774+
}
775+
__riscv_vse16_v_u16m1((uint16_t *)&r[ctr], y, n);
776+
ctr += n;
769777
}
770-
__riscv_vse16_v_u16m1((uint16_t *)&r[ctr], y, n);
771-
ctr += n;
772778
}
773779

774780
return ctr;

0 commit comments

Comments
 (0)