diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 44998dc8..6446cbcf 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -168,6 +168,10 @@ impl Simd for Avx2 { unsafe { _mm_fmsub_ps(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + unsafe { _mm_fnmadd_ps(a.into(), b.into(), c.into()).simd_into(self) } + } + #[inline(always)] fn floor_f32x4(self, a: f32x4) -> f32x4 { unsafe { _mm_round_ps::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()).simd_into(self) @@ -1338,6 +1342,10 @@ impl Simd for Avx2 { unsafe { _mm_fmsub_pd(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + unsafe { _mm_fnmadd_pd(a.into(), b.into(), c.into()).simd_into(self) } + } + #[inline(always)] fn floor_f64x2(self, a: f64x2) -> f64x2 { unsafe { _mm_round_pd::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()).simd_into(self) @@ -1559,6 +1567,10 @@ impl Simd for Avx2 { unsafe { _mm256_fmsub_ps(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { + unsafe { _mm256_fnmadd_ps(a.into(), b.into(), c.into()).simd_into(self) } + } + #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { unsafe { _mm256_round_ps::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()) @@ -3025,6 +3037,10 @@ impl Simd for Avx2 { unsafe { _mm256_fmsub_pd(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { + unsafe { _mm256_fnmadd_pd(a.into(), b.into(), c.into()).simd_into(self) } + } + #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { unsafe { _mm256_round_pd::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()) @@ -3301,6 +3317,16 @@ impl Simd for Avx2 { ) } #[inline(always)] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16 { + let (a0, a1) = self.split_f32x16(a); + let (b0, b1) = self.split_f32x16(b); + let (c0, c1) = self.split_f32x16(c); + self.combine_f32x8( + self.mul_neg_add_f32x8(a0, b0, c0), + self.mul_neg_add_f32x8(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x16(self, a: f32x16) -> f32x16 { let (a0, a1) = self.split_f32x16(a); self.combine_f32x8(self.floor_f32x8(a0), self.floor_f32x8(a1)) @@ -4986,6 +5012,16 @@ impl Simd for Avx2 { ) } #[inline(always)] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8 { + let (a0, a1) = self.split_f64x8(a); + let (b0, b1) = self.split_f64x8(b); + let (c0, c1) = self.split_f64x8(c); + self.combine_f64x4( + self.mul_neg_add_f64x4(a0, b0, c0), + self.mul_neg_add_f64x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x8(self, a: f64x8) -> f64x8 { let (a0, a1) = self.split_f64x8(a); self.combine_f64x4(self.floor_f64x4(a0), self.floor_f64x4(a1)) diff --git a/fearless_simd/src/generated/fallback.rs b/fearless_simd/src/generated/fallback.rs index dfa9f489..3d86f2c4 100644 --- a/fearless_simd/src/generated/fallback.rs +++ b/fearless_simd/src/generated/fallback.rs @@ -311,6 +311,10 @@ impl Simd for Fallback { a.mul(b).sub(c) } #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + c.sub(a.mul(b)) + } + #[inline(always)] fn floor_f32x4(self, a: f32x4) -> f32x4 { [ f32::floor(a[0usize]), @@ -3198,6 +3202,10 @@ impl Simd for Fallback { a.mul(b).sub(c) } #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + c.sub(a.mul(b)) + } + #[inline(always)] fn floor_f64x2(self, a: f64x2) -> f64x2 { [f64::floor(a[0usize]), f64::floor(a[1usize])].simd_into(self) } @@ -3471,6 +3479,16 @@ impl Simd for Fallback { ) } #[inline(always)] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { + let (a0, a1) = self.split_f32x8(a); + let (b0, b1) = self.split_f32x8(b); + let (c0, c1) = self.split_f32x8(c); + self.combine_f32x4( + self.mul_neg_add_f32x4(a0, b0, c0), + self.mul_neg_add_f32x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { let (a0, a1) = self.split_f32x8(a); self.combine_f32x4(self.floor_f32x4(a0), self.floor_f32x4(a1)) @@ -5019,6 +5037,16 @@ impl Simd for Fallback { ) } #[inline(always)] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { + let (a0, a1) = self.split_f64x4(a); + let (b0, b1) = self.split_f64x4(b); + let (c0, c1) = self.split_f64x4(c); + self.combine_f64x2( + self.mul_neg_add_f64x2(a0, b0, c0), + self.mul_neg_add_f64x2(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { let (a0, a1) = self.split_f64x4(a); self.combine_f64x2(self.floor_f64x2(a0), self.floor_f64x2(a1)) @@ -5315,6 +5343,16 @@ impl Simd for Fallback { ) } #[inline(always)] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16 { + let (a0, a1) = self.split_f32x16(a); + let (b0, b1) = self.split_f32x16(b); + let (c0, c1) = self.split_f32x16(c); + self.combine_f32x8( + self.mul_neg_add_f32x8(a0, b0, c0), + self.mul_neg_add_f32x8(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x16(self, a: f32x16) -> f32x16 { let (a0, a1) = self.split_f32x16(a); self.combine_f32x8(self.floor_f32x8(a0), self.floor_f32x8(a1)) @@ -6985,6 +7023,16 @@ impl Simd for Fallback { ) } #[inline(always)] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8 { + let (a0, a1) = self.split_f64x8(a); + let (b0, b1) = self.split_f64x8(b); + let (c0, c1) = self.split_f64x8(c); + self.combine_f64x4( + self.mul_neg_add_f64x4(a0, b0, c0), + self.mul_neg_add_f64x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x8(self, a: f64x8) -> f64x8 { let (a0, a1) = self.split_f64x8(a); self.combine_f64x4(self.floor_f64x4(a0), self.floor_f64x4(a1)) diff --git a/fearless_simd/src/generated/neon.rs b/fearless_simd/src/generated/neon.rs index 8cd2d9fb..0cfe41f9 100644 --- a/fearless_simd/src/generated/neon.rs +++ b/fearless_simd/src/generated/neon.rs @@ -159,6 +159,10 @@ impl Simd for Neon { unsafe { vnegq_f32(vfmsq_f32(c.into(), b.into(), a.into())).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + unsafe { vfmsq_f32(c.into(), b.into(), a.into()).simd_into(self) } + } + #[inline(always)] fn floor_f32x4(self, a: f32x4) -> f32x4 { unsafe { vrndmq_f32(a.into()).simd_into(self) } } @@ -1227,6 +1231,10 @@ impl Simd for Neon { unsafe { vnegq_f64(vfmsq_f64(c.into(), b.into(), a.into())).simd_into(self) } } #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + unsafe { vfmsq_f64(c.into(), b.into(), a.into()).simd_into(self) } + } + #[inline(always)] fn floor_f64x2(self, a: f64x2) -> f64x2 { unsafe { vrndmq_f64(a.into()).simd_into(self) } } @@ -1476,6 +1484,16 @@ impl Simd for Neon { ) } #[inline(always)] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { + let (a0, a1) = self.split_f32x8(a); + let (b0, b1) = self.split_f32x8(b); + let (c0, c1) = self.split_f32x8(c); + self.combine_f32x4( + self.mul_neg_add_f32x4(a0, b0, c0), + self.mul_neg_add_f32x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { let (a0, a1) = self.split_f32x8(a); self.combine_f32x4(self.floor_f32x4(a0), self.floor_f32x4(a1)) @@ -3011,6 +3029,16 @@ impl Simd for Neon { ) } #[inline(always)] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { + let (a0, a1) = self.split_f64x4(a); + let (b0, b1) = self.split_f64x4(b); + let (c0, c1) = self.split_f64x4(c); + self.combine_f64x2( + self.mul_neg_add_f64x2(a0, b0, c0), + self.mul_neg_add_f64x2(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { let (a0, a1) = self.split_f64x4(a); self.combine_f64x2(self.floor_f64x2(a0), self.floor_f64x2(a1)) @@ -3307,6 +3335,16 @@ impl Simd for Neon { ) } #[inline(always)] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16 { + let (a0, a1) = self.split_f32x16(a); + let (b0, b1) = self.split_f32x16(b); + let (c0, c1) = self.split_f32x16(c); + self.combine_f32x8( + self.mul_neg_add_f32x8(a0, b0, c0), + self.mul_neg_add_f32x8(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x16(self, a: f32x16) -> f32x16 { let (a0, a1) = self.split_f32x16(a); self.combine_f32x8(self.floor_f32x8(a0), self.floor_f32x8(a1)) @@ -4816,6 +4854,16 @@ impl Simd for Neon { ) } #[inline(always)] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8 { + let (a0, a1) = self.split_f64x8(a); + let (b0, b1) = self.split_f64x8(b); + let (c0, c1) = self.split_f64x8(c); + self.combine_f64x4( + self.mul_neg_add_f64x4(a0, b0, c0), + self.mul_neg_add_f64x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x8(self, a: f64x8) -> f64x8 { let (a0, a1) = self.split_f64x8(a); self.combine_f64x4(self.floor_f64x4(a0), self.floor_f64x4(a1)) diff --git a/fearless_simd/src/generated/simd_trait.rs b/fearless_simd/src/generated/simd_trait.rs index b9ed4e38..f88d56dd 100644 --- a/fearless_simd/src/generated/simd_trait.rs +++ b/fearless_simd/src/generated/simd_trait.rs @@ -166,6 +166,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f32x4(self, a: f32x4) -> f32x4; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -659,6 +661,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f64x2(self, a: f64x2) -> f64x2; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -752,6 +756,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f32x8(self, a: f32x8) -> f32x8; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -1267,6 +1273,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f64x4(self, a: f64x4) -> f64x4; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -1364,6 +1372,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f32x16(self, a: f32x16) -> f32x16; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -1873,6 +1883,8 @@ pub trait Simd: Sized + Clone + Copy + Send + Sync + Seal + 'static { fn mul_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8; #[doc = "Compute `(a * b) - c` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8; + #[doc = "Compute `c - (a * b)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor_f64x8(self, a: f64x8) -> f64x8; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] @@ -2024,6 +2036,8 @@ pub trait SimdFloat: fn mul_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self; #[doc = "Compute `(self * op1) - op2` (fused multiply-subtract) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors."] fn mul_sub(self, op1: impl SimdInto, op2: impl SimdInto) -> Self; + #[doc = "Compute `op2 - (self * op1)` (fused negated multiply-add) for each element.\n\nDepending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors."] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self; #[doc = "Return the largest integer less than or equal to each element, that is, round towards negative infinity."] fn floor(self) -> Self; #[doc = "Return the smallest integer greater than or equal to each element, that is, round towards positive infinity."] diff --git a/fearless_simd/src/generated/simd_types.rs b/fearless_simd/src/generated/simd_types.rs index ddd7037d..1eea17c3 100644 --- a/fearless_simd/src/generated/simd_types.rs +++ b/fearless_simd/src/generated/simd_types.rs @@ -199,6 +199,11 @@ impl crate::SimdFloat for f32x4 { .mul_sub_f32x4(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f32x4(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f32x4(self) } @@ -1995,6 +2000,11 @@ impl crate::SimdFloat for f64x2 { .mul_sub_f64x2(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f64x2(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f64x2(self) } @@ -2374,6 +2384,11 @@ impl crate::SimdFloat for f32x8 { .mul_sub_f32x8(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f32x8(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f32x8(self) } @@ -4339,6 +4354,11 @@ impl crate::SimdFloat for f64x4 { .mul_sub_f64x4(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f64x4(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f64x4(self) } @@ -4741,6 +4761,11 @@ impl crate::SimdFloat for f32x16 { .mul_sub_f32x16(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f32x16(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f32x16(self) } @@ -6823,6 +6848,11 @@ impl crate::SimdFloat for f64x8 { .mul_sub_f64x8(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) } #[inline(always)] + fn mul_neg_add(self, op1: impl SimdInto, op2: impl SimdInto) -> Self { + self.simd + .mul_neg_add_f64x8(self, op1.simd_into(self.simd), op2.simd_into(self.simd)) + } + #[inline(always)] fn floor(self) -> Self { self.simd.floor_f64x8(self) } diff --git a/fearless_simd/src/generated/sse4_2.rs b/fearless_simd/src/generated/sse4_2.rs index 8c311f40..b3055e84 100644 --- a/fearless_simd/src/generated/sse4_2.rs +++ b/fearless_simd/src/generated/sse4_2.rs @@ -173,6 +173,10 @@ impl Simd for Sse4_2 { a * b - c } #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + c - a * b + } + #[inline(always)] fn floor_f32x4(self, a: f32x4) -> f32x4 { unsafe { _mm_round_ps::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()).simd_into(self) @@ -1378,6 +1382,10 @@ impl Simd for Sse4_2 { a * b - c } #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + c - a * b + } + #[inline(always)] fn floor_f64x2(self, a: f64x2) -> f64x2 { unsafe { _mm_round_pd::<{ _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC }>(a.into()).simd_into(self) @@ -1632,6 +1640,16 @@ impl Simd for Sse4_2 { ) } #[inline(always)] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { + let (a0, a1) = self.split_f32x8(a); + let (b0, b1) = self.split_f32x8(b); + let (c0, c1) = self.split_f32x8(c); + self.combine_f32x4( + self.mul_neg_add_f32x4(a0, b0, c0), + self.mul_neg_add_f32x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { let (a0, a1) = self.split_f32x8(a); self.combine_f32x4(self.floor_f32x4(a0), self.floor_f32x4(a1)) @@ -3169,6 +3187,16 @@ impl Simd for Sse4_2 { ) } #[inline(always)] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { + let (a0, a1) = self.split_f64x4(a); + let (b0, b1) = self.split_f64x4(b); + let (c0, c1) = self.split_f64x4(c); + self.combine_f64x2( + self.mul_neg_add_f64x2(a0, b0, c0), + self.mul_neg_add_f64x2(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { let (a0, a1) = self.split_f64x4(a); self.combine_f64x2(self.floor_f64x2(a0), self.floor_f64x2(a1)) @@ -3465,6 +3493,16 @@ impl Simd for Sse4_2 { ) } #[inline(always)] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16 { + let (a0, a1) = self.split_f32x16(a); + let (b0, b1) = self.split_f32x16(b); + let (c0, c1) = self.split_f32x16(c); + self.combine_f32x8( + self.mul_neg_add_f32x8(a0, b0, c0), + self.mul_neg_add_f32x8(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x16(self, a: f32x16) -> f32x16 { let (a0, a1) = self.split_f32x16(a); self.combine_f32x8(self.floor_f32x8(a0), self.floor_f32x8(a1)) @@ -5142,6 +5180,16 @@ impl Simd for Sse4_2 { ) } #[inline(always)] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8 { + let (a0, a1) = self.split_f64x8(a); + let (b0, b1) = self.split_f64x8(b); + let (c0, c1) = self.split_f64x8(c); + self.combine_f64x4( + self.mul_neg_add_f64x4(a0, b0, c0), + self.mul_neg_add_f64x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x8(self, a: f64x8) -> f64x8 { let (a0, a1) = self.split_f64x8(a); self.combine_f64x4(self.floor_f64x4(a0), self.floor_f64x4(a1)) diff --git a/fearless_simd/src/generated/wasm.rs b/fearless_simd/src/generated/wasm.rs index 076bb9a0..9e53b6fb 100644 --- a/fearless_simd/src/generated/wasm.rs +++ b/fearless_simd/src/generated/wasm.rs @@ -175,6 +175,16 @@ impl Simd for WasmSimd128 { fn mul_sub_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { self.sub_f32x4(self.mul_f32x4(a, b), c) } + #[cfg(target_feature = "relaxed-simd")] + #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + f32x4_relaxed_nmadd(a.into(), b.into(), c.into()).simd_into(self) + } + #[cfg(not(target_feature = "relaxed-simd"))] + #[inline(always)] + fn mul_neg_add_f32x4(self, a: f32x4, b: f32x4, c: f32x4) -> f32x4 { + self.sub_f32x4(c, self.mul_f32x4(a, b)) + } #[inline(always)] fn floor_f32x4(self, a: f32x4) -> f32x4 { f32x4_floor(a.into()).simd_into(self) @@ -1316,6 +1326,16 @@ impl Simd for WasmSimd128 { fn mul_sub_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { self.sub_f64x2(self.mul_f64x2(a, b), c) } + #[cfg(target_feature = "relaxed-simd")] + #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + f64x2_relaxed_nmadd(a.into(), b.into(), c.into()).simd_into(self) + } + #[cfg(not(target_feature = "relaxed-simd"))] + #[inline(always)] + fn mul_neg_add_f64x2(self, a: f64x2, b: f64x2, c: f64x2) -> f64x2 { + self.sub_f64x2(c, self.mul_f64x2(a, b)) + } #[inline(always)] fn floor_f64x2(self, a: f64x2) -> f64x2 { f64x2_floor(a.into()).simd_into(self) @@ -1579,6 +1599,16 @@ impl Simd for WasmSimd128 { ) } #[inline(always)] + fn mul_neg_add_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { + let (a0, a1) = self.split_f32x8(a); + let (b0, b1) = self.split_f32x8(b); + let (c0, c1) = self.split_f32x8(c); + self.combine_f32x4( + self.mul_neg_add_f32x4(a0, b0, c0), + self.mul_neg_add_f32x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { let (a0, a1) = self.split_f32x8(a); self.combine_f32x4(self.floor_f32x4(a0), self.floor_f32x4(a1)) @@ -3114,6 +3144,16 @@ impl Simd for WasmSimd128 { ) } #[inline(always)] + fn mul_neg_add_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { + let (a0, a1) = self.split_f64x4(a); + let (b0, b1) = self.split_f64x4(b); + let (c0, c1) = self.split_f64x4(c); + self.combine_f64x2( + self.mul_neg_add_f64x2(a0, b0, c0), + self.mul_neg_add_f64x2(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { let (a0, a1) = self.split_f64x4(a); self.combine_f64x2(self.floor_f64x2(a0), self.floor_f64x2(a1)) @@ -3410,6 +3450,16 @@ impl Simd for WasmSimd128 { ) } #[inline(always)] + fn mul_neg_add_f32x16(self, a: f32x16, b: f32x16, c: f32x16) -> f32x16 { + let (a0, a1) = self.split_f32x16(a); + let (b0, b1) = self.split_f32x16(b); + let (c0, c1) = self.split_f32x16(c); + self.combine_f32x8( + self.mul_neg_add_f32x8(a0, b0, c0), + self.mul_neg_add_f32x8(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f32x16(self, a: f32x16) -> f32x16 { let (a0, a1) = self.split_f32x16(a); self.combine_f32x8(self.floor_f32x8(a0), self.floor_f32x8(a1)) @@ -5079,6 +5129,16 @@ impl Simd for WasmSimd128 { ) } #[inline(always)] + fn mul_neg_add_f64x8(self, a: f64x8, b: f64x8, c: f64x8) -> f64x8 { + let (a0, a1) = self.split_f64x8(a); + let (b0, b1) = self.split_f64x8(b); + let (c0, c1) = self.split_f64x8(c); + self.combine_f64x4( + self.mul_neg_add_f64x4(a0, b0, c0), + self.mul_neg_add_f64x4(a1, b1, c1), + ) + } + #[inline(always)] fn floor_f64x8(self, a: f64x8) -> f64x8 { let (a0, a1) = self.split_f64x8(a); self.combine_f64x4(self.floor_f64x4(a0), self.floor_f64x4(a1)) diff --git a/fearless_simd_gen/src/arch/neon.rs b/fearless_simd_gen/src/arch/neon.rs index db62a2ad..3386c603 100644 --- a/fearless_simd_gen/src/arch/neon.rs +++ b/fearless_simd_gen/src/arch/neon.rs @@ -37,6 +37,7 @@ fn translate_op(op: &str) -> Option<&'static str> { "min_precise" => "vminnm", "mul_add" => "vfma", "mul_sub" => "vfms", + "mul_neg_add" => "vfms", _ => return None, }) } diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index af45fd3a..9e2942d6 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -201,6 +201,14 @@ fn make_method(method: &str, sig: OpSig, vec_ty: &VecType) -> TokenStream { } } } + "mul_neg_add" => { + let intrinsic = simple_intrinsic("fnmadd", vec_ty); + quote! { + #method_sig { + unsafe { #intrinsic(a.into(), b.into(), c.into()).simd_into(self) } + } + } + } _ => mk_sse4_2::handle_ternary(method_sig, &method_ident, method, vec_ty), }, OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty), diff --git a/fearless_simd_gen/src/mk_fallback.rs b/fearless_simd_gen/src/mk_fallback.rs index 59e7d0e1..4c8b0686 100644 --- a/fearless_simd_gen/src/mk_fallback.rs +++ b/fearless_simd_gen/src/mk_fallback.rs @@ -246,6 +246,13 @@ fn mk_simd_impl() -> TokenStream { a.mul(b).sub(c) } } + } else if method == "mul_neg_add" { + // TODO: Same as above + quote! { + #method_sig { + c.sub(a.mul(b)) + } + } } else { let args = [ quote! { a.into() }, diff --git a/fearless_simd_gen/src/mk_neon.rs b/fearless_simd_gen/src/mk_neon.rs index e4fa6896..dd5a0c07 100644 --- a/fearless_simd_gen/src/mk_neon.rs +++ b/fearless_simd_gen/src/mk_neon.rs @@ -294,7 +294,7 @@ fn mk_simd_impl(level: Level) -> TokenStream { } OpSig::Ternary => { let args = match method { - "mul_add" | "mul_sub" => [ + "mul_add" | "mul_sub" | "mul_neg_add" => [ quote! { c.into() }, quote! { b.into() }, quote! { a.into() }, @@ -312,6 +312,7 @@ fn mk_simd_impl(level: Level) -> TokenStream { let neg = simple_intrinsic("vneg", vec_ty); expr = quote! { #neg(#expr) }; } + // mul_neg_add computes c - (a * b), which is exactly what vfms does quote! { #method_sig { unsafe { diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index 9af15ec5..11c75acb 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -526,6 +526,13 @@ pub(crate) fn handle_ternary( } } } + "mul_neg_add" => { + quote! { + #method_sig { + c - a * b + } + } + } _ => { let args = [ quote! { a.into() }, diff --git a/fearless_simd_gen/src/mk_wasm.rs b/fearless_simd_gen/src/mk_wasm.rs index f1e075eb..3665bb63 100644 --- a/fearless_simd_gen/src/mk_wasm.rs +++ b/fearless_simd_gen/src/mk_wasm.rs @@ -188,34 +188,55 @@ fn mk_simd_impl(level: Level) -> TokenStream { } } OpSig::Ternary => { - if matches!(method, "mul_add" | "mul_sub") { - let add_sub = generic_op_name( - if method == "mul_add" { "add" } else { "sub" }, - vec_ty, - ); + if matches!(method, "mul_add" | "mul_sub" | "mul_neg_add") { + let fallback_op = if method == "mul_neg_add" { + "sub" + } else if method == "mul_add" { + "add" + } else { + "sub" + }; + let add_sub = generic_op_name(fallback_op, vec_ty); let mul = generic_op_name("mul", vec_ty); - let c = if method == "mul_sub" { + let (relaxed_intrinsic, c_expr, fallback_expr) = if method == "mul_add" { + let relaxed_madd = simple_intrinsic("relaxed_madd", vec_ty); + ( + relaxed_madd, + quote! { c.into() }, + quote! { self.#add_sub(self.#mul(a, b), c) }, + ) + } else if method == "mul_sub" { // WebAssembly just... forgot fused multiply-subtract? It seems the // initial proposal // (https://github.com/WebAssembly/relaxed-simd/issues/27) confused it // with negate multiply-add, and nobody ever resolved the confusion. let negate = simple_intrinsic("neg", vec_ty); - quote! { #negate(c.into()) } + let relaxed_madd = simple_intrinsic("relaxed_madd", vec_ty); + ( + relaxed_madd, + quote! { #negate(c.into()) }, + quote! { self.#add_sub(self.#mul(a, b), c) }, + ) } else { - quote! { c.into() } + // mul_neg_add: c - (a * b) + let relaxed_nmadd = simple_intrinsic("relaxed_nmadd", vec_ty); + ( + relaxed_nmadd, + quote! { c.into() }, + quote! { self.#add_sub(c, self.#mul(a, b)) }, + ) }; - let relaxed_madd = simple_intrinsic("relaxed_madd", vec_ty); quote! { #[cfg(target_feature = "relaxed-simd")] #method_sig { - #relaxed_madd(a.into(), b.into(), #c).simd_into(self) + #relaxed_intrinsic(a.into(), b.into(), #c_expr).simd_into(self) } #[cfg(not(target_feature = "relaxed-simd"))] #method_sig { - self.#add_sub(self.#mul(a, b), c) + #fallback_expr } } } else { diff --git a/fearless_simd_gen/src/ops.rs b/fearless_simd_gen/src/ops.rs index bfd02bea..e8cad7e3 100644 --- a/fearless_simd_gen/src/ops.rs +++ b/fearless_simd_gen/src/ops.rs @@ -292,6 +292,13 @@ const FLOAT_OPS: &[Op] = &[ "Compute `({arg0} * {arg1}) - {arg2}` (fused multiply-subtract) for each element.\n\n\ Depending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a subtract, which will result in two rounding errors.", ), + Op::new( + "mul_neg_add", + OpKind::VecTraitMethod, + OpSig::Ternary, + "Compute `{arg2} - ({arg0} * {arg1})` (fused negated multiply-add) for each element.\n\n\ + Depending on hardware support, the result may be computed with only one rounding error, or may be implemented as a regular multiply followed by a negated add, which will result in two rounding errors.", + ), Op::new( "floor", OpKind::VecTraitMethod, diff --git a/fearless_simd_tests/tests/harness/mod.rs b/fearless_simd_tests/tests/harness/mod.rs index a52d31e9..7d78b8db 100644 --- a/fearless_simd_tests/tests/harness/mod.rs +++ b/fearless_simd_tests/tests/harness/mod.rs @@ -145,6 +145,14 @@ fn msub_f32x4(simd: S) { assert_eq!(a.mul_sub(b, c).val, [19.0, 28.0, 37.0, 46.0]); } +#[simd_test] +fn mul_neg_add_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[2.0, 3.0, 4.0, 5.0]); + let b = f32x4::from_slice(simd, &[10.0, 10.0, 10.0, 10.0]); + let c = f32x4::from_slice(simd, &[100.0, 50.0, 25.0, 10.0]); + assert_eq!(a.mul_neg_add(b, c).val, [80.0, 20.0, -15.0, -40.0]); +} + #[simd_test] fn max_precise_f32x4_with_nan(simd: S) { let a = f32x4::from_slice(simd, &[f32::NAN, -3.0, f32::INFINITY, 0.5]); @@ -2712,6 +2720,14 @@ fn madd_f64x2(simd: S) { assert_eq!(a.mul_add(b, c).val, [6.0, 13.0]); } +#[simd_test] +fn mul_neg_add_f64x2(simd: S) { + let a = f64x2::from_slice(simd, &[2.0, 3.0]); + let b = f64x2::from_slice(simd, &[4.0, 5.0]); + let c = f64x2::from_slice(simd, &[20.0, 30.0]); + assert_eq!(a.mul_neg_add(b, c).val, [12.0, 15.0]); +} + #[simd_test] fn floor_f64x2(simd: S) { let a = f64x2::from_slice(simd, &[1.7, -2.3]);