Skip to content

Commit 663b2cf

Browse files
thedavekwonmeta-codesync[bot]
authored andcommitted
try_post to have SendError<M> return_channel (#1568)
Summary: Pull Request resolved: #1568 Refactor so that caller does not have to maintain additional error handling. Reviewed By: mariusae Differential Revision: D84834014 fbshipit-source-id: 4c50e54d24310e59c63a8c57d69714665f2f6445
1 parent 06e651f commit 663b2cf

File tree

10 files changed

+183
-166
lines changed

10 files changed

+183
-166
lines changed

docs/source/books/hyperactor-book/src/channels/transports/local.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@
1212

1313
**notes:**
1414
- `Tx::send` completes after local enqueue (oneshot dropped).
15-
- if the receiver is dropped, `try_post` fails immediately with `Err(SendError(ChannelError::Closed, message))`.

docs/source/books/hyperactor-book/src/channels/tx_rx.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Under the hood, network transports use a length-prefixed, multipart frame with c
2222
```rust
2323
#[async_trait]
2424
pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
25-
fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>>;
25+
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>);
2626
fn post(&self, message: M);
2727
async fn send(&self, message: M) -> Result<(), SendError<M>>;
2828
fn addr(&self) -> ChannelAddr;
@@ -32,8 +32,7 @@ pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
3232

3333
- **`try_post(message, return_channel)`**
3434
Enqueues locally.
35-
- Immediate failure → `Err(SendError(ChannelError::Closed, message))`.
36-
- `Ok(())` means queued; if delivery later fails, the original message is sent back on `return_channel`.
35+
- If delivery later fails, the original message is sent back on `return_channel` as SendError.
3736

3837
- **`post(message)`**
3938
Fire-and-forget wrapper around `try_post`. The caller should monitor `status()` for health instead of relying on return values.
@@ -91,7 +90,7 @@ pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
9190
### Failure semantics
9291
- **Closed receiver:** `recv()` returns `Err(ChannelError::Closed)`.
9392
- **Network transports:** disconnects trigger exponential backoff reconnects; unacked messages are retried. If recovery ultimately fails (e.g., connection cannot be re-established within the delivery timeout window), the client closes and returns all undelivered/unacked messages via their `return_channel`. `status()` flips to `Closed`.
94-
- **Local transport:** no delayed return path; if the receiver is gone, `try_post` fails immediately with `Err(SendError(ChannelError::Closed, message))`.
93+
- **Local transport:** no delayed return path.
9594
- **Network disconnects (EOF/I/O error/temporary break):** the client reconnects with exponential backoff and resends any unacked messages; the server deduplicates by `seq`.
9695
- **Delivery timeout:** see [Size & time limits](#size--time-limits).
9796

@@ -104,7 +103,7 @@ pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
104103

105104
Concrete channel implementations that satisfy `Tx<M>` / `Rx<M>`:
106105

107-
- **Local** — in-process only; uses `tokio::sync::mpsc`. No network framing/acks. `try_post` fails immediately if the receiver is gone.
106+
- **Local** — in-process only; uses `tokio::sync::mpsc`. No network framing/acks.
108107
_Dial/serve:_ `serve_local::<M>()`, `ChannelAddr::Local(_)`.
109108

110109
- **TCP**`tokio::net::TcpStream` with 8-byte BE length-prefixed frames; `seq`/`ack` for exactly-once into the server queue; reconnects with backoff.

docs/source/books/hyperactor-book/src/mailboxes/mailbox_client.md

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,18 @@ impl MailboxClient {
124124
let return_handle_0 = return_handle.clone();
125125
tokio::spawn(async move {
126126
let result = return_receiver.await;
127-
if let Ok(message) = result {
128-
let _ = return_handle_0.send(Undeliverable(message));
129-
} else {
130-
// Sender dropped, this task can end.
127+
if let Ok(SendError(e, message)) = result {
128+
message.undeliverable(
129+
DeliveryError::BrokenLink(format!(
130+
"failed to enqueue in MailboxClient when processing buffer: {e}"
131+
)),
132+
return_handle_0,
133+
);
131134
}
132135
});
133136
// Send the message for transmission.
134-
let return_handle_1 = return_handle.clone();
135-
async move {
136-
if let Err(SendError(_, envelope)) = tx.try_post(envelope, return_channel) {
137-
// Failed to enqueue.
138-
envelope.undeliverable(
139-
DeliveryError::BrokenLink("failed to enqueue in MailboxClient".to_string()),
140-
return_handle_1.clone(),
141-
);
142-
}
143-
}
137+
tx.try_post(envelope, return_channel);
138+
future::ready(())
144139
});
145140
let this = Self {
146141
buffer,

hyperactor/benches/main.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ fn bench_message_rates(c: &mut Criterion) {
150150
Vec::with_capacity(rate as usize);
151151
for _ in 0..rate {
152152
let (return_sender, return_receiver) = oneshot::channel();
153-
if let Err(e) = tx.try_post(message.clone(), return_sender) {
154-
panic!("Failed to send message: {:?}", e);
155-
}
153+
tx.try_post(message.clone(), return_sender);
156154

157155
let handle = tokio::spawn(async move {
158156
_ = tokio::time::timeout(

hyperactor/src/channel.rs

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,24 @@ pub enum TxStatus {
109109
pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
110110
/// Enqueue a `message` on the local end of the channel. The
111111
/// message is either delivered, or we eventually discover that
112-
/// the channel has failed and it will be sent back on `return_handle`.
113-
// TODO: the return channel should be SendError<M> directly, and we should drop
114-
// the returned result.
112+
/// the channel has failed and it will be sent back on `return_channel`.
115113
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
116-
fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>>;
114+
// TODO: Consider making return channel optional to indicate that the log can be dropped.
115+
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>);
117116

118-
/// Enqueue a message to be sent on the channel. The caller is expected to monitor
119-
/// the channel status for failures.
117+
/// Enqueue a message to be sent on the channel.
120118
fn post(&self, message: M) {
121-
// We ignore errors here because the caller is meant to monitor the channel's
122-
// status, rather than rely on this function to report errors.
123-
let _ignore = self.try_post(message, oneshot::channel().0);
119+
self.try_post(message, oneshot::channel().0);
124120
}
125121

126122
/// Send a message synchronously, returning when the messsage has
127123
/// been delivered to the remote end of the channel.
128124
async fn send(&self, message: M) -> Result<(), SendError<M>> {
129125
let (tx, rx) = oneshot::channel();
130-
self.try_post(message, tx)?;
126+
self.try_post(message, tx);
131127
match rx.await {
132128
// Channel was closed; the message was not delivered.
133-
Ok(m) => Err(SendError(ChannelError::Closed, m)),
129+
Ok(err) => Err(err),
134130

135131
// Channel was dropped; the message was successfully enqueued
136132
// on the remote end of the channel.
@@ -179,14 +175,12 @@ impl<M: RemoteMessage> MpscTx<M> {
179175

180176
#[async_trait]
181177
impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
182-
fn try_post(
183-
&self,
184-
message: M,
185-
_return_channel: oneshot::Sender<M>,
186-
) -> Result<(), SendError<M>> {
187-
self.tx
188-
.send(message)
189-
.map_err(|mpsc::error::SendError(message)| SendError(ChannelError::Closed, message))
178+
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
179+
if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
180+
if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) {
181+
tracing::warn!("failed to deliver SendError: {}", m);
182+
}
183+
}
190184
}
191185

192186
fn addr(&self) -> ChannelAddr {
@@ -749,7 +743,7 @@ enum ChannelTxKind<M: RemoteMessage> {
749743

750744
#[async_trait]
751745
impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
752-
fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>> {
746+
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
753747
match &self.inner {
754748
ChannelTxKind::Local(tx) => tx.try_post(message, return_channel),
755749
ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel),
@@ -1054,7 +1048,7 @@ mod tests {
10541048
let addr = listen_addr.clone();
10551049
sends.spawn(async move {
10561050
let tx = dial::<u64>(addr).unwrap();
1057-
tx.try_post(message, oneshot::channel().0).unwrap();
1051+
tx.post(message);
10581052
});
10591053
}
10601054

@@ -1089,7 +1083,7 @@ mod tests {
10891083
let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
10901084

10911085
let tx = dial::<u64>(listen_addr).unwrap();
1092-
tx.try_post(123, oneshot::channel().0).unwrap();
1086+
tx.post(123);
10931087
drop(rx);
10941088

10951089
// New transmits should fail... but there is buffering, etc.,
@@ -1099,12 +1093,15 @@ mod tests {
10991093
let start = RealClock.now();
11001094

11011095
let result = loop {
1102-
let result = tx.try_post(123, oneshot::channel().0);
1103-
if result.is_err() || start.elapsed() > Duration::from_secs(10) {
1096+
let (return_tx, return_rx) = oneshot::channel();
1097+
tx.try_post(123, return_tx);
1098+
let result = return_rx.await;
1099+
1100+
if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
11041101
break result;
11051102
}
11061103
};
1107-
assert_matches!(result, Err(SendError(ChannelError::Closed, 123)));
1104+
assert_matches!(result, Ok(SendError(ChannelError::Closed, 123)));
11081105
}
11091106
}
11101107

@@ -1137,7 +1134,7 @@ mod tests {
11371134
for addr in addrs() {
11381135
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
11391136
let tx = crate::channel::dial(listen_addr).unwrap();
1140-
tx.try_post(123, oneshot::channel().0).unwrap();
1137+
tx.post(123);
11411138
assert_eq!(rx.recv().await.unwrap(), 123);
11421139
}
11431140
}

hyperactor/src/channel/local.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,21 @@ pub struct LocalTx<M: RemoteMessage> {
7272

7373
#[async_trait]
7474
impl<M: RemoteMessage> Tx<M> for LocalTx<M> {
75-
fn try_post(
76-
&self,
77-
message: M,
78-
_return_channel: oneshot::Sender<M>,
79-
) -> Result<(), SendError<M>> {
75+
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
8076
let data: Data = match bincode::serialize(&message) {
8177
Ok(data) => data,
82-
Err(err) => return Err(SendError(err.into(), message)),
78+
Err(err) => {
79+
if let Err(m) = return_channel.send(SendError(err.into(), message)) {
80+
tracing::warn!("failed to deliver SendError: {}", m);
81+
}
82+
return;
83+
}
8384
};
84-
self.tx
85-
.send(data)
86-
.map_err(|_| SendError(ChannelError::Closed, message))
85+
if self.tx.send(data).is_err() {
86+
if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) {
87+
tracing::warn!("failed to deliver SendError: {}", m);
88+
}
89+
}
8790
}
8891

8992
fn addr(&self) -> ChannelAddr {
@@ -167,7 +170,7 @@ mod tests {
167170
async fn test_local_basic() {
168171
let (tx, mut rx) = local::new::<u64>();
169172

170-
tx.try_post(123, unused_return_channel()).unwrap();
173+
tx.try_post(123, unused_return_channel());
171174
assert_eq!(rx.recv().await.unwrap(), 123);
172175
}
173176

@@ -178,23 +181,22 @@ mod tests {
178181

179182
let tx = local::dial::<u64>(port).unwrap();
180183

181-
tx.try_post(123, unused_return_channel()).unwrap();
184+
tx.try_post(123, unused_return_channel());
182185
assert_eq!(rx.recv().await.unwrap(), 123);
183186

184187
drop(rx);
185188

186-
assert_matches!(
187-
tx.try_post(123, unused_return_channel()),
188-
Err(SendError(ChannelError::Closed, 123))
189-
);
189+
let (return_tx, return_rx) = oneshot::channel();
190+
tx.try_post(123, return_tx);
191+
assert_matches!(return_rx.await, Ok(SendError(ChannelError::Closed, 123)));
190192
}
191193

192194
#[tokio::test]
193195
async fn test_local_drop() {
194196
let (port, mut rx) = local::serve::<u64>();
195197
let tx = local::dial::<u64>(port).unwrap();
196198

197-
tx.try_post(123, unused_return_channel()).unwrap();
199+
tx.try_post(123, unused_return_channel());
198200
assert_eq!(rx.recv().await.unwrap(), 123);
199201

200202
drop(rx);

0 commit comments

Comments
 (0)