Skip to content

Commit f2a9e58

Browse files
Merge pull request #10 from jayvdb/wait-for-write
Allow waiting for write
2 parents c5a65b8 + 74f44a0 commit f2a9e58

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

src/lib.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,23 @@ use std::cell::RefCell;
66
use std::io::{Cursor, Error, ErrorKind, Read, Result, Write};
77
use std::mem::swap;
88
use std::rc::Rc;
9+
use std::sync::atomic::{AtomicBool, Ordering};
910
use std::sync::{Arc, Mutex};
11+
use std::thread::sleep;
12+
use std::time;
1013

1114
#[cfg(test)]
1215
mod tests;
1316

17+
fn find_subsequence<T>(haystack: &[T], needle: &[T]) -> Option<usize>
18+
where
19+
for<'a> &'a [T]: PartialEq,
20+
{
21+
haystack
22+
.windows(needle.len())
23+
.position(|window| window == needle)
24+
}
25+
1426
/// MockStream is Read+Write stream that stores the data written and provides the data to be read.
1527
#[derive(Clone)]
1628
pub struct MockStream {
@@ -37,6 +49,11 @@ impl MockStream {
3749
}
3850
}
3951

52+
/// Extract all bytes written by Write trait calls.
53+
pub fn peek_bytes_written(&mut self) -> &Vec<u8> {
54+
self.writer.get_ref()
55+
}
56+
4057
/// Extract all bytes written by Write trait calls.
4158
pub fn pop_bytes_written(&mut self) -> Vec<u8> {
4259
let mut result = Vec::new();
@@ -114,6 +131,8 @@ impl Write for SharedMockStream {
114131
#[derive(Clone, Default)]
115132
pub struct SyncMockStream {
116133
pimpl: Arc<Mutex<MockStream>>,
134+
pub waiting_for_write: Arc<AtomicBool>,
135+
pub expected_bytes: Vec<u8>,
117136
}
118137

119138
impl SyncMockStream {
@@ -122,6 +141,12 @@ impl SyncMockStream {
122141
SyncMockStream::default()
123142
}
124143

144+
/// Block reads until expected bytes are written.
145+
pub fn wait_for(&mut self, expected_bytes: &[u8]) {
146+
self.expected_bytes = expected_bytes.to_vec();
147+
self.waiting_for_write.store(true, Ordering::Relaxed);
148+
}
149+
125150
/// Extract all bytes written by Write trait calls.
126151
pub fn push_bytes_to_read(&mut self, bytes: &[u8]) {
127152
self.pimpl.lock().unwrap().push_bytes_to_read(bytes)
@@ -135,13 +160,27 @@ impl SyncMockStream {
135160

136161
impl Read for SyncMockStream {
137162
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
163+
while self.waiting_for_write.load(Ordering::Relaxed) {
164+
sleep(time::Duration::from_millis(10));
165+
}
138166
self.pimpl.lock().unwrap().read(buf)
139167
}
140168
}
141169

142170
impl Write for SyncMockStream {
143171
fn write(&mut self, buf: &[u8]) -> Result<usize> {
144-
self.pimpl.lock().unwrap().write(buf)
172+
let mut x = self.pimpl.lock().unwrap();
173+
match x.write(buf) {
174+
Ok(rv) => {
175+
if self.waiting_for_write.load(Ordering::Relaxed)
176+
&& find_subsequence(x.peek_bytes_written(), &self.expected_bytes).is_some()
177+
{
178+
self.waiting_for_write.store(false, Ordering::Relaxed);
179+
}
180+
Ok(rv)
181+
}
182+
Err(rv) => Err(rv),
183+
}
145184
}
146185

147186
fn flush(&mut self) -> Result<()> {

src/tests.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,37 @@ fn test_sync_mock_stream() {
177177
assert_eq!(s.pop_bytes_written(), &[1, 2, 3, 4]);
178178
assert_eq!(read, &[5, 6, 7, 8]);
179179
}
180+
181+
#[test]
182+
fn test_sync_mock_stream_wait_for() {
183+
use std::thread;
184+
use std::time::{Duration, SystemTime};
185+
let mut s = SyncMockStream::new();
186+
187+
s.wait_for(&[3, 4]);
188+
189+
let mut s2 = s.clone();
190+
191+
// thread will write some bytes, and then read some bytes
192+
s.push_bytes_to_read(&[5, 6, 7, 8]);
193+
let read_thread = thread::spawn(move || {
194+
let mut buf = Vec::new();
195+
s2.read_to_end(&mut buf).unwrap();
196+
let read_ts = SystemTime::now();
197+
(buf, read_ts)
198+
});
199+
200+
let write_thread = thread::spawn(move || {
201+
thread::sleep(Duration::new(2, 0));
202+
s.write_all(&[1, 2, 3, 4]).unwrap();
203+
let write_ts = SystemTime::now();
204+
(s.pop_bytes_written(), write_ts)
205+
});
206+
207+
let (read, read_ts) = read_thread.join().unwrap();
208+
let (write, write_ts) = write_thread.join().unwrap();
209+
210+
assert_eq!(write, &[1, 2, 3, 4]);
211+
assert_eq!(read, &[5, 6, 7, 8]);
212+
assert!(read_ts > write_ts);
213+
}

0 commit comments

Comments
 (0)