diff --git a/hash2curve/src/group_digest.rs b/hash2curve/src/group_digest.rs index da493503c..b85d54efe 100644 --- a/hash2curve/src/group_digest.rs +++ b/hash2curve/src/group_digest.rs @@ -1,8 +1,7 @@ //! Traits for handling hash to curve. use super::{ExpandMsg, MapToCurve, hash_to_field}; -use elliptic_curve::array::typenum::Unsigned; -use elliptic_curve::{ProjectivePoint, Result}; +use elliptic_curve::{ProjectivePoint, array::typenum::Unsigned}; /// Hash arbitrary byte sequences to a valid group element. pub trait GroupDigest: MapToCurve { @@ -22,17 +21,17 @@ pub trait GroupDigest: MapToCurve { /// > oracle returning points in G assuming a cryptographically secure /// > hash function is used. /// + /// For the `expand_message` call, `len_in_bytes = ::Length * 2`. + /// This value must be less than `u16::MAX` or otherwise a compiler error will occur. + /// /// # Errors - /// - `len_in_bytes > u16::MAX` - /// - See implementors of [`ExpandMsg`] for additional errors: - /// - [`ExpandMsgXmd`] - /// - [`ExpandMsgXof`] /// - /// `len_in_bytes = ::Length * 2` + /// When the chosen [`ExpandMsg`] implementation returns an error. See [`ExpandMsgXmdError`] + /// and [`ExpandMsgXofError`] for examples. /// - /// [`ExpandMsgXmd`]: crate::ExpandMsgXmd - /// [`ExpandMsgXof`]: crate::ExpandMsgXof - fn hash_from_bytes(msg: &[&[u8]], dst: &[&[u8]]) -> Result> + /// [`ExpandMsgXmdError`]: crate::ExpandMsgXmdError + /// [`ExpandMsgXofError`]: crate::ExpandMsgXofError + fn hash_from_bytes(msg: &[&[u8]], dst: &[&[u8]]) -> Result, X::Error> where X: ExpandMsg, { @@ -52,17 +51,17 @@ pub trait GroupDigest: MapToCurve { /// > encode_to_curve is only a fraction of the points in G, and some /// > points in this set are more likely to be output than others. /// + /// For the `expand_message` call, `len_in_bytes = ::Length`. + /// This value must be less than `u16::MAX` or otherwise a compiler error will occur. + /// /// # Errors - /// - `len_in_bytes > u16::MAX` - /// - See implementors of [`ExpandMsg`] for additional errors: - /// - [`ExpandMsgXmd`] - /// - [`ExpandMsgXof`] /// - /// `len_in_bytes = ::Length` + /// When the chosen [`ExpandMsg`] implementation returns an error. See [`ExpandMsgXmdError`] + /// and [`ExpandMsgXofError`] for examples. /// - /// [`ExpandMsgXmd`]: crate::ExpandMsgXmd - /// [`ExpandMsgXof`]: crate::ExpandMsgXof - fn encode_from_bytes(msg: &[&[u8]], dst: &[&[u8]]) -> Result> + /// [`ExpandMsgXmdError`]: crate::ExpandMsgXmdError + /// [`ExpandMsgXofError`]: crate::ExpandMsgXofError + fn encode_from_bytes(msg: &[&[u8]], dst: &[&[u8]]) -> Result, X::Error> where X: ExpandMsg, { @@ -74,18 +73,18 @@ pub trait GroupDigest: MapToCurve { /// Computes the hash to field routine according to /// /// and returns a scalar. + /// + /// For the `expand_message` call, `len_in_bytes = ::Length`. + /// This value must be less than `u16::MAX` or otherwise a compiler error will occur. /// /// # Errors - /// - `len_in_bytes > u16::MAX` - /// - See implementors of [`ExpandMsg`] for additional errors: - /// - [`ExpandMsgXmd`] - /// - [`ExpandMsgXof`] /// - /// `len_in_bytes = ::Length` + /// When the chosen [`ExpandMsg`] implementation returns an error. See [`ExpandMsgXmdError`] + /// and [`ExpandMsgXofError`] for examples. /// - /// [`ExpandMsgXmd`]: crate::ExpandMsgXmd - /// [`ExpandMsgXof`]: crate::ExpandMsgXof - fn hash_to_scalar(msg: &[&[u8]], dst: &[&[u8]]) -> Result + /// [`ExpandMsgXmdError`]: crate::ExpandMsgXmdError + /// [`ExpandMsgXofError`]: crate::ExpandMsgXofError + fn hash_to_scalar(msg: &[&[u8]], dst: &[&[u8]]) -> Result where X: ExpandMsg, { diff --git a/hash2curve/src/hash2field.rs b/hash2curve/src/hash2field.rs index 267cdefc6..cdc3dc14e 100644 --- a/hash2curve/src/hash2field.rs +++ b/hash2curve/src/hash2field.rs @@ -12,7 +12,6 @@ use elliptic_curve::array::{ Array, ArraySize, typenum::{NonZero, Unsigned}, }; -use elliptic_curve::{Error, Result}; /// The trait for helping to convert to a field element. pub trait FromOkm { @@ -27,27 +26,28 @@ pub trait FromOkm { /// /// /// -/// # Errors -/// - `len_in_bytes > u16::MAX` -/// - See implementors of [`ExpandMsg`] for additional errors: -/// - [`ExpandMsgXmd`] -/// - [`ExpandMsgXof`] +/// For the `expand_message` call, `len_in_bytes = T::Length * N`. /// -/// `len_in_bytes = T::Length * out.len()` +/// # Errors /// -/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd -/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof +/// Returns an error if the [`ExpandMsg`] implementation fails. #[doc(hidden)] -pub fn hash_to_field(data: &[&[u8]], domain: &[&[u8]]) -> Result<[T; N]> +pub fn hash_to_field( + data: &[&[u8]], + domain: &[&[u8]], +) -> Result<[T; N], E::Error> where E: ExpandMsg, T: FromOkm + Default, { - let len_in_bytes = T::Length::USIZE - .checked_mul(N) - .and_then(|len| len.try_into().ok()) - .and_then(NonZeroU16::new) - .ok_or(Error)?; + // Completely degenerate case; `N` and `T::Length` would need to be extremely large. + let len_in_bytes = const { + assert!( + T::Length::USIZE.saturating_mul(N) <= u16::MAX as usize, + "The product of `T::Length` and `N` must not exceed `u16::MAX`." + ); + NonZeroU16::new(T::Length::U16 * N as u16).expect("N is greater than 0") + }; let mut tmp = Array::::Length>::default(); let mut expander = E::expand_message(data, domain, len_in_bytes)?; Ok(core::array::from_fn(|_| { diff --git a/hash2curve/src/hash2field/expand_msg.rs b/hash2curve/src/hash2field/expand_msg.rs index 5db42b73a..dea04d087 100644 --- a/hash2curve/src/hash2field/expand_msg.rs +++ b/hash2curve/src/hash2field/expand_msg.rs @@ -7,7 +7,8 @@ use core::num::NonZero; use digest::{Digest, ExtendableOutput, Update, XofReader}; use elliptic_curve::array::{Array, ArraySize}; -use elliptic_curve::{Error, Result}; +use xmd::ExpandMsgXmdError; +use xof::ExpandMsgXofError; /// Salt when the DST is too long const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-"; @@ -25,6 +26,8 @@ const MAX_DST_LEN: usize = 255; pub trait ExpandMsg { /// Type holding data for the [`Expander`]. type Expander<'dst>: Expander + Sized; + /// Error returned by [`ExpandMsg::expand_message`]. + type Error: core::error::Error; /// Expands `msg` to the required number of bytes. /// @@ -34,7 +37,7 @@ pub trait ExpandMsg { msg: &[&[u8]], dst: &'dst [&[u8]], len_in_bytes: NonZero, - ) -> Result>; + ) -> Result, Self::Error>; } /// Expander that, call `read` until enough bytes have been consumed. @@ -57,18 +60,17 @@ pub(crate) enum Domain<'a, L: ArraySize> { } impl<'a, L: ArraySize> Domain<'a, L> { - pub fn xof(dst: &'a [&'a [u8]]) -> Result + pub fn xof(dst: &'a [&'a [u8]]) -> Result where X: Default + ExtendableOutput + Update, { // https://www.rfc-editor.org/rfc/rfc9380.html#section-3.1-4.2 if dst.iter().map(|slice| slice.len()).sum::() == 0 { - Err(Error) + Err(ExpandMsgXofError::EmptyDst) } else if dst.iter().map(|slice| slice.len()).sum::() > MAX_DST_LEN { if L::USIZE > u8::MAX.into() { - return Err(Error); + return Err(ExpandMsgXofError::DstSecurityLevel); } - let mut data = Array::::default(); let mut hash = X::default(); hash.update(OVERSIZE_DST_SALT); @@ -85,18 +87,17 @@ impl<'a, L: ArraySize> Domain<'a, L> { } } - pub fn xmd(dst: &'a [&'a [u8]]) -> Result + pub fn xmd(dst: &'a [&'a [u8]]) -> Result where X: Digest, { // https://www.rfc-editor.org/rfc/rfc9380.html#section-3.1-4.2 if dst.iter().map(|slice| slice.len()).sum::() == 0 { - Err(Error) + Err(ExpandMsgXmdError::EmptyDst) } else if dst.iter().map(|slice| slice.len()).sum::() > MAX_DST_LEN { if L::USIZE > u8::MAX.into() { - return Err(Error); + return Err(ExpandMsgXmdError::DstHash); } - Ok(Self::Hashed({ let mut hash = X::new(); hash.update(OVERSIZE_DST_SALT); diff --git a/hash2curve/src/hash2field/expand_msg/xmd.rs b/hash2curve/src/hash2field/expand_msg/xmd.rs index c464700d2..749c4f2e7 100644 --- a/hash2curve/src/hash2field/expand_msg/xmd.rs +++ b/hash2curve/src/hash2field/expand_msg/xmd.rs @@ -11,15 +11,13 @@ use digest::{ }, block_api::BlockSizeUser, }; -use elliptic_curve::{Error, Result}; /// Implements `expand_message_xof` via the [`ExpandMsg`] trait: /// /// /// # Errors -/// - `dst` contains no bytes -/// - `dst > 255 && HashT::OutputSize > 255` -/// - `len_in_bytes > 255 * HashT::OutputSize` +/// +/// See [`ExpandMsgXmdError`] for details. #[derive(Debug)] pub struct ExpandMsgXmd(PhantomData) where @@ -38,17 +36,18 @@ where HashT::OutputSize: IsGreaterOrEqual, Output = True>, { type Expander<'dst> = ExpanderXmd<'dst, HashT>; + type Error = ExpandMsgXmdError; fn expand_message<'dst>( msg: &[&[u8]], dst: &'dst [&[u8]], len_in_bytes: NonZero, - ) -> Result> { + ) -> Result, ExpandMsgXmdError> { let b_in_bytes = HashT::OutputSize::USIZE; // `255 * ` can not exceed `u16::MAX` if usize::from(len_in_bytes.get()) > 255 * b_in_bytes { - return Err(Error); + return Err(ExpandMsgXmdError::Length); } let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes)) @@ -146,6 +145,32 @@ where } } +/// Error type for [`ExpandMsgXmd`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExpandMsgXmdError { + /// The domain separation tag must not be empty. + EmptyDst, + /// The hash's output size must not be greater then `255` + /// if the domain separation tag is longer than `255`. + DstHash, + /// The length in bytes is too large. + /// + /// `len_in_bytes` must be at most `255 * HashT::OutputSize`. + Length, +} + +impl core::fmt::Display for ExpandMsgXmdError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::EmptyDst => write!(f, "the domain separation tag is empty"), + Self::DstHash => write!(f, "hash output size is too large"), + Self::Length => write!(f, "`len_in_bytes` is too large"), + } + } +} + +impl core::error::Error for ExpandMsgXmdError {} + #[cfg(test)] mod test { use super::*; @@ -196,11 +221,7 @@ mod test { impl TestVector { #[allow(clippy::panic_in_result_fn)] - fn assert( - &self, - dst: &'static [u8], - domain: &Domain<'_, HashT::OutputSize>, - ) -> Result<()> + fn assert(&self, dst: &'static [u8], domain: &Domain<'_, HashT::OutputSize>) where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLessOrEqual, @@ -213,24 +234,24 @@ mod test { let mut expander = as ExpandMsg>::expand_message( &[self.msg], &dst, - NonZero::new(L::U16).ok_or(Error)?, - )?; + NonZero::new(L::U16).unwrap(), + ) + .unwrap(); let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes); assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes); - Ok(()) } } #[test] - fn expand_message_xmd_sha_256() -> Result<()> { + fn expand_message_xmd_sha_256() { const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHA256-128"; const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826"); - let dst_prime = Domain::xmd::(&[DST])?; + let dst_prime = Domain::xmd::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -262,7 +283,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -290,19 +311,17 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } #[test] - fn expand_message_xmd_sha_256_long() -> Result<()> { + fn expand_message_xmd_sha_256_long() { const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHA256-128-long-DST-1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"; const DST_PRIME: &[u8] = &hex!("412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620"); - let dst_prime = Domain::xmd::(&[DST])?; + let dst_prime = Domain::xmd::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -334,7 +353,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -366,21 +385,19 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } #[test] - fn expand_message_xmd_sha_512() -> Result<()> { + fn expand_message_xmd_sha_512() { use sha2::Sha512; const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHA512-256"; const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626"); - let dst_prime = Domain::xmd::(&[DST])?; + let dst_prime = Domain::xmd::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -412,7 +429,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -444,9 +461,7 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } } diff --git a/hash2curve/src/hash2field/expand_msg/xof.rs b/hash2curve/src/hash2field/expand_msg/xof.rs index 32d00a7c1..20ed8d8b8 100644 --- a/hash2curve/src/hash2field/expand_msg/xof.rs +++ b/hash2curve/src/hash2field/expand_msg/xof.rs @@ -2,21 +2,18 @@ use super::{Domain, ExpandMsg, Expander}; use core::{fmt, num::NonZero, ops::Mul}; -use digest::{ - CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual, -}; -use elliptic_curve::Result; +use digest::{CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader}; use elliptic_curve::array::{ ArraySize, - typenum::{Prod, True, U2}, + typenum::{IsGreaterOrEqual, Prod, True, U2}, }; /// Implements `expand_message_xof` via the [`ExpandMsg`] trait: /// /// /// # Errors -/// - `dst` contains no bytes -/// - `dst > 255 && K * 2 > 255` +/// +/// See [`ExpandMsgXofError`] for details. pub struct ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, @@ -47,12 +44,13 @@ where HashT: CollisionResistance>, { type Expander<'dst> = Self; + type Error = ExpandMsgXofError; fn expand_message<'dst>( msg: &[&[u8]], dst: &'dst [&[u8]], len_in_bytes: NonZero, - ) -> Result> { + ) -> Result, ExpandMsgXofError> { let len_in_bytes = len_in_bytes.get(); let domain = Domain::>::xof::(dst)?; @@ -79,10 +77,29 @@ where } } +/// Error type for [`ExpandMsgXof`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExpandMsgXofError { + /// The domain separation tag is invalid because it is empty. + EmptyDst, + /// The target security level (`K`) must not be greater then `127` + /// if the domain separation tag is longer than `255`. + DstSecurityLevel, +} + +impl core::fmt::Display for ExpandMsgXofError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::EmptyDst => write!(f, "the domain separation tag is empty"), + Self::DstSecurityLevel => write!(f, "target security level is too large"), + } + } +} + +impl core::error::Error for ExpandMsgXofError {} + #[cfg(test)] mod test { - use elliptic_curve::Error; - use super::*; use core::mem::size_of; use elliptic_curve::array::{ @@ -119,7 +136,7 @@ mod test { impl TestVector { #[allow(clippy::panic_in_result_fn)] - fn assert(&self, dst: &'static [u8], domain: &Domain<'_, U32>) -> Result<()> + fn assert(&self, dst: &'static [u8], domain: &Domain<'_, U32>) where HashT: Default + ExtendableOutput @@ -133,24 +150,24 @@ mod test { let mut expander = as ExpandMsg>::expand_message( &[self.msg], &[dst], - NonZero::new(L::U16).ok_or(Error)?, - )?; + NonZero::new(L::U16).unwrap(), + ) + .unwrap(); let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes); assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes); - Ok(()) } } #[test] - fn expand_message_xof_shake_128() -> Result<()> { + fn expand_message_xof_shake_128() { const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHAKE128"; const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348414b4531323824"); - let dst_prime = Domain::::xof::(&[DST])?; + let dst_prime = Domain::::xof::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -182,7 +199,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -214,19 +231,17 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } #[test] - fn expand_message_xof_shake_128_long() -> Result<()> { + fn expand_message_xof_shake_128_long() { const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHAKE128-long-DST-111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"; const DST_PRIME: &[u8] = &hex!("acb9736c0867fdfbd6385519b90fc8c034b5af04a958973212950132d035792f20"); - let dst_prime = Domain::::xof::(&[DST])?; + let dst_prime = Domain::::xof::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -258,7 +273,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -290,21 +305,19 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } #[test] - fn expand_message_xof_shake_256() -> Result<()> { + fn expand_message_xof_shake_256() { use sha3::Shake256; const DST: &[u8] = b"QUUX-V01-CS02-with-expander-SHAKE256"; const DST_PRIME: &[u8] = &hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348414b4532353624"); - let dst_prime = Domain::::xof::(&[DST])?; + let dst_prime = Domain::::xof::(&[DST]).unwrap(); dst_prime.assert_dst(DST_PRIME); const TEST_VECTORS_32: &[TestVector] = &[ @@ -336,7 +349,7 @@ mod test { ]; for test_vector in TEST_VECTORS_32 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } const TEST_VECTORS_128: &[TestVector] = &[ @@ -368,9 +381,7 @@ mod test { ]; for test_vector in TEST_VECTORS_128 { - test_vector.assert::(DST, &dst_prime)?; + test_vector.assert::(DST, &dst_prime); } - - Ok(()) } }