Skip to content

Commit 9ed6117

Browse files
committed
Use Iterator instead of Expander
1 parent 56a3fee commit 9ed6117

File tree

6 files changed

+91
-110
lines changed

6 files changed

+91
-110
lines changed

hash2curve/src/group_digest.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ pub trait GroupDigest: MapToCurve {
3232
///
3333
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
3434
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
35-
fn hash_from_bytes<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<ProjectivePoint<Self>>
35+
fn hash_from_bytes<'dst, X>(msg: &[&[u8]], dst: &'dst [&[u8]]) -> Result<ProjectivePoint<Self>>
3636
where
37-
X: ExpandMsg<Self::K>,
37+
X: ExpandMsg<'dst, Self::K>,
3838
{
3939
let [u0, u1] = hash_to_field::<2, X, _, Self::FieldElement>(msg, dst)?;
4040
let q0 = Self::map_to_curve(u0);
@@ -62,9 +62,12 @@ pub trait GroupDigest: MapToCurve {
6262
///
6363
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
6464
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
65-
fn encode_from_bytes<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<ProjectivePoint<Self>>
65+
fn encode_from_bytes<'dst, X>(
66+
msg: &[&[u8]],
67+
dst: &'dst [&[u8]],
68+
) -> Result<ProjectivePoint<Self>>
6669
where
67-
X: ExpandMsg<Self::K>,
70+
X: ExpandMsg<'dst, Self::K>,
6871
{
6972
let [u] = hash_to_field::<1, X, _, Self::FieldElement>(msg, dst)?;
7073
let q0 = Self::map_to_curve(u);
@@ -85,9 +88,9 @@ pub trait GroupDigest: MapToCurve {
8588
///
8689
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
8790
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
88-
fn hash_to_scalar<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<Self::Scalar>
91+
fn hash_to_scalar<'dst, X>(msg: &[&[u8]], dst: &'dst [&[u8]]) -> Result<Self::Scalar>
8992
where
90-
X: ExpandMsg<Self::K>,
93+
X: ExpandMsg<'dst, Self::K>,
9194
{
9295
let [u] = hash_to_field::<1, X, _, Self::Scalar>(msg, dst)?;
9396
Ok(u)

hash2curve/src/hash2field.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,24 @@ pub trait FromOkm {
3838
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
3939
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
4040
#[doc(hidden)]
41-
pub fn hash_to_field<const N: usize, E, K, T>(data: &[&[u8]], domain: &[&[u8]]) -> Result<[T; N]>
41+
pub fn hash_to_field<'dst, const N: usize, E, K, T>(
42+
data: &[&[u8]],
43+
domain: &'dst [&[u8]],
44+
) -> Result<[T; N]>
4245
where
43-
E: ExpandMsg<K>,
46+
E: ExpandMsg<'dst, K>,
4447
T: FromOkm + Default,
4548
{
4649
let len_in_bytes = T::Length::USIZE
4750
.checked_mul(N)
4851
.and_then(|len| len.try_into().ok())
4952
.and_then(NonZeroU16::new)
5053
.ok_or(Error)?;
51-
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
5254
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
5355
Ok(core::array::from_fn(|_| {
54-
expander.fill_bytes(&mut tmp);
56+
let tmp = Array::<u8, <T as FromOkm>::Length>::from_iter(
57+
expander.by_ref().take(T::Length::USIZE),
58+
);
5559
T::from_okm(&tmp)
5660
}))
5761
}

hash2curve/src/hash2field/expand_msg.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,16 @@ const MAX_DST_LEN: usize = 255;
2222
///
2323
/// # Errors
2424
/// See implementors of [`ExpandMsg`] for errors.
25-
pub trait ExpandMsg<K> {
26-
/// Type holding data for the [`Expander`].
27-
type Expander<'dst>: Expander + Sized;
28-
25+
pub trait ExpandMsg<'dst, K>: Iterator<Item = u8> + Sized {
2926
/// Expands `msg` to the required number of bytes.
3027
///
3128
/// Returns an expander that can be used to call `read` until enough
3229
/// bytes have been consumed
33-
fn expand_message<'dst>(
30+
fn expand_message(
3431
msg: &[&[u8]],
3532
dst: &'dst [&[u8]],
3633
len_in_bytes: NonZero<u16>,
37-
) -> Result<Self::Expander<'dst>>;
38-
}
39-
40-
/// Expander that, call `read` until enough bytes have been consumed.
41-
pub trait Expander {
42-
/// Fill the array with the expanded bytes
43-
fn fill_bytes(&mut self, okm: &mut [u8]);
34+
) -> Result<Self>;
4435
}
4536

4637
/// The domain separation tag

hash2curve/src/hash2field/expand_msg/xmd.rs

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
//! `expand_message_xmd` based on a hash function.
22
3-
use core::{marker::PhantomData, num::NonZero, ops::Mul};
3+
use core::{num::NonZero, ops::Mul};
44

5-
use super::{Domain, ExpandMsg, Expander};
5+
use super::{Domain, ExpandMsg};
66
use digest::{
77
FixedOutput, HashMarker,
88
array::{
@@ -21,12 +21,20 @@ use elliptic_curve::{Error, Result};
2121
/// - `dst > 255 && HashT::OutputSize > 255`
2222
/// - `len_in_bytes > 255 * HashT::OutputSize`
2323
#[derive(Debug)]
24-
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
24+
pub struct ExpandMsgXmd<'a, HashT>
2525
where
2626
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
27-
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>;
27+
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
28+
{
29+
b_0: Array<u8, HashT::OutputSize>,
30+
b_vals: Array<u8, HashT::OutputSize>,
31+
domain: Domain<'a, HashT::OutputSize>,
32+
index: u8,
33+
offset: usize,
34+
length: u16,
35+
}
2836

29-
impl<HashT, K> ExpandMsg<K> for ExpandMsgXmd<HashT>
37+
impl<'dst, HashT, K> ExpandMsg<'dst, K> for ExpandMsgXmd<'dst, HashT>
3038
where
3139
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
3240
// The number of bits output by `HashT` MUST be at most `HashT::BlockSize`.
@@ -37,23 +45,18 @@ where
3745
K: Mul<U2>,
3846
HashT::OutputSize: IsGreaterOrEqual<Prod<K, U2>, Output = True>,
3947
{
40-
type Expander<'dst> = ExpanderXmd<'dst, HashT>;
41-
42-
fn expand_message<'dst>(
48+
fn expand_message(
4349
msg: &[&[u8]],
4450
dst: &'dst [&[u8]],
4551
len_in_bytes: NonZero<u16>,
46-
) -> Result<Self::Expander<'dst>> {
52+
) -> Result<Self> {
4753
let b_in_bytes = HashT::OutputSize::USIZE;
4854

4955
// `255 * <b_in_bytes>` can not exceed `u16::MAX`
5056
if usize::from(len_in_bytes.get()) > 255 * b_in_bytes {
5157
return Err(Error);
5258
}
5359

54-
let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
55-
.expect("should never pass the previous check");
56-
5760
let domain = Domain::xmd::<HashT>(dst)?;
5861
let mut b_0 = HashT::default();
5962
b_0.update(&Array::<u8, HashT::BlockSize>::default());
@@ -75,74 +78,51 @@ where
7578
b_vals.update(&[domain.len()]);
7679
let b_vals = b_vals.finalize_fixed();
7780

78-
Ok(ExpanderXmd {
81+
Ok(Self {
7982
b_0,
8083
b_vals,
8184
domain,
8285
index: 1,
8386
offset: 0,
84-
ell,
87+
length: len_in_bytes.get(),
8588
})
8689
}
8790
}
8891

89-
/// [`Expander`] type for [`ExpandMsgXmd`].
90-
#[derive(Debug)]
91-
pub struct ExpanderXmd<'a, HashT>
92+
impl<HashT> Iterator for ExpandMsgXmd<'_, HashT>
9293
where
9394
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
9495
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
9596
{
96-
b_0: Array<u8, HashT::OutputSize>,
97-
b_vals: Array<u8, HashT::OutputSize>,
98-
domain: Domain<'a, HashT::OutputSize>,
99-
index: u8,
100-
offset: usize,
101-
ell: u8,
102-
}
97+
type Item = u8;
10398

104-
impl<HashT> ExpanderXmd<'_, HashT>
105-
where
106-
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
107-
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
108-
{
109-
fn next(&mut self) -> bool {
110-
if self.index < self.ell {
111-
self.index += 1;
112-
self.offset = 0;
113-
// b_0 XOR b_(idx - 1)
114-
let mut tmp = Array::<u8, HashT::OutputSize>::default();
115-
self.b_0
116-
.iter()
117-
.zip(&self.b_vals[..])
118-
.enumerate()
119-
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
120-
let mut b_vals = HashT::default();
121-
b_vals.update(&tmp);
122-
b_vals.update(&[self.index]);
123-
self.domain.update_hash(&mut b_vals);
124-
b_vals.update(&[self.domain.len()]);
125-
self.b_vals = b_vals.finalize_fixed();
126-
true
127-
} else {
128-
false
129-
}
130-
}
131-
}
132-
133-
impl<HashT> Expander for ExpanderXmd<'_, HashT>
134-
where
135-
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
136-
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
137-
{
138-
fn fill_bytes(&mut self, okm: &mut [u8]) {
139-
for b in okm {
140-
if self.offset == self.b_vals.len() && !self.next() {
141-
return;
142-
}
143-
*b = self.b_vals[self.offset];
99+
fn next(&mut self) -> Option<u8> {
100+
if (self.index as u16 - 1) * HashT::OutputSize::U16 + self.offset as u16
101+
== self.length
102+
{
103+
return None;
104+
} else if self.offset != self.b_vals.len() {
105+
let byte = self.b_vals[self.offset];
144106
self.offset += 1;
107+
return Some(byte);
145108
}
109+
110+
self.index += 1;
111+
self.offset = 1;
112+
// b_0 XOR b_(idx - 1)
113+
let mut tmp = Array::<u8, HashT::OutputSize>::default();
114+
self.b_0
115+
.iter()
116+
.zip(&self.b_vals[..])
117+
.enumerate()
118+
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
119+
let mut b_vals = HashT::default();
120+
b_vals.update(&tmp);
121+
b_vals.update(&[self.index]);
122+
self.domain.update_hash(&mut b_vals);
123+
b_vals.update(&[self.domain.len()]);
124+
self.b_vals = b_vals.finalize_fixed();
125+
Some(self.b_vals[0])
146126
}
147127
}
148128

@@ -210,15 +190,13 @@ mod test {
210190
assert_message::<HashT>(self.msg, domain, L::U16, self.msg_prime);
211191

212192
let dst = [dst];
213-
let mut expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
193+
let expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
214194
&[self.msg],
215195
&dst,
216196
NonZero::new(L::U16).ok_or(Error)?,
217197
)?;
218198

219-
let mut uniform_bytes = Array::<u8, L>::default();
220-
expander.fill_bytes(&mut uniform_bytes);
221-
199+
let uniform_bytes = Array::<u8, L>::from_iter(expander);
222200
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
223201
Ok(())
224202
}

hash2curve/src/hash2field/expand_msg/xof.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! `expand_message_xof` for the `ExpandMsg` trait
22
3-
use super::{Domain, ExpandMsg, Expander};
3+
use super::{Domain, ExpandMsg};
44
use core::{fmt, num::NonZero, ops::Mul};
5+
use digest::XofReader;
56
use digest::{
6-
CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual,
7+
CollisionResistance, ExtendableOutput, HashMarker, Update, typenum::IsGreaterOrEqual,
78
};
89
use elliptic_curve::Result;
910
use elliptic_curve::array::{
@@ -21,7 +22,8 @@ pub struct ExpandMsgXof<HashT>
2122
where
2223
HashT: Default + ExtendableOutput + Update + HashMarker,
2324
{
24-
reader: <HashT as ExtendableOutput>::Reader,
25+
reader: HashT::Reader,
26+
length: u16,
2527
}
2628

2729
impl<HashT> fmt::Debug for ExpandMsgXof<HashT>
@@ -36,7 +38,7 @@ where
3638
}
3739
}
3840

39-
impl<HashT, K> ExpandMsg<K> for ExpandMsgXof<HashT>
41+
impl<HashT, K> ExpandMsg<'_, K> for ExpandMsgXof<HashT>
4042
where
4143
HashT: Default + ExtendableOutput + Update + HashMarker,
4244
// If DST is larger than 255 bytes, the length of the computed DST is calculated by `K * 2`.
@@ -46,13 +48,7 @@ where
4648
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.2-2.1
4749
HashT: CollisionResistance<CollisionResistance: IsGreaterOrEqual<K, Output = True>>,
4850
{
49-
type Expander<'dst> = Self;
50-
51-
fn expand_message<'dst>(
52-
msg: &[&[u8]],
53-
dst: &'dst [&[u8]],
54-
len_in_bytes: NonZero<u16>,
55-
) -> Result<Self::Expander<'dst>> {
51+
fn expand_message(msg: &[&[u8]], dst: &[&[u8]], len_in_bytes: NonZero<u16>) -> Result<Self> {
5652
let len_in_bytes = len_in_bytes.get();
5753

5854
let domain = Domain::<Prod<K, U2>>::xof::<HashT>(dst)?;
@@ -66,16 +62,27 @@ where
6662
domain.update_hash(&mut reader);
6763
reader.update(&[domain.len()]);
6864
let reader = reader.finalize_xof();
69-
Ok(Self { reader })
65+
Ok(Self {
66+
reader,
67+
length: len_in_bytes,
68+
})
7069
}
7170
}
7271

73-
impl<HashT> Expander for ExpandMsgXof<HashT>
72+
impl<HashT> Iterator for ExpandMsgXof<HashT>
7473
where
7574
HashT: Default + ExtendableOutput + Update + HashMarker,
7675
{
77-
fn fill_bytes(&mut self, okm: &mut [u8]) {
78-
self.reader.read(okm);
76+
type Item = u8;
77+
78+
fn next(&mut self) -> Option<Self::Item> {
79+
if self.length == 0 {
80+
return None;
81+
}
82+
self.length -= 1;
83+
let mut buf = [0u8; 1];
84+
self.reader.read(&mut buf);
85+
Some(buf[0])
7986
}
8087
}
8188

@@ -130,15 +137,13 @@ mod test {
130137
{
131138
assert_message(self.msg, domain, L::to_u16(), self.msg_prime);
132139

133-
let mut expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
140+
let expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
134141
&[self.msg],
135142
&[dst],
136143
NonZero::new(L::U16).ok_or(Error)?,
137144
)?;
138145

139-
let mut uniform_bytes = Array::<u8, L>::default();
140-
expander.fill_bytes(&mut uniform_bytes);
141-
146+
let uniform_bytes = Array::<u8, L>::from_iter(expander);
142147
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
143148
Ok(())
144149
}

hash2curve/src/oprf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ pub trait OprfParameters: GroupDigest + PrimeCurve {
2525
/// and `HashToScalar` as defined in [section 4 of RFC9497][oprf].
2626
///
2727
/// [oprf]: https://www.rfc-editor.org/rfc/rfc9497.html#name-ciphersuites
28-
type ExpandMsg: ExpandMsg<<Self as GroupDigest>::K>;
28+
type ExpandMsg<'a>: ExpandMsg<'a, <Self as GroupDigest>::K>;
2929
}

0 commit comments

Comments
 (0)