diff --git a/grpc/examples/inmemory.rs b/grpc/examples/inmemory.rs index 88b17ee01..e71ddc4d7 100644 --- a/grpc/examples/inmemory.rs +++ b/grpc/examples/inmemory.rs @@ -1,6 +1,6 @@ use std::any::Any; -use grpc::service::{Message, Request, Response, Service}; +use grpc::service::{Message, MessageAllocator, Request, Response, Service}; use grpc::{client::ChannelOptions, inmemory}; use tokio_stream::StreamExt; use tonic::async_trait; @@ -10,12 +10,37 @@ struct Handler {} #[derive(Debug)] struct MyReqMessage(String); +impl Message for MyReqMessage { + fn encode(&self, _: &mut bytes::BytesMut) -> Result<(), String> { + Err("not implemented".to_string()) + } + + fn decode(&mut self, _: &bytes::Bytes) -> Result<(), String> { + Err("not implemented".to_string()) + } +} + #[derive(Debug)] struct MyResMessage(String); +impl Message for MyResMessage { + fn encode(&self, _: &mut bytes::BytesMut) -> Result<(), String> { + Err("not implemented".to_string()) + } + + fn decode(&mut self, _: &bytes::Bytes) -> Result<(), String> { + Err("not implemented".to_string()) + } +} + #[async_trait] impl Service for Handler { - async fn call(&self, method: String, request: Request) -> Response { + async fn call( + &self, + method: String, + request: Request, + _: Box, + ) -> Response { let mut stream = request.into_inner(); let output = async_stream::try_stream! { while let Some(req) = stream.next().await { @@ -30,6 +55,15 @@ impl Service for Handler { } } +#[derive(Debug, Default)] +struct MyResMessageAllocator {} + +impl MessageAllocator for MyResMessageAllocator { + fn allocate(&self) -> Box { + Box::new(MyResMessage(String::new())) + } +} + #[tokio::main] async fn main() { inmemory::reg(); @@ -55,7 +89,13 @@ async fn main() { }; let req = Request::new(Box::pin(outbound)); - let res = chan.call("/some/method".to_string(), req).await; + let res = chan + .call( + "/some/method".to_string(), + req, + Box::new(MyResMessageAllocator {}), + ) + .await; let mut res = res.into_inner(); while let Some(resp) = res.next().await { diff --git a/grpc/examples/multiaddr.rs b/grpc/examples/multiaddr.rs index c631d33c5..e339a9b86 100644 --- a/grpc/examples/multiaddr.rs +++ b/grpc/examples/multiaddr.rs @@ -1,6 +1,6 @@ use std::any::Any; -use grpc::service::{Message, Request, Response, Service}; +use grpc::service::{Message, MessageAllocator, Request, Response, Service}; use grpc::{client::ChannelOptions, inmemory}; use tokio_stream::StreamExt; use tonic::async_trait; @@ -15,9 +15,43 @@ struct MyReqMessage(String); #[derive(Debug)] struct MyResMessage(String); +impl Message for MyReqMessage { + fn encode(&self, _: &mut bytes::BytesMut) -> Result<(), String> { + Err("not implemented".to_string()) + } + + fn decode(&mut self, _: &bytes::Bytes) -> Result<(), String> { + Err("not implemented".to_string()) + } +} + +#[derive(Debug, Default)] +struct MyResMessageAllocator {} + +impl Message for MyResMessage { + fn encode(&self, _: &mut bytes::BytesMut) -> Result<(), String> { + Err("not implemented".to_string()) + } + + fn decode(&mut self, _: &bytes::Bytes) -> Result<(), String> { + Err("not implemented".to_string()) + } +} + +impl MessageAllocator for MyResMessageAllocator { + fn allocate(&self) -> Box { + Box::new(MyResMessage(String::new())) + } +} + #[async_trait] impl Service for Handler { - async fn call(&self, method: String, request: Request) -> Response { + async fn call( + &self, + method: String, + request: Request, + _: Box, + ) -> Response { let id = self.id.clone(); let mut stream = request.into_inner(); let output = async_stream::try_stream! { @@ -79,7 +113,13 @@ async fn main() { }; let req = Request::new(Box::pin(outbound)); - let res = chan.call("/some/method".to_string(), req).await; + let res = chan + .call( + "/some/method".to_string(), + req, + Box::new(MyResMessageAllocator {}), + ) + .await; let mut res = res.into_inner(); while let Some(resp) = res.next().await { diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index edbd55131..e398061b0 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -42,9 +42,9 @@ use serde_json::json; use tonic::async_trait; use url::Url; // NOTE: http::Uri requires non-empty authority portion of URI -use crate::attributes::Attributes; use crate::rt; use crate::service::{Request, Response, Service}; +use crate::{attributes::Attributes, service::MessageAllocator}; use crate::{client::ConnectivityState, rt::Runtime}; use crate::{credentials::Credentials, rt::default_runtime}; @@ -204,9 +204,14 @@ impl Channel { s.clone().unwrap() } - pub async fn call(&self, method: String, request: Request) -> Response { + pub async fn call( + &self, + method: String, + request: Request, + response_allocator: Box, + ) -> Response { let ac = self.get_or_create_active_channel(); - ac.call(method, request).await + ac.call(method, request, response_allocator).await } } @@ -302,7 +307,12 @@ impl ActiveChannel { }) } - async fn call(&self, method: String, request: Request) -> Response { + async fn call( + &self, + method: String, + request: Request, + response_allocator: Box, + ) -> Response { // TODO: pre-pick tasks (e.g. deadlines, interceptors, retry) let mut i = self.picker.iter(); loop { @@ -314,7 +324,12 @@ impl ActiveChannel { if let Some(sc) = (pr.subchannel.as_ref() as &dyn Any) .downcast_ref::() { - return sc.isc.as_ref().unwrap().call(method, request).await; + return sc + .isc + .as_ref() + .unwrap() + .call(method, request, response_allocator) + .await; } else { panic!("picked subchannel is not an implementation provided by the channel"); } diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index 1a7e33fbf..e311c6c84 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -40,6 +40,16 @@ pub(crate) fn new_request() -> Request { ))) } +impl Message for EmptyMessage { + fn encode(&self, buf: &mut bytes::BytesMut) -> Result<(), String> { + Ok(()) + } + + fn decode(&mut self, buf: &bytes::Bytes) -> Result<(), String> { + Ok(()) + } +} + // A test subchannel that forwards connect calls to a channel. // This allows tests to verify when a subchannel is asked to connect. pub(crate) struct TestSubchannel { diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 6898f5d13..5378785c5 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -39,7 +39,7 @@ use std::{ }; mod backoff; -mod dns; +pub mod dns; mod registry; pub use registry::global_registry; use url::Url; diff --git a/grpc/src/client/subchannel.rs b/grpc/src/client/subchannel.rs index 4b5c314fc..f30dd6dd4 100644 --- a/grpc/src/client/subchannel.rs +++ b/grpc/src/client/subchannel.rs @@ -12,7 +12,7 @@ use crate::{ transport::{ConnectedTransport, TransportOptions}, }, rt::{BoxedTaskHandle, Runtime}, - service::{Request, Response, Service}, + service::{MessageAllocator, Request, Response, Service}, }; use core::panic; use std::time::{Duration, Instant}; @@ -192,7 +192,12 @@ struct InnerSubchannel { #[async_trait] impl Service for InternalSubchannel { - async fn call(&self, method: String, request: Request) -> Response { + async fn call( + &self, + method: String, + request: Request, + response_allocator: Box, + ) -> Response { let svc = self.inner.lock().unwrap().state.connected_transport(); if svc.is_none() { // TODO(easwars): Change the signature of this method to return a @@ -201,7 +206,7 @@ impl Service for InternalSubchannel { } let svc = svc.unwrap().clone(); - return svc.call(method, request).await; + return svc.call(method, request, response_allocator).await; } } diff --git a/grpc/src/client/transport/tonic/mod.rs b/grpc/src/client/transport/tonic/mod.rs index 11fcd24e1..532be3fd0 100644 --- a/grpc/src/client/transport/tonic/mod.rs +++ b/grpc/src/client/transport/tonic/mod.rs @@ -8,10 +8,12 @@ use crate::rt::BoxedTaskHandle; use crate::rt::Runtime; use crate::rt::TcpOptions; use crate::service::Message; +use crate::service::MessageAllocator; use crate::service::Request as GrpcRequest; use crate::service::Response as GrpcResponse; use crate::{client::name_resolution::TCP_IP_NETWORK_TYPE, service::Service}; use bytes::Bytes; +use bytes::BytesMut; use http::uri::PathAndQuery; use http::Request as HttpRequest; use http::Response as HttpResponse; @@ -63,7 +65,12 @@ impl Drop for TonicTransport { #[async_trait] impl Service for TonicTransport { - async fn call(&self, method: String, request: GrpcRequest) -> GrpcResponse { + async fn call( + &self, + method: String, + request: GrpcRequest, + response_allocator: Box, + ) -> GrpcResponse { let Ok(path) = PathAndQuery::from_maybe_shared(method) else { let err = Status::internal("Failed to parse path"); return create_error_response(err); @@ -78,7 +85,7 @@ impl Service for TonicTransport { }; let request = convert_request(request); let response = grpc.streaming(request, path, BytesCodec {}).await; - convert_response(response) + convert_response(response, response_allocator) } } @@ -88,23 +95,22 @@ fn create_error_response(status: Status) -> GrpcResponse { TonicResponse::new(Box::pin(stream)) } -fn convert_request(req: GrpcRequest) -> TonicRequest + Send>>> { +fn convert_request(req: GrpcRequest) -> TonicRequest> { let (metadata, extensions, stream) = req.into_parts(); - let bytes_stream = Box::pin(stream.filter_map(|msg| { - if let Ok(bytes) = (msg as Box).downcast::() { - Some(*bytes) - } else { - // If it fails, log the error and return None to filter it out. - eprintln!("A message could not be downcast to Bytes and was skipped."); - None - } + let bytes_stream = Box::pin(stream.map(|msg| { + let mut buf = BytesMut::with_capacity(msg.encoded_message_size_hint().unwrap_or(0)); + msg.encode(&mut buf).map_err(Status::internal)?; + Ok(buf.freeze()) })); TonicRequest::from_parts(metadata, extensions, bytes_stream as _) } -fn convert_response(res: Result>, Status>) -> GrpcResponse { +fn convert_response( + res: Result>, Status>, + allocator: Box, +) -> GrpcResponse { let response = match res { Ok(s) => s, Err(e) => { @@ -113,11 +119,14 @@ fn convert_response(res: Result>, Status>) -> Grp } }; let (metadata, stream, extensions) = response.into_parts(); - let message_stream: BoxStream> = Box::pin(stream.map(|msg| { - msg.map(|b| { - let msg: Box = Box::new(b); - msg - }) + let allocator: Arc = Arc::from(allocator); + let allocator_copy = allocator.clone(); + let message_stream: BoxStream> = Box::pin(stream.map(move |msg| { + let allocator = allocator_copy.clone(); + let buf = msg?; + let mut msg = allocator.allocate(); + msg.decode(&buf).map_err(Status::internal)?; + Ok(msg) })); TonicResponse::from_parts(metadata, message_stream, extensions) } diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs index 678280e34..e292063ad 100644 --- a/grpc/src/client/transport/tonic/test.rs +++ b/grpc/src/client/transport/tonic/test.rs @@ -1,12 +1,15 @@ -use crate::client::name_resolution::TCP_IP_NETWORK_TYPE; +use crate::client::name_resolution::{self, dns, TCP_IP_NETWORK_TYPE}; use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY; +use crate::client::{Channel, ChannelOptions}; use crate::echo_pb::echo_server::{Echo, EchoServer}; use crate::echo_pb::{EchoRequest, EchoResponse}; -use crate::service::Message; use crate::service::Request as GrpcRequest; +use crate::service::{Message, MessageAllocator}; use crate::{client::transport::TransportOptions, rt::tokio::TokioRuntime}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use std::any::Any; +use std::fmt::Debug; +use std::marker::PhantomData; use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::net::TcpListener; use tokio::sync::{mpsc, oneshot, Notify}; @@ -19,6 +22,99 @@ use tonic_prost::prost::Message as ProstMessage; const DEFAULT_TEST_DURATION: Duration = Duration::from_secs(10); const DEFAULT_TEST_SHORT_DURATION: Duration = Duration::from_millis(10); +impl Message for T +where + T: ProstMessage + Send + Sync + Any + Debug + Default, +{ + fn encode(&self, buf: &mut BytesMut) -> Result<(), String> { + T::encode(&self, buf).map_err(|err| err.to_string()) + } + + fn decode(&mut self, buf: &Bytes) -> Result<(), String> { + T::merge(self, buf.as_ref()).map_err(|err| err.to_string())?; + Ok(()) + } + + fn encoded_message_size_hint(&self) -> Option { + Some(T::encoded_len(&self)) + } +} + +#[derive(Default)] +struct ProstMessageAllocator { + _pd: PhantomData, +} + +impl MessageAllocator for ProstMessageAllocator +where + T: ProstMessage + Send + Sync + Any + Debug + Default, +{ + fn allocate(&self) -> Box { + Box::new(T::default()) + } +} + +// Tests the tonic transport by creating a bi-di stream with a tonic server. +#[tokio::test] +pub async fn grpc_channel_rpc() { + super::reg(); + dns::reg(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); // get the assigned address + println!("EchoServer listening on: {addr}"); + let server_handle = tokio::spawn(async move { + let echo_server = EchoService {}; + let svc = EchoServer::new(echo_server); + let _ = Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await; + }); + let chan_opts = ChannelOptions::default(); + let chan = Channel::new(&format!("dns:///{addr}"), None, chan_opts); + + let (tx, rx) = mpsc::channel::>(1); + + // Convert the mpsc receiver into a Stream + let outbound: GrpcRequest = + Request::new(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))); + + let mut inbound = chan + .call( + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho".to_string(), + outbound, + Box::new(ProstMessageAllocator::::default()), + ) + .await + .into_inner(); + + // Spawn a sender task + let client_handle = tokio::spawn(async move { + for i in 0..5 { + let message = format!("message {i}"); + let request = EchoRequest { + message: message.clone(), + }; + + println!("Sent request: {request:?}"); + assert!(tx.send(Box::new(request)).await.is_ok(), "Receiver dropped"); + + // Wait for the reply + let resp = inbound + .next() + .await + .expect("server unexpectedly closed the stream!") + .expect("server returned error"); + + let echo_response = (resp as Box).downcast::().unwrap(); + println!("Got response: {echo_response:?}"); + assert_eq!(echo_response.message, message); + } + }); + + client_handle.await.unwrap(); +} + // Tests the tonic transport by creating a bi-di stream with a tonic server. #[tokio::test] pub async fn tonic_transport_rpc() { @@ -60,6 +156,7 @@ pub async fn tonic_transport_rpc() { .call( "/grpc.examples.echo.Echo/BidirectionalStreamingEcho".to_string(), outbound, + Box::new(ProstMessageAllocator::::default()), ) .await .into_inner(); @@ -72,10 +169,8 @@ pub async fn tonic_transport_rpc() { message: message.clone(), }; - let bytes = Bytes::from(request.encode_to_vec()); - println!("Sent request: {request:?}"); - assert!(tx.send(Box::new(bytes)).await.is_ok(), "Receiver dropped"); + assert!(tx.send(Box::new(request)).await.is_ok(), "Receiver dropped"); // Wait for the reply let resp = inbound @@ -84,8 +179,7 @@ pub async fn tonic_transport_rpc() { .expect("server unexpectedly closed the stream!") .expect("server returned error"); - let bytes = (resp as Box).downcast::().unwrap(); - let echo_response = EchoResponse::decode(bytes).unwrap(); + let echo_response = (resp as Box).downcast::().unwrap(); println!("Got response: {echo_response:?}"); assert_eq!(echo_response.message, message); } diff --git a/grpc/src/codec.rs b/grpc/src/codec.rs index eb9cc03e7..eb638fd83 100644 --- a/grpc/src/codec.rs +++ b/grpc/src/codec.rs @@ -11,7 +11,7 @@ use tonic::{ pub(crate) struct BytesCodec {} impl Codec for BytesCodec { - type Encode = Bytes; + type Encode = Result; type Decode = Bytes; type Encoder = BytesEncoder; type Decoder = BytesDecoder; @@ -28,11 +28,11 @@ impl Codec for BytesCodec { pub struct BytesEncoder {} impl Encoder for BytesEncoder { - type Item = Bytes; + type Item = Result; type Error = Status; fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { - dst.put_slice(&item); + dst.put_slice(&item?); Ok(()) } } diff --git a/grpc/src/inmemory/mod.rs b/grpc/src/inmemory/mod.rs index b9dae99e0..ae5cfec87 100644 --- a/grpc/src/inmemory/mod.rs +++ b/grpc/src/inmemory/mod.rs @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, LazyLock, Mutex}; use std::{collections::HashMap, ops::Add}; +use crate::service::MessageAllocator; use crate::{ client::{ name_resolution::{ @@ -66,7 +67,12 @@ impl Drop for Listener { #[async_trait] impl Service for Arc { - async fn call(&self, method: String, request: Request) -> Response { + async fn call( + &self, + method: String, + request: Request, + _: Box, + ) -> Response { // 1. unblock accept, giving it a func back to me // 2. return what that func had let (s, r) = oneshot::channel(); diff --git a/grpc/src/server/mod.rs b/grpc/src/server/mod.rs index 18da685ca..d8bf09746 100644 --- a/grpc/src/server/mod.rs +++ b/grpc/src/server/mod.rs @@ -3,12 +3,20 @@ use std::sync::Arc; use tokio::sync::oneshot; use tonic::async_trait; -use crate::service::{Request, Response, Service}; +use crate::service::{MessageAllocator, Request, Response, Service}; pub struct Server { handler: Option>, } +struct NoOpAllocator {} + +impl MessageAllocator for NoOpAllocator { + fn allocate(&self) -> Box { + unimplemented!() + } +} + pub type Call = (String, Request, oneshot::Sender); #[async_trait] @@ -28,7 +36,13 @@ impl Server { pub async fn serve(&self, l: &impl Listener) { while let Some((method, req, reply_on)) = l.accept().await { reply_on - .send(self.handler.as_ref().unwrap().call(method, req).await) + .send( + self.handler + .as_ref() + .unwrap() + .call(method, req, Box::new(NoOpAllocator {})) + .await, + ) .ok(); // TODO: log error } } diff --git a/grpc/src/service.rs b/grpc/src/service.rs index 64d02ed17..23de81458 100644 --- a/grpc/src/service.rs +++ b/grpc/src/service.rs @@ -24,6 +24,7 @@ use std::{any::Any, fmt::Debug, pin::Pin}; +use bytes::{BufMut, Bytes, BytesMut}; use tokio_stream::Stream; use tonic::{async_trait, Request as TonicRequest, Response as TonicResponse, Status}; @@ -33,10 +34,32 @@ pub type Response = #[async_trait] pub trait Service: Send + Sync { - async fn call(&self, method: String, request: Request) -> Response; + async fn call( + &self, + method: String, + request: Request, + response_allocator: Box, + ) -> Response; } -// TODO: define methods that will allow serialization/deserialization. -pub trait Message: Any + Send + Sync + Debug {} +pub trait Message: Any + Send + Sync + Debug { + /// Encodes the message into the provided buffer. + fn encode(&self, buf: &mut BytesMut) -> Result<(), String>; + /// Decodes the message from the provided buffer. + fn decode(&mut self, buf: &Bytes) -> Result<(), String>; + /// Provides a hint for the expected size of the encoded message. + /// + /// This method can be used by encoders to pre-allocate buffer space, + /// potentially improving performance by reducing reallocations. It's a + /// best-effort hint and implementations may return `None` if an + /// accurate size cannot be easily determined without encoding. + fn encoded_message_size_hint(&self) -> Option { + None + } +} -impl Message for T where T: Any + Send + Sync + Debug {} +/// Allocates messages for responses on the client side and requests on the +/// server. +pub trait MessageAllocator: Send + Sync { + fn allocate(&self) -> Box; +}