Skip to content

Use Iterator<Item = u8> instead of Expander #1317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions ed448-goldilocks/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ impl FieldElement {
mod tests {
use super::*;
use elliptic_curve::consts::U32;
use hash2curve::{ExpandMsg, ExpandMsgXof, Expander};
use hash2curve::{ExpandMsg, ExpandMsgXof};
use hex_literal::hex;
use sha3::Shake256;

Expand All @@ -463,18 +463,16 @@ mod tests {
(84 * 2).try_into().unwrap(),
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
// TODO: This should be `Curve448FieldElement`.
let u0 = Ed448FieldElement::from_okm(&data).0;
let u0 = Ed448FieldElement::from_okm(&expander.by_ref().take(84).collect()).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);

// TODO: This should be `Curve448FieldElement`.
let u1 = Ed448FieldElement::from_okm(&data).0;
let u1 = Ed448FieldElement::from_okm(&expander.collect()).0;
assert_eq!(u1.to_bytes(), e_u1);
}
}
Expand All @@ -497,16 +495,13 @@ mod tests {
(84 * 2).try_into().unwrap(),
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
let u0 = Ed448FieldElement::from_okm(&data).0;
let u0 = Ed448FieldElement::from_okm(&expander.by_ref().take(84).collect()).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);
let u1 = Ed448FieldElement::from_okm(&data).0;
let u1 = Ed448FieldElement::from_okm(&expander.collect()).0;
assert_eq!(u1.to_bytes(), e_u1);
}
}
Expand Down
4 changes: 1 addition & 3 deletions hash2curve/src/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ where
.and_then(|len| len.try_into().ok())
.and_then(NonZeroU16::new)
.ok_or(Error)?;
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
Ok(core::array::from_fn(|_| {
expander.fill_bytes(&mut tmp);
T::from_okm(&tmp)
T::from_okm(&expander.by_ref().take(T::Length::USIZE).collect())
}))
}
12 changes: 3 additions & 9 deletions hash2curve/src/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ const MAX_DST_LEN: usize = 255;
/// # Errors
/// See implementors of [`ExpandMsg`] for errors.
pub trait ExpandMsg<K> {
/// Type holding data for the [`Expander`].
type Expander<'dst>: Expander + Sized;
/// The expanded message.
type Expanded<'a>: Iterator<Item = u8>;

/// Expands `msg` to the required number of bytes.
///
Expand All @@ -34,13 +34,7 @@ pub trait ExpandMsg<K> {
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>>;
}

/// Expander that, call `read` until enough bytes have been consumed.
pub trait Expander {
/// Fill the array with the expanded bytes
fn fill_bytes(&mut self, okm: &mut [u8]);
) -> Result<Self::Expanded<'dst>>;
}

/// The domain separation tag
Expand Down
72 changes: 28 additions & 44 deletions hash2curve/src/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! `expand_message_xmd` based on a hash function.

use core::{marker::PhantomData, num::NonZero, ops::Mul};
use core::{num::NonZero, ops::Mul};

use super::{Domain, ExpandMsg, Expander};
use super::{Domain, ExpandMsg};
use digest::{
FixedOutput, HashMarker,
array::{
Expand All @@ -20,11 +20,8 @@ use elliptic_curve::{Error, Result};
/// - `dst` contains no bytes
/// - `dst > 255 && HashT::OutputSize > 255`
/// - `len_in_bytes > 255 * HashT::OutputSize`
#[derive(Debug)]
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExpandMsgXmd<HashT>(core::marker::PhantomData<HashT>);

impl<HashT, K> ExpandMsg<K> for ExpandMsgXmd<HashT>
where
Expand All @@ -37,23 +34,20 @@ where
K: Mul<U2>,
HashT::OutputSize: IsGreaterOrEqual<Prod<K, U2>, Output = True>,
{
type Expander<'dst> = ExpanderXmd<'dst, HashT>;
type Expanded<'a> = ExpandedXmd<'a, HashT>;

fn expand_message<'dst>(
fn expand_message<'a>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
dst: &'a [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>> {
) -> Result<Self::Expanded<'a>> {
let b_in_bytes = HashT::OutputSize::USIZE;

// `255 * <b_in_bytes>` can not exceed `u16::MAX`
if usize::from(len_in_bytes.get()) > 255 * b_in_bytes {
return Err(Error);
}

let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
.expect("should never pass the previous check");
Comment on lines -54 to -55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets turn this into a debug_assert!.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here is succinct enough to not need a dbg assert. Since len_in_bytes <= 255 * b_in_bytes, len_in_bytes / b_in_bytes <= 255.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine to me!


let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::default();
b_0.update(&Array::<u8, HashT::BlockSize>::default());
Expand All @@ -75,20 +69,20 @@ where
b_vals.update(&[domain.len()]);
let b_vals = b_vals.finalize_fixed();

Ok(ExpanderXmd {
Ok(ExpandedXmd {
b_0,
b_vals,
domain,
index: 1,
offset: 0,
ell,
remaining: len_in_bytes.get(),
})
}
}

/// [`Expander`] type for [`ExpandMsgXmd`].
/// The expanded bytes of `expand_message_xmd`.
#[derive(Debug)]
pub struct ExpanderXmd<'a, HashT>
pub struct ExpandedXmd<'a, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
Expand All @@ -98,16 +92,22 @@ where
domain: Domain<'a, HashT::OutputSize>,
index: u8,
offset: usize,
ell: u8,
remaining: u16,
}

impl<HashT> ExpanderXmd<'_, HashT>
impl<HashT> Iterator for ExpandedXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn next(&mut self) -> bool {
if self.index < self.ell {
type Item = u8;

fn next(&mut self) -> Option<u8> {
if self.remaining == 0 {
return None;
}

if self.offset == self.b_vals.len() {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
Expand All @@ -123,26 +123,12 @@ where
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
true
} else {
false
}
}
}

impl<HashT> Expander for ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
for b in okm {
if self.offset == self.b_vals.len() && !self.next() {
return;
}
*b = self.b_vals[self.offset];
self.offset += 1;
}
let byte = self.b_vals[self.offset];
self.offset += 1;
self.remaining -= 1;
Some(byte)
}
}

