diff --git a/crates/libm-test/src/f8_impl.rs b/crates/libm-test/src/f8_impl.rs index 0683d839..6772e092 100644 --- a/crates/libm-test/src/f8_impl.rs +++ b/crates/libm-test/src/f8_impl.rs @@ -3,8 +3,6 @@ use std::cmp::{self, Ordering}; use std::{fmt, ops}; -use libm::support::hex_float::parse_any; - use crate::Float; /// Sometimes verifying float logic is easiest when all values can quickly be checked exhaustively @@ -499,5 +497,6 @@ impl fmt::LowerHex for f8 { } pub const fn hf8(s: &str) -> f8 { - f8(parse_any(s, 8, 3) as u8) + let Ok(bits) = libm::support::hex_float::parse_hex_exact(s, 8, 3) else { panic!() }; + f8(bits as u8) } diff --git a/src/math/support/env.rs b/src/math/support/env.rs index c05890d9..79630937 100644 --- a/src/math/support/env.rs +++ b/src/math/support/env.rs @@ -46,7 +46,7 @@ pub enum Round { } /// IEEE 754 exception status flags. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct Status(u8); impl Status { @@ -90,16 +90,22 @@ impl Status { /// True if `UNDERFLOW` is set. #[cfg_attr(not(feature = "unstable-public-internals"), allow(dead_code))] - pub fn underflow(self) -> bool { + pub const fn underflow(self) -> bool { self.0 & Self::UNDERFLOW.0 != 0 } + /// True if `OVERFLOW` is set. + #[cfg_attr(not(feature = "unstable-public-internals"), allow(dead_code))] + pub const fn overflow(self) -> bool { + self.0 & Self::OVERFLOW.0 != 0 + } + pub fn set_underflow(&mut self, val: bool) { self.set_flag(val, Self::UNDERFLOW); } /// True if `INEXACT` is set. - pub fn inexact(self) -> bool { + pub const fn inexact(self) -> bool { self.0 & Self::INEXACT.0 != 0 } @@ -114,4 +120,8 @@ impl Status { self.0 &= !mask.0; } } + + pub(crate) const fn with(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } } diff --git a/src/math/support/hex_float.rs b/src/math/support/hex_float.rs index be7d7607..819e2f56 100644 --- a/src/math/support/hex_float.rs +++ b/src/math/support/hex_float.rs @@ -2,149 +2,260 @@ use core::fmt; -use super::{Float, f32_from_bits, f64_from_bits}; +use super::{Float, Round, Status, f32_from_bits, f64_from_bits}; /// Construct a 16-bit float from hex float representation (C-style) #[cfg(f16_enabled)] pub const fn hf16(s: &str) -> f16 { - f16::from_bits(parse_any(s, 16, 10) as u16) + match parse_hex_exact(s, 16, 10) { + Ok(bits) => f16::from_bits(bits as u16), + Err(HexFloatParseError(s)) => panic!("{}", s), + } } /// Construct a 32-bit float from hex float representation (C-style) #[allow(unused)] pub const fn hf32(s: &str) -> f32 { - f32_from_bits(parse_any(s, 32, 23) as u32) + match parse_hex_exact(s, 32, 23) { + Ok(bits) => f32_from_bits(bits as u32), + Err(HexFloatParseError(s)) => panic!("{}", s), + } } /// Construct a 64-bit float from hex float representation (C-style) pub const fn hf64(s: &str) -> f64 { - f64_from_bits(parse_any(s, 64, 52) as u64) + match parse_hex_exact(s, 64, 52) { + Ok(bits) => f64_from_bits(bits as u64), + Err(HexFloatParseError(s)) => panic!("{}", s), + } } /// Construct a 128-bit float from hex float representation (C-style) #[cfg(f128_enabled)] pub const fn hf128(s: &str) -> f128 { - f128::from_bits(parse_any(s, 128, 112)) + match parse_hex_exact(s, 128, 112) { + Ok(bits) => f128::from_bits(bits), + Err(HexFloatParseError(s)) => panic!("{}", s), + } +} +#[derive(Copy, Clone, Debug)] +pub struct HexFloatParseError(&'static str); + +/// Parses any float to its bitwise representation, returning an error if it cannot be represented exactly +pub const fn parse_hex_exact( + s: &str, + bits: u32, + sig_bits: u32, +) -> Result { + match parse_any(s, bits, sig_bits, Round::Nearest) { + Err(e) => Err(e), + Ok((bits, Status::OK)) => Ok(bits), + Ok((_, status)) if status.overflow() => Err(HexFloatParseError("the value is too huge")), + Ok((_, status)) if status.underflow() => Err(HexFloatParseError("the value is too tiny")), + Ok((_, status)) if status.inexact() => Err(HexFloatParseError("the value is too precise")), + Ok(_) => unreachable!(), + } } /// Parse any float from hex to its bitwise representation. -/// -/// `nan_repr` is passed rather than constructed so the platform-specific NaN is returned. -pub const fn parse_any(s: &str, bits: u32, sig_bits: u32) -> u128 { +pub const fn parse_any( + s: &str, + bits: u32, + sig_bits: u32, + round: Round, +) -> Result<(u128, Status), HexFloatParseError> { + let mut b = s.as_bytes(); + + if sig_bits > 119 || bits > 128 || bits < sig_bits + 3 || bits > sig_bits + 30 { + return Err(HexFloatParseError("unsupported target float configuration")); + } + + let neg = matches!(b, [b'-', ..]); + if let &[b'-' | b'+', ref rest @ ..] = b { + b = rest; + } + + let sign_bit = 1 << (bits - 1); + let quiet_bit = 1 << (sig_bits - 1); + let nan = sign_bit - quiet_bit; + let inf = nan - quiet_bit; + + let (mut x, status) = match *b { + [b'i' | b'I', b'n' | b'N', b'f' | b'F'] => (inf, Status::OK), + [b'n' | b'N', b'a' | b'A', b'n' | b'N'] => (nan, Status::OK), + [b'0', b'x' | b'X', ref rest @ ..] => { + let round = match (neg, round) { + // parse("-x", Round::Positive) == -parse("x", Round::Negative) + (true, Round::Positive) => Round::Negative, + (true, Round::Negative) => Round::Positive, + // rounding toward nearest or zero are symmetric + (true, Round::Nearest | Round::Zero) | (false, _) => round, + }; + match parse_finite(rest, bits, sig_bits, round) { + Err(e) => return Err(e), + Ok(res) => res, + } + } + _ => return Err(HexFloatParseError("no hex indicator")), + }; + + if neg { + x ^= sign_bit; + } + + Ok((x, status)) +} + +const fn parse_finite( + b: &[u8], + bits: u32, + sig_bits: u32, + rounding_mode: Round, +) -> Result<(u128, Status), HexFloatParseError> { let exp_bits: u32 = bits - sig_bits - 1; let max_msb: i32 = (1 << (exp_bits - 1)) - 1; // The exponent of one ULP in the subnormals let min_lsb: i32 = 1 - max_msb - sig_bits as i32; - let exp_mask = ((1 << exp_bits) - 1) << sig_bits; + let (mut sig, mut exp) = match parse_hex(b) { + Err(e) => return Err(e), + Ok(Parsed { sig: 0, .. }) => return Ok((0, Status::OK)), + Ok(Parsed { sig, exp }) => (sig, exp), + }; + + let mut round_bits = u128_ilog2(sig) as i32 - sig_bits as i32; + + // Round at least up to min_lsb + if exp < min_lsb - round_bits { + round_bits = min_lsb - exp; + } + + let mut status = Status::OK; - let (neg, mut sig, exp) = match parse_hex(s.as_bytes()) { - Parsed::Finite { neg, sig: 0, .. } => return (neg as u128) << (bits - 1), - Parsed::Finite { neg, sig, exp } => (neg, sig, exp), - Parsed::Infinite { neg } => return ((neg as u128) << (bits - 1)) | exp_mask, - Parsed::Nan { neg } => { - return ((neg as u128) << (bits - 1)) | exp_mask | (1 << (sig_bits - 1)); + exp += round_bits; + + if round_bits > 0 { + // first, prepare for rounding exactly two bits + if round_bits == 1 { + sig <<= 1; + } else if round_bits > 2 { + sig = shr_odd_rounding(sig, (round_bits - 2) as u32); } - }; - // exponents of the least and most significant bits in the value - let lsb = sig.trailing_zeros() as i32; - let msb = u128_ilog2(sig) as i32; - let sig_bits = sig_bits as i32; + if sig & 0b11 != 0 { + status = Status::INEXACT; + } - assert!(msb - lsb <= sig_bits, "the value is too precise"); - assert!(msb + exp <= max_msb, "the value is too huge"); - assert!(lsb + exp >= min_lsb, "the value is too tiny"); + sig = shr2_round(sig, rounding_mode); + } else if round_bits < 0 { + sig <<= -round_bits; + } // The parsed value is X = sig * 2^exp // Expressed as a multiple U of the smallest subnormal value: // X = U * 2^min_lsb, so U = sig * 2^(exp-min_lsb) - let mut uexp = exp - min_lsb; + let uexp = (exp - min_lsb) as u128; + let uexp = uexp << sig_bits; - let shift = if uexp + msb >= sig_bits { - // normal, shift msb to position sig_bits - sig_bits - msb - } else { - // subnormal, shift so that uexp becomes 0 - uexp + // Note that it is possible for the exponent bits to equal 2 here + // if the value rounded up, but that means the mantissa is all zeroes + // so the value is still correct + debug_assert!(sig <= 2 << sig_bits); + + let inf = ((1 << exp_bits) - 1) << sig_bits; + + let bits = match sig.checked_add(uexp) { + Some(bits) if bits < inf => { + // inexact subnormal or zero? + if status.inexact() && bits < (1 << sig_bits) { + status = status.with(Status::UNDERFLOW); + } + bits + } + _ => { + // overflow to infinity + status = status.with(Status::OVERFLOW).with(Status::INEXACT); + match rounding_mode { + Round::Positive | Round::Nearest => inf, + Round::Negative | Round::Zero => inf - 1, + } + } }; + Ok((bits, status)) +} - if shift >= 0 { - sig <<= shift; +/// Shift right, rounding all inexact divisions to the nearest odd number +/// E.g. (0 >> 4) -> 0, (1..=31 >> 4) -> 1, (32 >> 4) -> 2, ... +/// +/// Useful for reducing a number before rounding the last two bits, since +/// the result of the final rounding is preserved for all rounding modes. +const fn shr_odd_rounding(x: u128, k: u32) -> u128 { + if k < 128 { + let inexact = x.trailing_zeros() < k; + (x >> k) | (inexact as u128) } else { - sig >>= -shift; + (x != 0) as u128 } - uexp -= shift; - - // the most significant bit is like having 1 in the exponent bits - // add any leftover exponent to that - assert!(uexp >= 0 && uexp < (1 << exp_bits) - 2); - sig += (uexp as u128) << sig_bits; +} - // finally, set the sign bit if necessary - sig | ((neg as u128) << (bits - 1)) +/// Divide by 4, rounding with the given mode +const fn shr2_round(mut x: u128, round: Round) -> u128 { + let t = (x as u32) & 0b111; + x >>= 2; + match round { + // Look-up-table on the last three bits for when to round up + Round::Nearest => x + ((0b11001000_u8 >> t) & 1) as u128, + + Round::Negative => x, + Round::Zero => x, + Round::Positive => x + (t & 0b11 != 0) as u128, + } } -/// A parsed floating point number. -enum Parsed { - /// Absolute value sig * 2^e - Finite { - neg: bool, - sig: u128, - exp: i32, - }, - Infinite { - neg: bool, - }, - Nan { - neg: bool, - }, +/// A parsed finite and unsigned floating point number. +struct Parsed { + /// Absolute value sig * 2^exp + sig: u128, + exp: i32, } /// Parse a hexadecimal float x -const fn parse_hex(mut b: &[u8]) -> Parsed { - let mut neg = false; +const fn parse_hex(mut b: &[u8]) -> Result { let mut sig: u128 = 0; let mut exp: i32 = 0; - if let &[c @ (b'-' | b'+'), ref rest @ ..] = b { - b = rest; - neg = c == b'-'; - } - - match *b { - [b'i' | b'I', b'n' | b'N', b'f' | b'F'] => return Parsed::Infinite { neg }, - [b'n' | b'N', b'a' | b'A', b'n' | b'N'] => return Parsed::Nan { neg }, - _ => (), - } - - if let &[b'0', b'x' | b'X', ref rest @ ..] = b { - b = rest; - } else { - panic!("no hex indicator"); - } - let mut seen_point = false; let mut some_digits = false; + let mut inexact = false; while let &[c, ref rest @ ..] = b { b = rest; match c { b'.' => { - assert!(!seen_point); + if seen_point { + return Err(HexFloatParseError("unexpected '.' parsing fractional digits")); + } seen_point = true; continue; } b'p' | b'P' => break, c => { - let digit = hex_digit(c); + let digit = match hex_digit(c) { + Some(d) => d, + None => return Err(HexFloatParseError("expected hexadecimal digit")), + }; some_digits = true; - let of; - (sig, of) = sig.overflowing_mul(16); - assert!(!of, "too many digits"); - sig |= digit as u128; - // up until the fractional point, the value grows + + if (sig >> 124) == 0 { + sig <<= 4; + sig |= digit as u128; + } else { + // FIXME: it is technically possible for exp to overflow if parsing a string with >500M digits + exp += 4; + inexact |= digit != 0; + } + // Up until the fractional point, the value grows // with more digits, but after it the exponent is // compensated to match. if seen_point { @@ -153,49 +264,79 @@ const fn parse_hex(mut b: &[u8]) -> Parsed { } } } - assert!(some_digits, "at least one digit is required"); + // If we've set inexact, the exact value has more than 125 + // significant bits, and lies somewhere between sig and sig + 1. + // Because we'll round off at least two of the trailing bits, + // setting the last bit gives correct rounding for inexact values. + sig |= inexact as u128; + + if !some_digits { + return Err(HexFloatParseError("at least one digit is required")); + }; + some_digits = false; - let mut negate_exp = false; - if let &[c @ (b'-' | b'+'), ref rest @ ..] = b { + let negate_exp = matches!(b, [b'-', ..]); + if let &[b'-' | b'+', ref rest @ ..] = b { b = rest; - negate_exp = c == b'-'; } - let mut pexp: i32 = 0; + let mut pexp: u32 = 0; while let &[c, ref rest @ ..] = b { b = rest; - let digit = dec_digit(c); + let digit = match dec_digit(c) { + Some(d) => d, + None => return Err(HexFloatParseError("expected decimal digit")), + }; some_digits = true; - let of; - (pexp, of) = pexp.overflowing_mul(10); - assert!(!of, "too many exponent digits"); - pexp += digit as i32; + pexp = pexp.saturating_mul(10); + pexp += digit as u32; } - assert!(some_digits, "at least one exponent digit is required"); + if !some_digits { + return Err(HexFloatParseError("at least one exponent digit is required")); + }; + + { + let e; + if negate_exp { + e = (exp as i64) - (pexp as i64); + } else { + e = (exp as i64) + (pexp as i64); + }; + + exp = if e < i32::MIN as i64 { + i32::MIN + } else if e > i32::MAX as i64 { + i32::MAX + } else { + e as i32 + }; + } + /* FIXME(msrv): once MSRV >= 1.66, replace the above workaround block with: if negate_exp { - exp -= pexp; + exp = exp.saturating_sub_unsigned(pexp); } else { - exp += pexp; - } + exp = exp.saturating_add_unsigned(pexp); + }; + */ - Parsed::Finite { neg, sig, exp } + Ok(Parsed { sig, exp }) } -const fn dec_digit(c: u8) -> u8 { +const fn dec_digit(c: u8) -> Option { match c { - b'0'..=b'9' => c - b'0', - _ => panic!("bad char"), + b'0'..=b'9' => Some(c - b'0'), + _ => None, } } -const fn hex_digit(c: u8) -> u8 { +const fn hex_digit(c: u8) -> Option { match c { - b'0'..=b'9' => c - b'0', - b'a'..=b'f' => c - b'a' + 10, - b'A'..=b'F' => c - b'A' + 10, - _ => panic!("bad char"), + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, } } @@ -341,6 +482,61 @@ mod parse_tests { use super::*; + #[cfg(f16_enabled)] + fn rounding_properties(s: &str) -> Result<(), HexFloatParseError> { + let (xd, s0) = parse_any(s, 16, 10, Round::Negative)?; + let (xu, s1) = parse_any(s, 16, 10, Round::Positive)?; + let (xz, s2) = parse_any(s, 16, 10, Round::Zero)?; + let (xn, s3) = parse_any(s, 16, 10, Round::Nearest)?; + + // FIXME: A value between the least normal and largest subnormal + // could have underflow status depend on rounding mode. + + if let Status::OK = s0 { + // an exact result is the same for all rounding modes + assert_eq!(s0, s1); + assert_eq!(s0, s2); + assert_eq!(s0, s3); + + assert_eq!(xd, xu); + assert_eq!(xd, xz); + assert_eq!(xd, xn); + } else { + assert!([s0, s1, s2, s3].into_iter().all(Status::inexact)); + + let xd = f16::from_bits(xd as u16); + let xu = f16::from_bits(xu as u16); + let xz = f16::from_bits(xz as u16); + let xn = f16::from_bits(xn as u16); + + assert_biteq!(xd.next_up(), xu, "s={s}, xd={xd:?}, xu={xu:?}"); + + let signs = [xd, xu, xz, xn].map(f16::is_sign_negative); + + if signs == [true; 4] { + assert_biteq!(xz, xu); + } else { + assert_eq!(signs, [false; 4]); + assert_biteq!(xz, xd); + } + + if xn.to_bits() != xd.to_bits() { + assert_biteq!(xn, xu); + } + } + Ok(()) + } + #[test] + #[cfg(f16_enabled)] + fn test_rounding() { + let n = 1_i32 << 14; + for i in -n..n { + let u = i.rotate_right(11) as u32; + let s = format!("{}", Hexf(f32::from_bits(u))); + assert!(rounding_properties(&s).is_ok()); + } + } + #[test] fn test_parse_any() { for k in -149..=127 { @@ -397,6 +593,48 @@ mod parse_tests { } } + // FIXME: this test is causing failures that are likely UB on various platforms + #[cfg(all(target_arch = "x86_64", target_os = "linux"))] + #[test] + #[cfg(f128_enabled)] + fn rounding() { + let pi = std::f128::consts::PI; + let s = format!("{}", Hexf(pi)); + + for k in 0..=111 { + let (bits, status) = parse_any(&s, 128 - k, 112 - k, Round::Nearest).unwrap(); + let scale = (1u128 << (112 - k - 1)) as f128; + let expected = (pi * scale).round_ties_even() / scale; + assert_eq!(bits << k, expected.to_bits(), "k = {k}, s = {s}"); + assert_eq!(expected != pi, status.inexact()); + } + } + #[test] + fn rounding_extreme_underflow() { + for k in 1..1000 { + let s = format!("0x1p{}", -149 - k); + let Ok((bits, status)) = parse_any(&s, 32, 23, Round::Nearest) else { unreachable!() }; + assert_eq!(bits, 0, "{s} should round to zero, got bits={bits}"); + assert!(status.underflow(), "should indicate underflow when parsing {s}"); + assert!(status.inexact(), "should indicate inexact when parsing {s}"); + } + } + #[test] + fn long_tail() { + for k in 1..1000 { + let s = format!("0x1.{}p0", "0".repeat(k)); + let Ok(bits) = parse_hex_exact(&s, 32, 23) else { panic!("parsing {s} failed") }; + assert_eq!(f32::from_bits(bits as u32), 1.0); + + let s = format!("0x1.{}1p0", "0".repeat(k)); + let Ok((bits, status)) = parse_any(&s, 32, 23, Round::Nearest) else { unreachable!() }; + if status.inexact() { + assert!(1.0 == f32::from_bits(bits as u32)); + } else { + assert!(1.0 < f32::from_bits(bits as u32)); + } + } + } // HACK(msrv): 1.63 rejects unknown width float literals at an AST level, so use a macro to // hide them from the AST. #[cfg(f16_enabled)] @@ -434,6 +672,7 @@ mod parse_tests { ]; for (s, exp) in checks { println!("parsing {s}"); + assert!(rounding_properties(s).is_ok()); let act = hf16(s).to_bits(); assert_eq!( act, exp, @@ -749,7 +988,13 @@ mod tests_panicking { #[test] #[should_panic(expected = "the value is too precise")] fn test_f128_extra_precision() { - // One bit more than the above. + // Just below the maximum finite. + hf128("0x1.fffffffffffffffffffffffffffe8p+16383"); + } + #[test] + #[should_panic(expected = "the value is too huge")] + fn test_f128_extra_precision_overflow() { + // One bit more than the above. Should overflow. hf128("0x1.ffffffffffffffffffffffffffff8p+16383"); } @@ -822,6 +1067,46 @@ mod print_tests { } } + #[test] + #[cfg(f16_enabled)] + fn test_f16_to_f32() { + use std::format; + // Exhaustively check that these are equivalent for all `f16`: + // - `f16 -> f32` + // - `f16 -> str -> f32` + // - `f16 -> f32 -> str -> f32` + // - `f16 -> f32 -> str -> f16 -> f32` + for x in 0..=u16::MAX { + let f16 = f16::from_bits(x); + let s16 = format!("{}", Hexf(f16)); + let f32 = f16 as f32; + let s32 = format!("{}", Hexf(f32)); + + let a = hf32(&s16); + let b = hf32(&s32); + let c = hf16(&s32); + + if f32.is_nan() && a.is_nan() && b.is_nan() && c.is_nan() { + continue; + } + + assert_eq!( + f32.to_bits(), + a.to_bits(), + "{f16:?} : f16 formatted as {s16} which parsed as {a:?} : f16" + ); + assert_eq!( + f32.to_bits(), + b.to_bits(), + "{f32:?} : f32 formatted as {s32} which parsed as {b:?} : f32" + ); + assert_eq!( + f32.to_bits(), + (c as f32).to_bits(), + "{f32:?} : f32 formatted as {s32} which parsed as {c:?} : f16" + ); + } + } #[test] fn spot_checks() { assert_eq!(Hexf(f32::MAX).to_string(), "0x1.fffffep+127");