Expand Down Expand Up @@ -210,15 +196,13 @@ mod test {
assert_message::<HashT>(self.msg, domain, L::U16, self.msg_prime);

let dst = [dst];
let mut expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
let expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
&[self.msg],
&dst,
NonZero::new(L::U16).ok_or(Error)?,
)?;

let mut uniform_bytes = Array::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);

let uniform_bytes: Array<u8, L> = expander.collect();
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
Ok(())
}
Expand Down
43 changes: 25 additions & 18 deletions hash2curve/src/hash2field/expand_msg/xof.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! `expand_message_xof` for the `ExpandMsg` trait

use super::{Domain, ExpandMsg, Expander};
use core::{fmt, num::NonZero, ops::Mul};
use super::{Domain, ExpandMsg};
use core::{array, fmt, num::NonZero, ops::Mul};
use digest::XofReader;
use digest::{
CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual,
CollisionResistance, ExtendableOutput, HashMarker, Update, typenum::IsGreaterOrEqual,
};
use elliptic_curve::Result;
use elliptic_curve::array::{
Expand All @@ -21,7 +22,8 @@ pub struct ExpandMsgXof<HashT>
where
HashT: Default + ExtendableOutput + Update + HashMarker,
{
reader: <HashT as ExtendableOutput>::Reader,
reader: HashT::Reader,
length: u16,
}

impl<HashT> fmt::Debug for ExpandMsgXof<HashT>
Expand All @@ -46,13 +48,9 @@ where
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.2-2.1
HashT: CollisionResistance<CollisionResistance: IsGreaterOrEqual<K, Output = True>>,
{
type Expander<'dst> = Self;
type Expanded<'a> = Self;

fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>> {
fn expand_message(msg: &[&[u8]], dst: &[&[u8]], len_in_bytes: NonZero<u16>) -> Result<Self> {
let len_in_bytes = len_in_bytes.get();

let domain = Domain::<Prod<K, U2>>::xof::<HashT>(dst)?;
Expand All @@ -66,16 +64,27 @@ where
domain.update_hash(&mut reader);
reader.update(&[domain.len()]);
let reader = reader.finalize_xof();
Ok(Self { reader })
Ok(Self {
reader,
length: len_in_bytes,
})
}
}

impl<HashT> Expander for ExpandMsgXof<HashT>
impl<HashT> Iterator for ExpandMsgXof<HashT>
where
HashT: Default + ExtendableOutput + Update + HashMarker,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
self.reader.read(okm);
type Item = u8;

fn next(&mut self) -> Option<Self::Item> {
if self.length == 0 {
return None;
}
self.length -= 1;
let mut byte = 0;
self.reader.read(array::from_mut(&mut byte));
Some(byte)
}
}

Expand Down Expand Up @@ -130,15 +139,13 @@ mod test {
{
assert_message(self.msg, domain, L::to_u16(), self.msg_prime);

let mut expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
let expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
&[self.msg],
&[dst],
NonZero::new(L::U16).ok_or(Error)?,
)?;

let mut uniform_bytes = Array::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);

let uniform_bytes: Array<u8, L> = expander.collect();
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
Ok(())
}
Expand Down