From f675352a8ba08f7ce0385087355f102015cdf61c Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 8 Aug 2025 07:03:26 -0600 Subject: [PATCH 01/11] chore: update readme with v0.14.x branch (#2383) --- .github/PULL_REQUEST_TEMPLATE.md | 3 +++ README.md | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 42fa9c650..09bd860bc 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,9 @@ Thank you for your Pull Request. Please provide a description above and review the requirements below. +If this change is intended for tonic `v0.14.x` please make this PR against that branch +otherwise, it may not get included in a relase for a long time. + Bug fixes and new features should include tests. Contributors guide: https://github.com/hyperium/tonic/blob/master/CONTRIBUTING.md diff --git a/README.md b/README.md index d052783df..264ddfaad 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,13 @@ ![](https://github.com/hyperium/tonic/raw/master/.github/assets/tonic-banner.svg?sanitize=true) + A rust implementation of [gRPC], a high performance, open source, general RPC framework that puts mobile and HTTP/2 first. +> **Note**: tonic's [master](https://github.com/hyperium/tonic) branch is +> currently preparing breaking changes. For the most recently *released* code, +> look to the [0.14.x branch](https://github.com/hyperium/tonic/tree/v0.14.x). + [`tonic`] is a gRPC over HTTP/2 implementation focused on high performance, interoperability, and flexibility. This library was created to have first class support of async/await and to act as a core building block for production systems written in Rust. [![Crates.io](https://img.shields.io/crates/v/tonic)](https://crates.io/crates/tonic) From b4c81fc82970d46da88778ac3b5d79721e63de4d Mon Sep 17 00:00:00 2001 From: cjqzhao <159480369+cjqzhao@users.noreply.github.com> Date: Fri, 8 Aug 2025 15:12:34 -0700 Subject: [PATCH 02/11] feat(grpc): add exit_idle to LbPolicy trait (#2332) --- grpc/src/client/load_balancing/child_manager.rs | 4 ++++ grpc/src/client/load_balancing/mod.rs | 4 ++++ grpc/src/client/load_balancing/pick_first.rs | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index 0d4af6542..fea17bfe5 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -262,6 +262,10 @@ impl LbPolicy for ChildManager self.resolve_child_controller(channel_controller, child_idx); } } + + fn exit_idle(&mut self, _channel_controller: &mut dyn ChannelController) { + todo!("implement exit_idle") + } } struct WrappedController<'a> { diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index b91950c31..7cb562660 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -169,6 +169,10 @@ pub trait LbPolicy: Send { /// Called by the channel in response to a call from the LB policy to the /// WorkScheduler's request_work method. fn work(&mut self, channel_controller: &mut dyn ChannelController); + + /// Called by the channel when an LbPolicy goes idle and the channel + /// wants it to start connecting to subchannels again. + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController); } /// Controls channel behaviors. diff --git a/grpc/src/client/load_balancing/pick_first.rs b/grpc/src/client/load_balancing/pick_first.rs index ed7ae76f6..54ae78711 100644 --- a/grpc/src/client/load_balancing/pick_first.rs +++ b/grpc/src/client/load_balancing/pick_first.rs @@ -99,6 +99,10 @@ impl LbPolicy for PickFirstPolicy { } fn work(&mut self, channel_controller: &mut dyn ChannelController) {} + + fn exit_idle(&mut self, _channel_controller: &mut dyn ChannelController) { + todo!("implement exit_idle") + } } struct OneSubchannelPicker { From 9bdfd225759294a74bf5a7847aff301e298dab6d Mon Sep 17 00:00:00 2001 From: tottoto Date: Sun, 10 Aug 2025 00:07:03 +0900 Subject: [PATCH 03/11] chore(ci): Update to cargo-check-external-types 0.3.0 (#2384) --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index dba221904..626fde513 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -268,11 +268,11 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2025-05-04 + toolchain: nightly-2025-08-06 - name: Install cargo-check-external-types uses: taiki-e/cache-cargo-install-action@v2 with: - tool: cargo-check-external-types@0.2.0 + tool: cargo-check-external-types@0.3.0 - uses: taiki-e/install-action@cargo-hack - uses: Swatinem/rust-cache@v2 - run: cargo hack --no-private check-external-types --all-features From 7662376fff715e403ff840ab828765f98e740620 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 12 Aug 2025 10:07:20 -0700 Subject: [PATCH 04/11] grpc: add testing utilities for LB policy tests (#2380) --- grpc/src/client/load_balancing/mod.rs | 2 + grpc/src/client/load_balancing/test_utils.rs | 147 +++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 grpc/src/client/load_balancing/test_utils.rs diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 7cb562660..ad576e819 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -53,6 +53,8 @@ use crate::client::{ pub mod child_manager; pub mod pick_first; +#[cfg(test)] +pub mod test_utils; pub(crate) mod registry; use super::{service_config::LbConfig, subchannel::SubchannelStateWatcher}; diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs new file mode 100644 index 000000000..93ffc5511 --- /dev/null +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -0,0 +1,147 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use crate::client::load_balancing::{ + ChannelController, ExternalSubchannel, ForwardingSubchannel, LbState, Subchannel, WorkScheduler, +}; +use crate::client::name_resolution::Address; +use crate::service::{Message, Request, Response, Service}; +use std::hash::{Hash, Hasher}; +use std::{fmt::Debug, ops::Add, sync::Arc}; +use tokio::sync::{mpsc, Notify}; +use tokio::task::AbortHandle; + +pub(crate) struct EmptyMessage {} +impl Message for EmptyMessage {} +pub(crate) fn new_request() -> Request { + Request::new(Box::pin(tokio_stream::once( + Box::new(EmptyMessage {}) as Box + ))) +} + +// A test subchannel that forwards connect calls to a channel. +// This allows tests to verify when a subchannel is asked to connect. +pub(crate) struct TestSubchannel { + address: Address, + tx_connect: mpsc::UnboundedSender, +} + +impl TestSubchannel { + fn new(address: Address, tx_connect: mpsc::UnboundedSender) -> Self { + Self { + address, + tx_connect, + } + } +} + +impl ForwardingSubchannel for TestSubchannel { + fn delegate(&self) -> Arc { + panic!("unsupported operation on a test subchannel"); + } + + fn address(&self) -> Address { + self.address.clone() + } + + fn connect(&self) { + println!("connect called for subchannel {}", self.address); + self.tx_connect + .send(TestEvent::Connect(self.address.clone())) + .unwrap(); + } +} + +impl Hash for TestSubchannel { + fn hash(&self, state: &mut H) { + self.address.hash(state); + } +} + +impl PartialEq for TestSubchannel { + fn eq(&self, other: &Self) -> bool { + std::ptr::eq(self, other) + } +} +impl Eq for TestSubchannel {} + +pub(crate) enum TestEvent { + NewSubchannel(Arc), + UpdatePicker(LbState), + RequestResolution, + Connect(Address), + ScheduleWork, +} + +// TODO(easwars): Remove this and instead derive Debug. +impl Debug for TestEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NewSubchannel(sc) => write!(f, "NewSubchannel({})", sc.address()), + Self::UpdatePicker(state) => write!(f, "UpdatePicker({})", state.connectivity_state), + Self::RequestResolution => write!(f, "RequestResolution"), + Self::Connect(addr) => write!(f, "Connect({})", addr.address.to_string()), + Self::ScheduleWork => write!(f, "ScheduleWork"), + } + } +} + +/// A test channel controller that forwards calls to a channel. This allows +/// tests to verify when a channel controller is asked to create subchannels or +/// update the picker. +pub(crate) struct TestChannelController { + pub(crate) tx_events: mpsc::UnboundedSender, +} + +impl ChannelController for TestChannelController { + fn new_subchannel(&mut self, address: &Address) -> Arc { + println!("new_subchannel called for address {}", address); + let notify = Arc::new(Notify::new()); + let subchannel: Arc = + Arc::new(TestSubchannel::new(address.clone(), self.tx_events.clone())); + self.tx_events + .send(TestEvent::NewSubchannel(subchannel.clone())) + .unwrap(); + subchannel + } + fn update_picker(&mut self, update: LbState) { + println!("picker_update called with {}", update.connectivity_state); + self.tx_events + .send(TestEvent::UpdatePicker(update)) + .unwrap(); + } + fn request_resolution(&mut self) { + self.tx_events.send(TestEvent::RequestResolution).unwrap(); + } +} + +pub(crate) struct TestWorkScheduler { + pub(crate) tx_events: mpsc::UnboundedSender, +} + +impl WorkScheduler for TestWorkScheduler { + fn schedule_work(&self) { + self.tx_events.send(TestEvent::ScheduleWork).unwrap(); + } +} From 29163c2f47171c7c85d971b450b0008abbff2398 Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Wed, 13 Aug 2025 21:50:48 +0530 Subject: [PATCH 05/11] feat(grpc): Add tonic transport (#2339) --- codegen/src/main.rs | 13 + examples/Cargo.toml | 3 +- grpc/Cargo.toml | 60 +- grpc/examples/inmemory.rs | 3 - grpc/examples/multiaddr.rs | 3 - grpc/proto/echo/echo.proto | 43 ++ grpc/src/client/channel.rs | 21 +- .../client/load_balancing/child_manager.rs | 9 +- grpc/src/client/load_balancing/mod.rs | 2 + grpc/src/client/load_balancing/pick_first.rs | 11 +- grpc/src/client/load_balancing/test_utils.rs | 2 +- grpc/src/client/mod.rs | 3 +- grpc/src/client/name_resolution/dns/mod.rs | 4 +- grpc/src/client/name_resolution/dns/test.rs | 8 + grpc/src/client/subchannel.rs | 86 +-- grpc/src/client/transport/mod.rs | 47 +- grpc/src/client/transport/registry.rs | 41 +- grpc/src/client/transport/tonic/mod.rs | 277 +++++++++ grpc/src/client/transport/tonic/test.rs | 165 ++++++ grpc/src/codec.rs | 53 ++ grpc/src/generated/echo_fds.rs | 61 ++ grpc/src/generated/grpc_examples_echo.rs | 547 ++++++++++++++++++ grpc/src/inmemory/mod.rs | 58 +- grpc/src/lib.rs | 8 + grpc/src/rt/hyper_wrapper.rs | 158 +++++ grpc/src/rt/mod.rs | 68 ++- grpc/src/rt/tokio/mod.rs | 70 ++- grpc/src/service.rs | 8 +- 28 files changed, 1674 insertions(+), 158 deletions(-) create mode 100644 grpc/proto/echo/echo.proto create mode 100644 grpc/src/client/transport/tonic/mod.rs create mode 100644 grpc/src/client/transport/tonic/test.rs create mode 100644 grpc/src/codec.rs create mode 100644 grpc/src/generated/echo_fds.rs create mode 100644 grpc/src/generated/grpc_examples_echo.rs create mode 100644 grpc/src/rt/hyper_wrapper.rs diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 6cc193421..d0d4ef67c 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -68,6 +68,19 @@ fn main() { false, ); + // grpc + codegen( + &PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("grpc"), + &["proto/echo/echo.proto"], + &["proto"], + &PathBuf::from("src/generated"), + &PathBuf::from("src/generated/echo_fds.rs"), + true, + true, + ); println!("Codgen completed: {}ms", start.elapsed().as_millis()); } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index b2f6d0ccf..361ec8fcd 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -263,7 +263,6 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"] uds = ["dep:tokio-stream", "tokio-stream?/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["dep:tokio-stream", "dep:h2"] mock = ["dep:tokio-stream", "dep:tower", "dep:hyper-util"] -tower = ["dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls-ring"] @@ -273,7 +272,7 @@ types = ["dep:tonic-types"] h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] -full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "tls-client-auth", "types", "cancellation", "h2c"] +full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "json-codec", "compression", "tls", "tls-rustls", "tls-client-auth", "types", "cancellation", "h2c"] default = ["full"] [dependencies] diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index d709f622b..67dee1a98 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -6,39 +6,59 @@ authors = ["gRPC Authors"] license = "MIT" rust-version = "1.86" +[package.metadata.cargo_check_external_types] +allowed_external_types = [ + "tonic::*", + "futures_core::stream::Stream", + "tokio::sync::oneshot::Sender", +] + +[features] +default = ["dns", "_runtime-tokio"] +dns = ["dep:hickory-resolver", "_runtime-tokio"] +# The following feature is used to ensure all modules use the runtime +# abstraction instead of using tokio directly. +# Using tower/buffer enables tokio's rt feature even though it's possible to +# create Buffers with a user provided executor. +_runtime-tokio = [ + "tokio/rt", + "tokio/net", + "tokio/time", + "dep:socket2", + "dep:tower", +] + [dependencies] bytes = "1.10.1" hickory-resolver = { version = "0.25.1", optional = true } http = "1.1.0" http-body = "1.0.1" hyper = { version = "1.6.0", features = ["client", "http2"] } -hyper-util = "0.1.14" parking_lot = "0.12.4" pin-project-lite = "0.2.16" rand = "0.9" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" -socket2 = "0.5.10" -tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } -tokio-stream = "0.1.17" -tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen", "transport"] } -tower = "0.5.2" +socket2 = { version = "0.5.10", optional = true } +tokio = { version = "1.37.0", features = ["sync", "macros"] } +tokio-stream = { version = "0.1.17", default-features = false } +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = [ + "codegen", +] } +tower = { version = "0.5.2", features = [ + "limit", + "util", + "buffer", +], optional = true } tower-service = "0.3.3" url = "2.5.0" [dev-dependencies] async-stream = "0.3.6" -tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["server", "router"] } -hickory-server = "0.25.2" -prost = "0.14" - -[features] -default = ["dns"] -dns = ["dep:hickory-resolver"] - -[package.metadata.cargo_check_external_types] -allowed_external_types = [ - "tonic::*", - "futures_core::stream::Stream", - "tokio::sync::oneshot::Sender", -] +hickory-server = "0.25.2" +prost = "0.14.0" +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = [ + "server", + "router", +] } +tonic-prost = { version = "0.14.0", path = "../tonic-prost" } diff --git a/grpc/examples/inmemory.rs b/grpc/examples/inmemory.rs index 1ffc74b9d..88b17ee01 100644 --- a/grpc/examples/inmemory.rs +++ b/grpc/examples/inmemory.rs @@ -10,11 +10,8 @@ struct Handler {} #[derive(Debug)] struct MyReqMessage(String); -impl Message for MyReqMessage {} - #[derive(Debug)] struct MyResMessage(String); -impl Message for MyResMessage {} #[async_trait] impl Service for Handler { diff --git a/grpc/examples/multiaddr.rs b/grpc/examples/multiaddr.rs index 9fcc8f0ed..c631d33c5 100644 --- a/grpc/examples/multiaddr.rs +++ b/grpc/examples/multiaddr.rs @@ -12,11 +12,8 @@ struct Handler { #[derive(Debug)] struct MyReqMessage(String); -impl Message for MyReqMessage {} - #[derive(Debug)] struct MyResMessage(String); -impl Message for MyResMessage {} #[async_trait] impl Service for Handler { diff --git a/grpc/proto/echo/echo.proto b/grpc/proto/echo/echo.proto new file mode 100644 index 000000000..1ed1207f0 --- /dev/null +++ b/grpc/proto/echo/echo.proto @@ -0,0 +1,43 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +syntax = "proto3"; + +package grpc.examples.echo; + +// EchoRequest is the request for echo. +message EchoRequest { + string message = 1; +} + +// EchoResponse is the response for echo. +message EchoResponse { + string message = 1; +} + +// Echo is the echo service. +service Echo { + // UnaryEcho is unary echo. + rpc UnaryEcho(EchoRequest) returns (EchoResponse) {} + // ServerStreamingEcho is server side streaming. + rpc ServerStreamingEcho(EchoRequest) returns (stream EchoResponse) {} + // ClientStreamingEcho is client side streaming. + rpc ClientStreamingEcho(stream EchoRequest) returns (EchoResponse) {} + // BidirectionalStreamingEcho is bidi streaming. + rpc BidirectionalStreamingEcho(stream EchoRequest) returns (stream EchoResponse) {} +} diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index 6e41c6759..edbd55131 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -37,17 +37,16 @@ use std::{ }; use tokio::sync::{mpsc, oneshot, watch, Notify}; -use tokio::task::AbortHandle; use serde_json::json; use tonic::async_trait; use url::Url; // NOTE: http::Uri requires non-empty authority portion of URI -use crate::credentials::Credentials; +use crate::attributes::Attributes; use crate::rt; use crate::service::{Request, Response, Service}; -use crate::{attributes::Attributes, rt::tokio::TokioRuntime}; use crate::{client::ConnectivityState, rt::Runtime}; +use crate::{credentials::Credentials, rt::default_runtime}; use super::service_config::ServiceConfig; use super::transport::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; @@ -156,7 +155,7 @@ impl Channel { inner: Arc::new(PersistentChannel::new( target, credentials, - Arc::new(rt::tokio::TokioRuntime {}), + default_runtime(), options, )), } @@ -262,6 +261,7 @@ impl ActiveChannel { tx.clone(), picker.clone(), connectivity_state.clone(), + runtime.clone(), ); let resolver_helper = Box::new(tx.clone()); @@ -279,7 +279,7 @@ impl ActiveChannel { let resolver_opts = name_resolution::ResolverOptions { authority, work_scheduler, - runtime: Arc::new(TokioRuntime {}), + runtime: runtime.clone(), }; let resolver = rb.build(&target, resolver_opts); @@ -360,6 +360,7 @@ pub(crate) struct InternalChannelController { wqtx: WorkQueueTx, picker: Arc>>, connectivity_state: Arc>, + runtime: Arc, } impl InternalChannelController { @@ -369,8 +370,9 @@ impl InternalChannelController { wqtx: WorkQueueTx, picker: Arc>>, connectivity_state: Arc>, + runtime: Arc, ) -> Self { - let lb = Arc::new(GracefulSwitchBalancer::new(wqtx.clone())); + let lb = Arc::new(GracefulSwitchBalancer::new(wqtx.clone(), runtime.clone())); Self { lb, @@ -380,6 +382,7 @@ impl InternalChannelController { wqtx, picker, connectivity_state, + runtime, } } @@ -429,6 +432,7 @@ impl load_balancing::ChannelController for InternalChannelController { Box::new(move |k: SubchannelKey| { scp.unregister_subchannel(&k); }), + self.runtime.clone(), ); let _ = self.subchannel_pool.register_subchannel(&key, isc.clone()); self.new_esc_for_isc(isc) @@ -454,6 +458,7 @@ pub(super) struct GracefulSwitchBalancer { policy_builder: Mutex>>, work_scheduler: WorkQueueTx, pending: Mutex, + runtime: Arc, } impl WorkScheduler for GracefulSwitchBalancer { @@ -478,12 +483,13 @@ impl WorkScheduler for GracefulSwitchBalancer { } impl GracefulSwitchBalancer { - fn new(work_scheduler: WorkQueueTx) -> Self { + fn new(work_scheduler: WorkQueueTx, runtime: Arc) -> Self { Self { policy_builder: Mutex::default(), policy: Mutex::default(), // new(None::>), work_scheduler, pending: Mutex::default(), + runtime, } } @@ -501,6 +507,7 @@ impl GracefulSwitchBalancer { let builder = GLOBAL_LB_REGISTRY.get_policy(policy_name).unwrap(); let newpol = builder.build(LbPolicyOptions { work_scheduler: self.clone(), + runtime: self.runtime.clone(), }); *self.policy_builder.lock().unwrap() = Some(builder); *p = Some(newpol); diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index fea17bfe5..a868391bf 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -38,6 +38,7 @@ use crate::client::load_balancing::{ WeakSubchannel, WorkScheduler, }; use crate::client::name_resolution::{Address, ResolverUpdate}; +use crate::rt::Runtime; use super::{Subchannel, SubchannelState}; @@ -47,6 +48,7 @@ pub struct ChildManager { children: Vec>, update_sharder: Box>, pending_work: Arc>>, + runtime: Arc, } struct Child { @@ -81,12 +83,16 @@ pub trait ResolverUpdateSharder: Send { impl ChildManager { /// Creates a new ChildManager LB policy. shard_update is called whenever a /// resolver_update operation occurs. - pub fn new(update_sharder: Box>) -> Self { + pub fn new( + update_sharder: Box>, + runtime: Arc, + ) -> Self { Self { update_sharder, subchannel_child_map: Default::default(), children: Default::default(), pending_work: Default::default(), + runtime, } } @@ -197,6 +203,7 @@ impl LbPolicy for ChildManager }); let policy = builder.build(LbPolicyOptions { work_scheduler: work_scheduler.clone(), + runtime: self.runtime.clone(), }); let state = LbState::initial(); self.children.push(Child { diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index ad576e819..36f41affe 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -41,6 +41,7 @@ use tonic::{metadata::MetadataMap, Status}; use crate::{ client::channel::WorkQueueTx, + rt::Runtime, service::{Request, Response, Service}, }; @@ -66,6 +67,7 @@ pub struct LbPolicyOptions { /// A hook into the channel's work scheduler that allows the LbPolicy to /// request the ability to perform operations on the ChannelController. pub work_scheduler: Arc, + pub runtime: Arc, } /// Used to asynchronously request a call into the LbPolicy's work method if diff --git a/grpc/src/client/load_balancing/pick_first.rs b/grpc/src/client/load_balancing/pick_first.rs index 54ae78711..6ac901709 100644 --- a/grpc/src/client/load_balancing/pick_first.rs +++ b/grpc/src/client/load_balancing/pick_first.rs @@ -4,7 +4,6 @@ use std::{ time::Duration, }; -use tokio::time::sleep; use tonic::metadata::MetadataMap; use crate::{ @@ -13,6 +12,7 @@ use crate::{ name_resolution::{Address, ResolverUpdate}, subchannel, ConnectivityState, }, + rt::Runtime, service::Request, }; @@ -31,6 +31,7 @@ impl LbPolicyBuilder for Builder { work_scheduler: options.work_scheduler, subchannel: None, next_addresses: Vec::default(), + runtime: options.runtime, }) } @@ -47,6 +48,7 @@ struct PickFirstPolicy { work_scheduler: Arc, subchannel: Option>, next_addresses: Vec
, + runtime: Arc, } impl LbPolicy for PickFirstPolicy { @@ -72,11 +74,12 @@ impl LbPolicy for PickFirstPolicy { self.next_addresses = addresses; let work_scheduler = self.work_scheduler.clone(); + let runtime = self.runtime.clone(); // TODO: Implement Drop that cancels this task. - tokio::task::spawn(async move { - sleep(Duration::from_millis(200)).await; + self.runtime.spawn(Box::pin(async move { + runtime.sleep(Duration::from_millis(200)).await; work_scheduler.schedule_work(); - }); + })); // TODO: return a picker that queues RPCs. Ok(()) } diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index 93ffc5511..1a7e33fbf 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -32,8 +32,8 @@ use std::{fmt::Debug, ops::Add, sync::Arc}; use tokio::sync::{mpsc, Notify}; use tokio::task::AbortHandle; +#[derive(Debug)] pub(crate) struct EmptyMessage {} -impl Message for EmptyMessage {} pub(crate) fn new_request() -> Request { Request::new(Box::pin(tokio_stream::once( Box::new(EmptyMessage {}) as Box diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index 66c809e62..e896412ae 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -28,9 +28,8 @@ pub mod channel; pub(crate) mod load_balancing; pub(crate) mod name_resolution; pub mod service_config; -pub mod transport; - mod subchannel; +pub(crate) mod transport; pub use channel::Channel; pub use channel::ChannelOptions; diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 461b6d02b..6475d62c3 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -41,7 +41,7 @@ use url::Host; use crate::{ byte_str::ByteStr, client::name_resolution::{global_registry, ChannelController, ResolverBuilder, Target}, - rt::{self, TaskHandle}, + rt::{self, BoxedTaskHandle}, }; use super::{ @@ -243,7 +243,7 @@ impl ResolverBuilder for Builder { struct DnsResolver { state: Arc>, - task_handle: Box, + task_handle: BoxedTaskHandle, resolve_now_notifier: Arc, channel_update_notifier: Arc, } diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index beda2ea32..135e8ccfa 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -290,6 +290,14 @@ impl rt::Runtime for FakeRuntime { fn sleep(&self, duration: std::time::Duration) -> Pin> { self.inner.sleep(duration) } + + fn tcp_stream( + &self, + target: std::net::SocketAddr, + opts: rt::TcpOptions, + ) -> Pin, String>> + Send>> { + self.inner.tcp_stream(target, opts) + } } #[tokio::test] diff --git a/grpc/src/client/subchannel.rs b/grpc/src/client/subchannel.rs index d9bef839b..4b5c314fc 100644 --- a/grpc/src/client/subchannel.rs +++ b/grpc/src/client/subchannel.rs @@ -2,14 +2,20 @@ use super::{ channel::{InternalChannelController, WorkQueueTx}, load_balancing::{self, ExternalSubchannel, Picker, Subchannel, SubchannelState}, name_resolution::Address, - transport::{self, ConnectedTransport, Transport, TransportRegistry}, + transport::{self, Transport, TransportRegistry}, ConnectivityState, }; use crate::{ - client::{channel::WorkQueueItem, subchannel}, + client::{ + channel::WorkQueueItem, + subchannel, + transport::{ConnectedTransport, TransportOptions}, + }, + rt::{BoxedTaskHandle, Runtime}, service::{Request, Response, Service}, }; use core::panic; +use std::time::{Duration, Instant}; use std::{ collections::BTreeMap, error::Error, @@ -17,14 +23,10 @@ use std::{ ops::Sub, sync::{Arc, Mutex, RwLock, Weak}, }; -use tokio::{ - sync::{mpsc, watch, Notify}, - task::{AbortHandle, JoinHandle}, - time::{Duration, Instant}, -}; +use tokio::sync::{mpsc, oneshot, watch, Notify}; use tonic::async_trait; -type SharedService = Arc; +type SharedService = Arc; pub trait Backoff: Send + Sync { fn backoff_until(&self) -> Instant; @@ -52,16 +54,16 @@ enum InternalSubchannelState { } struct InternalSubchannelConnectingState { - abort_handle: Option, + abort_handle: Option, } struct InternalSubchannelReadyState { - abort_handle: Option, + abort_handle: Option, svc: SharedService, } struct InternalSubchannelTransientFailureState { - abort_handle: Option, + task_handle: Option, error: String, } @@ -163,7 +165,7 @@ impl Drop for InternalSubchannelState { } } Self::TransientFailure(st) => { - if let Some(ah) = &st.abort_handle { + if let Some(ah) = &st.task_handle { ah.abort(); } } @@ -178,13 +180,14 @@ pub(crate) struct InternalSubchannel { unregister_fn: Option>, state_machine_event_sender: mpsc::UnboundedSender, inner: Mutex, + runtime: Arc, } struct InnerSubchannel { state: InternalSubchannelState, watchers: Vec>, // TODO(easwars): Revisit the choice for this data structure. - backoff_task: Option>, - disconnect_task: Option>, + backoff_task: Option, + disconnect_task: Option, } #[async_trait] @@ -204,7 +207,7 @@ impl Service for InternalSubchannel { enum SubchannelStateMachineEvent { ConnectionRequested, - ConnectionSucceeded(SharedService), + ConnectionSucceeded(SharedService, oneshot::Receiver>), ConnectionTimedOut, ConnectionFailed(String), ConnectionTerminated, @@ -214,7 +217,7 @@ impl Debug for SubchannelStateMachineEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::ConnectionRequested => write!(f, "ConnectionRequested"), - Self::ConnectionSucceeded(_) => write!(f, "ConnectionSucceeded"), + Self::ConnectionSucceeded(_, _) => write!(f, "ConnectionSucceeded"), Self::ConnectionTimedOut => write!(f, "ConnectionTimedOut"), Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"), Self::ConnectionTerminated => write!(f, "ConnectionTerminated"), @@ -229,6 +232,7 @@ impl InternalSubchannel { transport: Arc, backoff: Arc, unregister_fn: Box, + runtime: Arc, ) -> Arc { println!("creating new internal subchannel for: {:?}", &key); let (tx, mut rx) = mpsc::unbounded_channel::(); @@ -244,6 +248,7 @@ impl InternalSubchannel { backoff_task: None, disconnect_task: None, }), + runtime: runtime.clone(), }); // This long running task implements the subchannel state machine. When @@ -251,7 +256,7 @@ impl InternalSubchannel { // closed, and therefore this task exits because rx.recv() returns None // in that case. let arc_to_self = Arc::clone(&isc); - tokio::task::spawn(async move { + runtime.spawn(Box::pin(async move { println!("starting subchannel state machine for: {:?}", &key); while let Some(m) = rx.recv().await { println!("subchannel {:?} received event {:?}", &key, &m); @@ -259,8 +264,8 @@ impl InternalSubchannel { SubchannelStateMachineEvent::ConnectionRequested => { arc_to_self.move_to_connecting(); } - SubchannelStateMachineEvent::ConnectionSucceeded(svc) => { - arc_to_self.move_to_ready(svc); + SubchannelStateMachineEvent::ConnectionSucceeded(svc, rx) => { + arc_to_self.move_to_ready(svc, rx); } SubchannelStateMachineEvent::ConnectionTimedOut => { arc_to_self.move_to_transient_failure("connect timeout expired".into()); @@ -277,7 +282,7 @@ impl InternalSubchannel { } } println!("exiting work queue task in subchannel"); - }); + })); isc } @@ -345,15 +350,19 @@ impl InternalSubchannel { let transport = self.transport.clone(); let address = self.address().address; let state_machine_tx = self.state_machine_event_sender.clone(); - let connect_task = tokio::task::spawn(async move { + // TODO: All these options to be configured by users. + let transport_opts = TransportOptions::default(); + let runtime = self.runtime.clone(); + + let connect_task = self.runtime.spawn(Box::pin(async move { tokio::select! { - _ = tokio::time::sleep(min_connect_timeout) => { + _ = runtime.sleep(min_connect_timeout) => { let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTimedOut); } - result = transport.connect(address.to_string().clone()) => { + result = transport.connect(address.to_string().clone(), runtime, &transport_opts) => { match result { Ok(s) => { - let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionSucceeded(Arc::from(s))); + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionSucceeded(Arc::from(s.service), s.disconnection_listener)); } Err(e) => { let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionFailed(e)); @@ -361,14 +370,14 @@ impl InternalSubchannel { } }, } - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::Connecting(InternalSubchannelConnectingState { - abort_handle: Some(connect_task.abort_handle()), + abort_handle: Some(connect_task), }); } - fn move_to_ready(&self, svc: SharedService) { + fn move_to_ready(&self, svc: SharedService, closed_rx: oneshot::Receiver>) { let svc2 = svc.clone(); { let mut inner = self.inner.lock().unwrap(); @@ -383,17 +392,19 @@ impl InternalSubchannel { }); let state_machine_tx = self.state_machine_event_sender.clone(); - let disconnect_task = tokio::task::spawn(async move { + let task_handle = self.runtime.spawn(Box::pin(async move { // TODO(easwars): Does it make sense for disconnected() to return an // error string containing information about why the connection // terminated? But what can we do with that error other than logging // it, which the transport can do as well? - svc.disconnected().await; + if let Err(e) = closed_rx.await { + eprintln!("Transport closed with error: {e}",) + }; let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTerminated); - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::Ready(InternalSubchannelReadyState { - abort_handle: Some(disconnect_task.abort_handle()), + abort_handle: Some(task_handle), svc: svc2.clone(), }); } @@ -403,7 +414,7 @@ impl InternalSubchannel { let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::TransientFailure( InternalSubchannelTransientFailureState { - abort_handle: None, + task_handle: None, error: err.clone(), }, ); @@ -417,14 +428,17 @@ impl InternalSubchannel { let backoff_interval = self.backoff.backoff_until(); let state_machine_tx = self.state_machine_event_sender.clone(); - let backoff_task = tokio::task::spawn(async move { - tokio::time::sleep_until(backoff_interval).await; + let runtime = self.runtime.clone(); + let backoff_task = self.runtime.spawn(Box::pin(async move { + runtime + .sleep(backoff_interval.saturating_duration_since(Instant::now())) + .await; let _ = state_machine_tx.send(SubchannelStateMachineEvent::BackoffExpired); - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::TransientFailure(InternalSubchannelTransientFailureState { - abort_handle: Some(backoff_task.abort_handle()), + task_handle: Some(backoff_task), error: err.clone(), }); } diff --git a/grpc/src/client/transport/mod.rs b/grpc/src/client/transport/mod.rs index 4c5b021b8..411a2954b 100644 --- a/grpc/src/client/transport/mod.rs +++ b/grpc/src/client/transport/mod.rs @@ -1,16 +1,49 @@ -use crate::service::Service; +use crate::{rt::Runtime, service::Service}; +use std::time::Instant; +use std::{sync::Arc, time::Duration}; mod registry; +// Using tower/buffer enables tokio's rt feature even though it's possible to +// create Buffers with a user provided executor. +#[cfg(feature = "_runtime-tokio")] +mod tonic; + use ::tonic::async_trait; -pub use registry::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; +pub(crate) use registry::TransportRegistry; +pub(crate) use registry::GLOBAL_TRANSPORT_REGISTRY; +use tokio::sync::oneshot; -#[async_trait] -pub trait Transport: Send + Sync { - async fn connect(&self, address: String) -> Result, String>; +pub(crate) struct ConnectedTransport { + pub service: Box, + pub disconnection_listener: oneshot::Receiver>, +} + +// TODO: The following options are specific to HTTP/2. We should +// instead pass an `Attribute` like struct to the connect method instead which +// can hold config relevant to a particular transport. +#[derive(Default)] +pub(crate) struct TransportOptions { + pub(crate) init_stream_window_size: Option, + pub(crate) init_connection_window_size: Option, + pub(crate) http2_keep_alive_interval: Option, + pub(crate) http2_keep_alive_timeout: Option, + pub(crate) http2_keep_alive_while_idle: Option, + pub(crate) http2_max_header_list_size: Option, + pub(crate) http2_adaptive_window: Option, + pub(crate) concurrency_limit: Option, + pub(crate) rate_limit: Option<(u64, Duration)>, + pub(crate) tcp_keepalive: Option, + pub(crate) tcp_nodelay: bool, + pub(crate) connect_deadline: Option, } #[async_trait] -pub trait ConnectedTransport: Service { - async fn disconnected(&self); +pub(crate) trait Transport: Send + Sync { + async fn connect( + &self, + address: String, + runtime: Arc, + opts: &TransportOptions, + ) -> Result; } diff --git a/grpc/src/client/transport/registry.rs b/grpc/src/client/transport/registry.rs index e5f7f7fe0..0b4f614ef 100644 --- a/grpc/src/client/transport/registry.rs +++ b/grpc/src/client/transport/registry.rs @@ -1,20 +1,17 @@ -use std::{ - collections::HashMap, - sync::{Arc, LazyLock, Mutex}, -}; - use super::Transport; +use std::sync::{Arc, LazyLock, Mutex}; +use std::{collections::HashMap, fmt::Debug}; /// A registry to store and retrieve transports. Transports are indexed by /// the address type they are intended to handle. -#[derive(Clone)] -pub struct TransportRegistry { - m: Arc>>>, +#[derive(Default, Clone)] +pub(crate) struct TransportRegistry { + inner: Arc>>>, } -impl std::fmt::Debug for TransportRegistry { +impl Debug for TransportRegistry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let m = self.m.lock().unwrap(); + let m = self.inner.lock().unwrap(); for key in m.keys() { write!(f, "k: {key:?}")? } @@ -24,21 +21,21 @@ impl std::fmt::Debug for TransportRegistry { impl TransportRegistry { /// Construct an empty name resolver registry. - pub fn new() -> Self { - Self { m: Arc::default() } + pub(crate) fn new() -> Self { + Self::default() } - /// Add a name resolver into the registry. - pub fn add_transport(&self, address_type: &str, transport: impl Transport + 'static) { - //let a: Arc = transport; - //let a: Arc> = transport; - self.m + + /// Add a transport into the registry. + pub(crate) fn add_transport(&self, address_type: &str, transport: impl Transport + 'static) { + self.inner .lock() .unwrap() .insert(address_type.to_string(), Arc::new(transport)); } + /// Retrieve a name resolver from the registry, or None if not found. - pub fn get_transport(&self, address_type: &str) -> Result, String> { - self.m + pub(crate) fn get_transport(&self, address_type: &str) -> Result, String> { + self.inner .lock() .unwrap() .get(address_type) @@ -49,12 +46,6 @@ impl TransportRegistry { } } -impl Default for TransportRegistry { - fn default() -> Self { - Self::new() - } -} - /// The registry used if a local registry is not provided to a channel or if it /// does not exist in the local registry. pub static GLOBAL_TRANSPORT_REGISTRY: LazyLock = diff --git a/grpc/src/client/transport/tonic/mod.rs b/grpc/src/client/transport/tonic/mod.rs new file mode 100644 index 000000000..11fcd24e1 --- /dev/null +++ b/grpc/src/client/transport/tonic/mod.rs @@ -0,0 +1,277 @@ +use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY; +use crate::client::transport::ConnectedTransport; +use crate::client::transport::Transport; +use crate::client::transport::TransportOptions; +use crate::codec::BytesCodec; +use crate::rt::hyper_wrapper::{HyperCompatExec, HyperCompatTimer, HyperStream}; +use crate::rt::BoxedTaskHandle; +use crate::rt::Runtime; +use crate::rt::TcpOptions; +use crate::service::Message; +use crate::service::Request as GrpcRequest; +use crate::service::Response as GrpcResponse; +use crate::{client::name_resolution::TCP_IP_NETWORK_TYPE, service::Service}; +use bytes::Bytes; +use http::uri::PathAndQuery; +use http::Request as HttpRequest; +use http::Response as HttpResponse; +use http::Uri; +use hyper::client::conn::http2::Builder; +use hyper::client::conn::http2::SendRequest; +use std::any::Any; +use std::task::{Context, Poll}; +use std::time::Instant; +use std::{error::Error, future::Future, net::SocketAddr, pin::Pin, str::FromStr, sync::Arc}; +use tokio::sync::oneshot; +use tokio_stream::Stream; +use tokio_stream::StreamExt; +use tonic::client::GrpcService; +use tonic::Request as TonicRequest; +use tonic::Response as TonicResponse; +use tonic::Streaming; +use tonic::{async_trait, body::Body, client::Grpc, Status}; +use tower::buffer::{future::ResponseFuture as BufferResponseFuture, Buffer}; +use tower::limit::{ConcurrencyLimitLayer, RateLimitLayer}; +use tower::{util::BoxService, ServiceBuilder}; +use tower_service::Service as TowerService; + +#[cfg(test)] +mod test; + +const DEFAULT_BUFFER_SIZE: usize = 1024; +pub(crate) type BoxError = Box; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; +type BoxStream = Pin> + Send>>; + +pub(crate) fn reg() { + GLOBAL_TRANSPORT_REGISTRY.add_transport(TCP_IP_NETWORK_TYPE, TransportBuilder {}); +} + +struct TransportBuilder {} + +struct TonicTransport { + grpc: Grpc, + task_handle: BoxedTaskHandle, +} + +impl Drop for TonicTransport { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +#[async_trait] +impl Service for TonicTransport { + async fn call(&self, method: String, request: GrpcRequest) -> GrpcResponse { + let Ok(path) = PathAndQuery::from_maybe_shared(method) else { + let err = Status::internal("Failed to parse path"); + return create_error_response(err); + }; + let mut grpc = self.grpc.clone(); + if let Err(e) = grpc.ready().await { + // TODO: Figure out the exact situations under which the service + // may return an error and re-evaluate the status code returned + // below. + let err = Status::unknown(format!("Service was not ready: {e}")); + return create_error_response(err); + }; + let request = convert_request(request); + let response = grpc.streaming(request, path, BytesCodec {}).await; + convert_response(response) + } +} + +/// Helper function to create an error response stream. +fn create_error_response(status: Status) -> GrpcResponse { + let stream = tokio_stream::once(Err(status)); + TonicResponse::new(Box::pin(stream)) +} + +fn convert_request(req: GrpcRequest) -> TonicRequest + Send>>> { + let (metadata, extensions, stream) = req.into_parts(); + + let bytes_stream = Box::pin(stream.filter_map(|msg| { + if let Ok(bytes) = (msg as Box).downcast::() { + Some(*bytes) + } else { + // If it fails, log the error and return None to filter it out. + eprintln!("A message could not be downcast to Bytes and was skipped."); + None + } + })); + + TonicRequest::from_parts(metadata, extensions, bytes_stream as _) +} + +fn convert_response(res: Result>, Status>) -> GrpcResponse { + let response = match res { + Ok(s) => s, + Err(e) => { + let stream = tokio_stream::once(Err(e)); + return TonicResponse::new(Box::pin(stream)); + } + }; + let (metadata, stream, extensions) = response.into_parts(); + let message_stream: BoxStream> = Box::pin(stream.map(|msg| { + msg.map(|b| { + let msg: Box = Box::new(b); + msg + }) + })); + TonicResponse::from_parts(metadata, message_stream, extensions) +} + +#[async_trait] +impl Transport for TransportBuilder { + async fn connect( + &self, + address: String, + runtime: Arc, + opts: &TransportOptions, + ) -> Result { + let runtime = runtime.clone(); + let mut settings = Builder::::new(HyperCompatExec { + inner: runtime.clone(), + }) + .timer(HyperCompatTimer { + inner: runtime.clone(), + }) + .initial_stream_window_size(opts.init_stream_window_size) + .initial_connection_window_size(opts.init_connection_window_size) + .keep_alive_interval(opts.http2_keep_alive_interval) + .clone(); + + if let Some(val) = opts.http2_keep_alive_timeout { + settings.keep_alive_timeout(val); + } + + if let Some(val) = opts.http2_keep_alive_while_idle { + settings.keep_alive_while_idle(val); + } + + if let Some(val) = opts.http2_adaptive_window { + settings.adaptive_window(val); + } + + if let Some(val) = opts.http2_max_header_list_size { + settings.max_header_list_size(val); + } + + let addr: SocketAddr = SocketAddr::from_str(&address).map_err(|err| err.to_string())?; + let tcp_stream_fut = runtime.tcp_stream( + addr, + TcpOptions { + enable_nodelay: opts.tcp_nodelay, + keepalive: opts.tcp_keepalive, + }, + ); + let tcp_stream = if let Some(deadline) = opts.connect_deadline { + let timeout = deadline.saturating_duration_since(Instant::now()); + tokio::select! { + _ = runtime.sleep(timeout) => { + return Err("timed out waiting for TCP stream to connect".to_string()) + } + tcp_stream = tcp_stream_fut => { tcp_stream? } + } + } else { + tcp_stream_fut.await? + }; + let tcp_stream = HyperStream::new(tcp_stream); + + let (sender, connection) = settings + .handshake(tcp_stream) + .await + .map_err(|err| err.to_string())?; + let (tx, rx) = oneshot::channel(); + + let task_handle = runtime.spawn(Box::pin(async move { + if let Err(err) = connection.await { + let _ = tx.send(Err(err.to_string())); + } else { + let _ = tx.send(Ok(())); + } + })); + let sender = SendRequestWrapper::from(sender); + + let service = ServiceBuilder::new() + .option_layer(opts.concurrency_limit.map(ConcurrencyLimitLayer::new)) + .option_layer(opts.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) + .map_err(Into::::into) + .service(sender); + + let service = BoxService::new(service); + let (service, worker) = Buffer::pair(service, DEFAULT_BUFFER_SIZE); + runtime.spawn(Box::pin(worker)); + let uri = + Uri::from_maybe_shared(format!("http://{}", &address)).map_err(|e| e.to_string())?; // TODO: err msg + let grpc = Grpc::with_origin(TonicService { inner: service }, uri); + + let service = TonicTransport { grpc, task_handle }; + Ok(ConnectedTransport { + service: Box::new(service), + disconnection_listener: rx, + }) + } +} + +struct SendRequestWrapper { + inner: SendRequest, +} + +impl From> for SendRequestWrapper { + fn from(inner: SendRequest) -> Self { + Self { inner } + } +} + +impl TowerService> for SendRequestWrapper { + type Response = HttpResponse; + type Error = BoxError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let fut = self.inner.send_request(req); + Box::pin(async move { fut.await.map_err(Into::into).map(|res| res.map(Body::new)) }) + } +} + +#[derive(Clone)] +struct TonicService { + inner: Buffer, BoxFuture<'static, Result, BoxError>>>, +} + +impl GrpcService for TonicService { + type ResponseBody = Body; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + tower::Service::poll_ready(&mut self.inner, cx) + } + + fn call(&mut self, request: http::Request) -> Self::Future { + ResponseFuture { + inner: tower::Service::call(&mut self.inner, request), + } + } +} + +/// A future that resolves to an HTTP response. +/// +/// This is returned by the `Service::call` on [`Channel`]. +pub struct ResponseFuture { + inner: BufferResponseFuture, BoxError>>>, +} + +impl Future for ResponseFuture { + type Output = Result, BoxError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs new file mode 100644 index 000000000..678280e34 --- /dev/null +++ b/grpc/src/client/transport/tonic/test.rs @@ -0,0 +1,165 @@ +use crate::client::name_resolution::TCP_IP_NETWORK_TYPE; +use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY; +use crate::echo_pb::echo_server::{Echo, EchoServer}; +use crate::echo_pb::{EchoRequest, EchoResponse}; +use crate::service::Message; +use crate::service::Request as GrpcRequest; +use crate::{client::transport::TransportOptions, rt::tokio::TokioRuntime}; +use bytes::Bytes; +use std::any::Any; +use std::{pin::Pin, sync::Arc, time::Duration}; +use tokio::net::TcpListener; +use tokio::sync::{mpsc, oneshot, Notify}; +use tokio::time::timeout; +use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; +use tonic::async_trait; +use tonic::{transport::Server, Request, Response, Status}; +use tonic_prost::prost::Message as ProstMessage; + +const DEFAULT_TEST_DURATION: Duration = Duration::from_secs(10); +const DEFAULT_TEST_SHORT_DURATION: Duration = Duration::from_millis(10); + +// Tests the tonic transport by creating a bi-di stream with a tonic server. +#[tokio::test] +pub async fn tonic_transport_rpc() { + super::reg(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); // get the assigned address + let shutdown_notify = Arc::new(Notify::new()); + let shutdown_notify_copy = shutdown_notify.clone(); + println!("EchoServer listening on: {addr}"); + let server_handle = tokio::spawn(async move { + let echo_server = EchoService {}; + let svc = EchoServer::new(echo_server); + let _ = Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_notify_copy.notified(), + ) + .await; + }); + + let builder = GLOBAL_TRANSPORT_REGISTRY + .get_transport(TCP_IP_NETWORK_TYPE) + .unwrap(); + let config = Arc::new(TransportOptions::default()); + let mut connected_transport = builder + .connect(addr.to_string(), Arc::new(TokioRuntime {}), &config) + .await + .unwrap(); + let conn = connected_transport.service; + + let (tx, rx) = mpsc::channel::>(1); + + // Convert the mpsc receiver into a Stream + let outbound: GrpcRequest = + Request::new(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))); + + let mut inbound = conn + .call( + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho".to_string(), + outbound, + ) + .await + .into_inner(); + + // Spawn a sender task + let client_handle = tokio::spawn(async move { + for i in 0..5 { + let message = format!("message {i}"); + let request = EchoRequest { + message: message.clone(), + }; + + let bytes = Bytes::from(request.encode_to_vec()); + + println!("Sent request: {request:?}"); + assert!(tx.send(Box::new(bytes)).await.is_ok(), "Receiver dropped"); + + // Wait for the reply + let resp = inbound + .next() + .await + .expect("server unexpectedly closed the stream!") + .expect("server returned error"); + + let bytes = (resp as Box).downcast::().unwrap(); + let echo_response = EchoResponse::decode(bytes).unwrap(); + println!("Got response: {echo_response:?}"); + assert_eq!(echo_response.message, message); + } + }); + + client_handle.await.unwrap(); + // The connection should break only after the server is stopped. + assert_eq!( + connected_transport.disconnection_listener.try_recv(), + Err(oneshot::error::TryRecvError::Empty), + ); + shutdown_notify.notify_waiters(); + let res = timeout( + DEFAULT_TEST_DURATION, + connected_transport.disconnection_listener, + ) + .await + .unwrap() + .unwrap(); + assert_eq!(res, Ok(())); + server_handle.await.unwrap(); +} + +#[derive(Debug)] +pub struct EchoService {} + +#[async_trait] +impl Echo for EchoService { + async fn unary_echo( + &self, + _: tonic::Request, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + + type ServerStreamingEchoStream = ReceiverStream>; + + async fn server_streaming_echo( + &self, + _: tonic::Request, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + + async fn client_streaming_echo( + &self, + _: tonic::Request>, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + type BidirectionalStreamingEchoStream = + Pin> + Send + 'static>>; + + async fn bidirectional_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status> + { + let mut inbound = request.into_inner(); + + // Map each request to a corresponding EchoResponse + let outbound = async_stream::try_stream! { + while let Some(req) = inbound.next().await { + let req = req?; // Return Err(Status) if stream item is error + let reply = EchoResponse { + message: req.message.clone(), + }; + yield reply; + } + println!("Server closing stream"); + }; + + Ok(Response::new( + Box::pin(outbound) as Self::BidirectionalStreamingEchoStream + )) + } +} diff --git a/grpc/src/codec.rs b/grpc/src/codec.rs new file mode 100644 index 000000000..eb9cc03e7 --- /dev/null +++ b/grpc/src/codec.rs @@ -0,0 +1,53 @@ +use bytes::{Buf, BufMut, Bytes}; +use tonic::{ + codec::{Codec, Decoder, EncodeBuf, Encoder}, + Status, +}; + +/// An adapter for sending and receiving messages as bytes using tonic. +/// Coding/decoding is handled within gRPC. +/// TODO: Remove this when tonic allows access to bytes without requiring a +/// codec. +pub(crate) struct BytesCodec {} + +impl Codec for BytesCodec { + type Encode = Bytes; + type Decode = Bytes; + type Encoder = BytesEncoder; + type Decoder = BytesDecoder; + + fn encoder(&mut self) -> Self::Encoder { + BytesEncoder {} + } + + fn decoder(&mut self) -> Self::Decoder { + BytesDecoder {} + } +} + +pub struct BytesEncoder {} + +impl Encoder for BytesEncoder { + type Item = Bytes; + type Error = Status; + + fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + dst.put_slice(&item); + Ok(()) + } +} + +#[derive(Debug)] +pub struct BytesDecoder {} + +impl Decoder for BytesDecoder { + type Item = Bytes; + type Error = Status; + + fn decode( + &mut self, + src: &mut tonic::codec::DecodeBuf<'_>, + ) -> Result, Self::Error> { + Ok(Some(src.copy_to_bytes(src.remaining()))) + } +} diff --git a/grpc/src/generated/echo_fds.rs b/grpc/src/generated/echo_fds.rs new file mode 100644 index 000000000..9833d2636 --- /dev/null +++ b/grpc/src/generated/echo_fds.rs @@ -0,0 +1,61 @@ +// This file is @generated by codegen. +// +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// +/// Byte encoded FILE_DESCRIPTOR_SET. +pub const FILE_DESCRIPTOR_SET: &[u8] = &[ + 10u8, 246u8, 3u8, 10u8, 15u8, 101u8, 99u8, 104u8, 111u8, 47u8, 101u8, 99u8, 104u8, + 111u8, 46u8, 112u8, 114u8, 111u8, 116u8, 111u8, 18u8, 18u8, 103u8, 114u8, 112u8, + 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, + 104u8, 111u8, 34u8, 39u8, 10u8, 11u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, + 117u8, 101u8, 115u8, 116u8, 18u8, 24u8, 10u8, 7u8, 109u8, 101u8, 115u8, 115u8, 97u8, + 103u8, 101u8, 24u8, 1u8, 32u8, 1u8, 40u8, 9u8, 82u8, 7u8, 109u8, 101u8, 115u8, 115u8, + 97u8, 103u8, 101u8, 34u8, 40u8, 10u8, 12u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, + 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 18u8, 24u8, 10u8, 7u8, 109u8, 101u8, 115u8, + 115u8, 97u8, 103u8, 101u8, 24u8, 1u8, 32u8, 1u8, 40u8, 9u8, 82u8, 7u8, 109u8, 101u8, + 115u8, 115u8, 97u8, 103u8, 101u8, 50u8, 243u8, 2u8, 10u8, 4u8, 69u8, 99u8, 104u8, + 111u8, 18u8, 78u8, 10u8, 9u8, 85u8, 110u8, 97u8, 114u8, 121u8, 69u8, 99u8, 104u8, + 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, + 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, + 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, 32u8, 46u8, 103u8, + 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, + 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 115u8, + 112u8, 111u8, 110u8, 115u8, 101u8, 18u8, 90u8, 10u8, 19u8, 83u8, 101u8, 114u8, 118u8, + 101u8, 114u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, 110u8, 103u8, 69u8, + 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, + 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, + 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, 32u8, 46u8, + 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, + 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, + 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 48u8, 1u8, 18u8, 90u8, 10u8, 19u8, 67u8, + 108u8, 105u8, 101u8, 110u8, 116u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, + 110u8, 103u8, 69u8, 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, + 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, + 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, + 115u8, 116u8, 26u8, 32u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, + 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, + 104u8, 111u8, 82u8, 101u8, 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 40u8, 1u8, 18u8, + 99u8, 10u8, 26u8, 66u8, 105u8, 100u8, 105u8, 114u8, 101u8, 99u8, 116u8, 105u8, 111u8, + 110u8, 97u8, 108u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, 110u8, 103u8, + 69u8, 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, + 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, + 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, + 32u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, + 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, + 101u8, 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 40u8, 1u8, 48u8, 1u8, 98u8, 6u8, + 112u8, 114u8, 111u8, 116u8, 111u8, 51u8, +]; diff --git a/grpc/src/generated/grpc_examples_echo.rs b/grpc/src/generated/grpc_examples_echo.rs new file mode 100644 index 000000000..5545928b0 --- /dev/null +++ b/grpc/src/generated/grpc_examples_echo.rs @@ -0,0 +1,547 @@ +// This file is @generated by prost-build. +/// EchoRequest is the request for echo. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct EchoRequest { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} +/// EchoResponse is the response for echo. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct EchoResponse { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} +/// Generated client implementations. +pub mod echo_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + /// Echo is the echo service. + #[derive(Debug, Clone)] + pub struct EchoClient { + inner: tonic::client::Grpc, + } + impl EchoClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> EchoClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + EchoClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// UnaryEcho is unary echo. + pub async fn unary_echo( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/UnaryEcho", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("grpc.examples.echo.Echo", "UnaryEcho")); + self.inner.unary(req, path, codec).await + } + /// ServerStreamingEcho is server side streaming. + pub async fn server_streaming_echo( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/ServerStreamingEcho", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("grpc.examples.echo.Echo", "ServerStreamingEcho"), + ); + self.inner.server_streaming(req, path, codec).await + } + /// ClientStreamingEcho is client side streaming. + pub async fn client_streaming_echo( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/ClientStreamingEcho", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("grpc.examples.echo.Echo", "ClientStreamingEcho"), + ); + self.inner.client_streaming(req, path, codec).await + } + /// BidirectionalStreamingEcho is bidi streaming. + pub async fn bidirectional_streaming_echo( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "grpc.examples.echo.Echo", + "BidirectionalStreamingEcho", + ), + ); + self.inner.streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod echo_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with EchoServer. + #[async_trait] + pub trait Echo: std::marker::Send + std::marker::Sync + 'static { + /// UnaryEcho is unary echo. + async fn unary_echo( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ServerStreamingEcho method. + type ServerStreamingEchoStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + std::marker::Send + + 'static; + /// ServerStreamingEcho is server side streaming. + async fn server_streaming_echo( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + /// ClientStreamingEcho is client side streaming. + async fn client_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the BidirectionalStreamingEcho method. + type BidirectionalStreamingEchoStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + std::marker::Send + + 'static; + /// BidirectionalStreamingEcho is bidi streaming. + async fn bidirectional_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + } + /// Echo is the echo service. + #[derive(Debug)] + pub struct EchoServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl EchoServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for EchoServer + where + T: Echo, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/grpc.examples.echo.Echo/UnaryEcho" => { + #[allow(non_camel_case_types)] + struct UnaryEchoSvc(pub Arc); + impl tonic::server::UnaryService + for UnaryEchoSvc { + type Response = super::EchoResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::unary_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = UnaryEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/ServerStreamingEcho" => { + #[allow(non_camel_case_types)] + struct ServerStreamingEchoSvc(pub Arc); + impl< + T: Echo, + > tonic::server::ServerStreamingService + for ServerStreamingEchoSvc { + type Response = super::EchoResponse; + type ResponseStream = T::ServerStreamingEchoStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::server_streaming_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ServerStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/ClientStreamingEcho" => { + #[allow(non_camel_case_types)] + struct ClientStreamingEchoSvc(pub Arc); + impl< + T: Echo, + > tonic::server::ClientStreamingService + for ClientStreamingEchoSvc { + type Response = super::EchoResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::client_streaming_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ClientStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.client_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho" => { + #[allow(non_camel_case_types)] + struct BidirectionalStreamingEchoSvc(pub Arc); + impl tonic::server::StreamingService + for BidirectionalStreamingEchoSvc { + type Response = super::EchoResponse; + type ResponseStream = T::BidirectionalStreamingEchoStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::bidirectional_streaming_echo(&inner, request) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = BidirectionalStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + let mut response = http::Response::new( + tonic::body::Body::default(), + ); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } + } + } + } + impl Clone for EchoServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "grpc.examples.echo.Echo"; + impl tonic::server::NamedService for EchoServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/grpc/src/inmemory/mod.rs b/grpc/src/inmemory/mod.rs index 48e97ceb5..b9dae99e0 100644 --- a/grpc/src/inmemory/mod.rs +++ b/grpc/src/inmemory/mod.rs @@ -1,11 +1,6 @@ -use std::{ - collections::HashMap, - ops::Add, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, LazyLock, - }, -}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, LazyLock, Mutex}; +use std::{collections::HashMap, ops::Add}; use crate::{ client::{ @@ -13,20 +8,22 @@ use crate::{ self, global_registry, Address, ChannelController, Endpoint, Resolver, ResolverBuilder, ResolverOptions, ResolverUpdate, }, - transport::{self, ConnectedTransport, GLOBAL_TRANSPORT_REGISTRY}, + transport::{self, ConnectedTransport, TransportOptions, GLOBAL_TRANSPORT_REGISTRY}, }, + rt::Runtime, server, service::{Request, Response, Service}, }; -use tokio::sync::{mpsc, oneshot, Mutex, Notify}; +use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex, Notify}; use tonic::async_trait; pub struct Listener { id: String, s: Box>>, - r: Arc>>>, + r: Arc>>>, // List of notifiers to call when closed. - closed: Notify, + #[allow(clippy::type_complexity)] + closed_tx: Arc>>>>, } static ID: AtomicU32 = AtomicU32::new(0); @@ -37,8 +34,8 @@ impl Listener { let s = Arc::new(Self { id: format!("{}", ID.fetch_add(1, Ordering::Relaxed)), s: Box::new(tx), - r: Arc::new(Mutex::new(rx)), - closed: Notify::new(), + r: Arc::new(AsyncMutex::new(rx)), + closed_tx: Arc::new(Mutex::new(Vec::new())), }); LISTENERS.lock().unwrap().insert(s.id.clone(), s.clone()); s @@ -59,7 +56,10 @@ impl Listener { impl Drop for Listener { fn drop(&mut self) { - self.closed.notify_waiters(); + let txs = std::mem::take(&mut *self.closed_tx.lock().unwrap()); + for rx in txs { + let _ = rx.send(Ok(())); + } LISTENERS.lock().unwrap().remove(&self.id); } } @@ -75,25 +75,17 @@ impl Service for Arc { } } -#[async_trait] -impl ConnectedTransport for Arc { - async fn disconnected(&self) { - self.closed.notified().await; - } -} - #[async_trait] impl crate::server::Listener for Arc { async fn accept(&self) -> Option { let mut recv = self.r.lock().await; let r = recv.recv().await; - r.as_ref()?; - r.unwrap() + // Listener may be closed. + r? } } -static LISTENERS: LazyLock>>> = - LazyLock::new(std::sync::Mutex::default); +static LISTENERS: LazyLock>>> = LazyLock::new(Mutex::default); struct ClientTransport {} @@ -105,14 +97,24 @@ impl ClientTransport { #[async_trait] impl transport::Transport for ClientTransport { - async fn connect(&self, address: String) -> Result, String> { + async fn connect( + &self, + address: String, + _: Arc, + _: &TransportOptions, + ) -> Result { let lis = LISTENERS .lock() .unwrap() .get(&address) .ok_or(format!("Could not find listener for address {address}"))? .clone(); - Ok(Box::new(lis)) + let (tx, rx) = oneshot::channel(); + lis.closed_tx.lock().unwrap().push(tx); + Ok(ConnectedTransport { + service: Box::new(lis), + disconnection_listener: rx, + }) } } diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index f56fd2cab..512adbc8f 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -41,3 +41,11 @@ pub mod service; pub(crate) mod attributes; pub(crate) mod byte_str; +pub(crate) mod codec; +#[cfg(test)] +pub(crate) mod echo_pb { + include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/generated/grpc_examples_echo.rs" + )); +} diff --git a/grpc/src/rt/hyper_wrapper.rs b/grpc/src/rt/hyper_wrapper.rs new file mode 100644 index 000000000..6bdaad48f --- /dev/null +++ b/grpc/src/rt/hyper_wrapper.rs @@ -0,0 +1,158 @@ +use super::{Runtime, TcpStream}; +use hyper::rt::{Executor, Timer}; +use pin_project_lite::pin_project; +use std::task::{Context, Poll}; +use std::{future::Future, io, pin::Pin, sync::Arc, time::Instant}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Adapts a runtime to a hyper compatible executor. +#[derive(Clone)] +pub(crate) struct HyperCompatExec { + pub(crate) inner: Arc, +} + +impl Executor for HyperCompatExec +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.spawn(Box::pin(async { + let _ = fut.await; + })); + } +} + +struct HyperCompatSleep { + inner: Pin>, +} + +impl Future for HyperCompatSleep { + type Output = (); + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.inner.as_mut().poll(cx) + } +} + +impl hyper::rt::Sleep for HyperCompatSleep {} + +/// Adapts a runtime to a hyper compatible timer. +pub(crate) struct HyperCompatTimer { + pub(crate) inner: Arc, +} + +impl Timer for HyperCompatTimer { + fn sleep(&self, duration: std::time::Duration) -> Pin> { + let sleep = self.inner.sleep(duration); + Box::pin(HyperCompatSleep { inner: sleep }) + } + + fn sleep_until(&self, deadline: Instant) -> Pin> { + let now = Instant::now(); + let duration = deadline.saturating_duration_since(now); + self.sleep(duration) + } +} + +// The following adapters are copied from hyper: +// https://github.com/hyperium/hyper/blob/v1.6.0/benches/support/tokiort.rs + +pin_project! { + /// A wrapper to make any `TcpStream` compatible with Hyper. It implements + /// Tokio's async IO traits. + pub(crate) struct HyperStream { + #[pin] + inner: Box, + } +} + +impl HyperStream { + /// Creates a new `HyperStream` from a type implementing `TcpStream`. + pub fn new(stream: Box) -> Self { + Self { inner: stream } + } +} + +impl AsyncRead for HyperStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Delegate the poll_read call to the inner stream. + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for HyperStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl hyper::rt::Read for HyperStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for HyperStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index 78accb53f..81d22ff7c 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -23,10 +23,14 @@ */ use ::tokio::io::{AsyncRead, AsyncWrite}; +use std::{future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; -use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; +pub(crate) mod hyper_wrapper; +#[cfg(feature = "_runtime-tokio")] +pub(crate) mod tokio; -pub mod tokio; +type BoxFuture = Pin + Send>>; +pub(crate) type BoxedTaskHandle = Box; /// An abstraction over an asynchronous runtime. /// @@ -37,10 +41,7 @@ pub mod tokio; /// and testable infrastructure. pub(super) trait Runtime: Send + Sync { /// Spawns the given asynchronous task to run in the background. - fn spawn( - &self, - task: Pin + Send + 'static>>, - ) -> Box; + fn spawn(&self, task: Pin + Send + 'static>>) -> BoxedTaskHandle; /// Creates and returns an instance of a DNSResolver, optionally /// configured by the ResolverOptions struct. This method may return an @@ -49,6 +50,14 @@ pub(super) trait Runtime: Send + Sync { /// Returns a future that completes after the specified duration. fn sleep(&self, duration: std::time::Duration) -> Pin>; + + /// Establishes a TCP connection to the given `target` address with the + /// specified `opts`. + fn tcp_stream( + &self, + target: SocketAddr, + opts: TcpOptions, + ) -> BoxFuture, String>>; } /// A future that resolves after a specified duration. @@ -77,7 +86,48 @@ pub(super) struct ResolverOptions { } #[derive(Default)] -pub struct TcpOptions { - pub enable_nodelay: bool, - pub keepalive: Option, +pub(crate) struct TcpOptions { + pub(crate) enable_nodelay: bool, + pub(crate) keepalive: Option, +} + +pub(crate) trait TcpStream: AsyncRead + AsyncWrite + Send + Unpin {} + +/// A fake runtime to satisfy the compiler when no runtime is enabled. This will +/// +/// # Panics +/// +/// Panics if any of its functions are called. +#[derive(Default)] +pub(crate) struct NoOpRuntime {} + +impl Runtime for NoOpRuntime { + fn spawn(&self, task: Pin + Send + 'static>>) -> BoxedTaskHandle { + unimplemented!() + } + + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String> { + unimplemented!() + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + unimplemented!() + } + + fn tcp_stream( + &self, + target: SocketAddr, + opts: TcpOptions, + ) -> Pin, String>> + Send>> { + unimplemented!() + } +} + +pub(crate) fn default_runtime() -> Arc { + #[cfg(feature = "_runtime-tokio")] + { + return Arc::new(tokio::TokioRuntime {}); + } + #[allow(unreachable_code)] + Arc::new(NoOpRuntime::default()) } diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs index b0a66ae39..8caec4cf3 100644 --- a/grpc/src/rt/tokio/mod.rs +++ b/grpc/src/rt/tokio/mod.rs @@ -31,10 +31,11 @@ use std::{ use tokio::{ io::{AsyncRead, AsyncWrite}, + net::TcpStream, task::JoinHandle, }; -use super::{DnsResolver, ResolverOptions, Runtime, Sleep, TaskHandle}; +use super::{BoxedTaskHandle, DnsResolver, ResolverOptions, Runtime, Sleep, TaskHandle}; #[cfg(feature = "dns")] mod hickory_resolver; @@ -74,10 +75,7 @@ impl TaskHandle for JoinHandle<()> { impl Sleep for tokio::time::Sleep {} impl Runtime for TokioRuntime { - fn spawn( - &self, - task: Pin + Send + 'static>>, - ) -> Box { + fn spawn(&self, task: Pin + Send + 'static>>) -> BoxedTaskHandle { Box::new(tokio::spawn(task)) } @@ -95,6 +93,28 @@ impl Runtime for TokioRuntime { fn sleep(&self, duration: Duration) -> Pin> { Box::pin(tokio::time::sleep(duration)) } + + fn tcp_stream( + &self, + target: SocketAddr, + opts: super::TcpOptions, + ) -> Pin, String>> + Send>> { + Box::pin(async move { + let stream = TcpStream::connect(target) + .await + .map_err(|err| err.to_string())?; + if let Some(duration) = opts.keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let mut ka = socket2::TcpKeepalive::new(); + ka = ka.with_time(duration); + sock_ref + .set_tcp_keepalive(&ka) + .map_err(|err| err.to_string())?; + } + let stream: Box = Box::new(TokioTcpStream { inner: stream }); + Ok(stream) + }) + } } impl TokioDefaultDnsResolver { @@ -106,6 +126,46 @@ impl TokioDefaultDnsResolver { } } +struct TokioTcpStream { + inner: TcpStream, +} + +impl AsyncRead for TokioTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for TokioTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl super::TcpStream for TokioTcpStream {} + #[cfg(test)] mod tests { use super::{DnsResolver, ResolverOptions, Runtime, TokioDefaultDnsResolver, TokioRuntime}; diff --git a/grpc/src/service.rs b/grpc/src/service.rs index b16f9c9b7..64d02ed17 100644 --- a/grpc/src/service.rs +++ b/grpc/src/service.rs @@ -22,14 +22,14 @@ * */ -use std::{any::Any, pin::Pin}; +use std::{any::Any, fmt::Debug, pin::Pin}; use tokio_stream::Stream; use tonic::{async_trait, Request as TonicRequest, Response as TonicResponse, Status}; pub type Request = TonicRequest> + Send + Sync>>>; pub type Response = - TonicResponse, Status>> + Send + Sync>>>; + TonicResponse, Status>> + Send>>>; #[async_trait] pub trait Service: Send + Sync { @@ -37,4 +37,6 @@ pub trait Service: Send + Sync { } // TODO: define methods that will allow serialization/deserialization. -pub trait Message: Any + Send + Sync {} +pub trait Message: Any + Send + Sync + Debug {} + +impl Message for T where T: Any + Send + Sync + Debug {} From c6e42ebe425f1afe70da41c0b387e2b0807f4672 Mon Sep 17 00:00:00 2001 From: Shaun Houlihan <1269284+Shaun1@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:59:05 -0600 Subject: [PATCH 06/11] Update helloworld-tutorial.md (#2385) Update the hello-world tutorial readme to reflect changes to package names and remove reference to deprecated tools. ## Motivation As someone relatively new to Rust and new to Tonic, I was working through the hello-world tutorial and found a few issues. I've successfully built the project with the changes in this PR. ## Solution - update example Cargo.toml to reflect package name updates - update example build.rs with correct build tool name - remove reference to deprecated GUI tool I've tested this and it runs and builds correctly. I sort of hate to reference Postman as the recommended GUI but it was the first one on the list referenced by the now-deprecated Bloom GUI. --- examples/helloworld-tutorial.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/helloworld-tutorial.md b/examples/helloworld-tutorial.md index 972509a5f..bd6614e20 100644 --- a/examples/helloworld-tutorial.md +++ b/examples/helloworld-tutorial.md @@ -114,10 +114,11 @@ path = "src/client.rs" [dependencies] tonic = "*" prost = "0.14" +tonic-prost = "*" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } [build-dependencies] -tonic-build = "*" +tonic-prost-build = "*" ``` We include `tonic-build` as a useful way to incorporate the generation of our client and server gRPC code into the build process of our application. We will setup this build process now: @@ -128,7 +129,7 @@ At the root of your project (not /src), create a `build.rs` file and add the fol ```rust fn main() -> Result<(), Box> { - tonic_build::compile_protos("proto/helloworld.proto")?; + tonic_prost_build::compile_protos("proto/helloworld.proto")?; Ok(()) } ``` @@ -239,7 +240,7 @@ async fn main() -> Result<(), Box> { You should now be able to run your HelloWorld gRPC server using the command `cargo run --bin helloworld-server`. This uses the [[bin]] we defined earlier in our `Cargo.toml` to run specifically the server. -If you have a gRPC GUI client such as [Bloom RPC] you should be able to send requests to the server and get back greetings! +If you have a gRPC GUI client such as [Postman] you should be able to send requests to the server and get back greetings! Or if you use [grpcurl] then you can simply try send requests like this: ``` @@ -252,7 +253,7 @@ And receiving responses like this: } ``` -[bloom rpc]: https://github.com/uw-labs/bloomrpc +[postman]: https://www.postman.com/ [grpcurl]: https://github.com/fullstorydev/grpcurl ## Writing our Client From dee02feaa7caf1fa8559d0fcbe9a5c8fb5e63437 Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Tue, 19 Aug 2025 01:15:14 +0530 Subject: [PATCH 07/11] chore(ci): pin protoc version (#2389) Fixes: https://github.com/hyperium/tonic/issues/2386 Interop tests are failing since the [protobuf codegen requires the same version as protoc](https://github.com/protocolbuffers/protobuf/blob/2ae8154f366a6d776bcc3ac931413bc4d99f578e/rust/release_crates/protobuf_codegen/src/lib.rs#L136-L148) to work. This PR pins the protoc version to fix the tests. --- .github/workflows/CI.yml | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 626fde513..608a885d8 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -82,7 +82,9 @@ jobs: - uses: hecrj/setup-rust-action@v2 with: components: clippy - - uses: taiki-e/install-action@protoc + - uses: taiki-e/install-action@v2 + with: + tool: protoc@3.31.1 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -115,7 +117,9 @@ jobs: toolchain: nightly-2025-03-27 - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@cargo-udeps - - uses: taiki-e/install-action@protoc + - uses: taiki-e/install-action@v2 + with: + tool: protoc@3.31.1 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -147,7 +151,9 @@ jobs: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@cargo-hack - - uses: taiki-e/install-action@protoc + - uses: taiki-e/install-action@v2 + with: + tool: protoc@3.31.1 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -198,7 +204,9 @@ jobs: steps: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 - - uses: taiki-e/install-action@protoc + - uses: taiki-e/install-action@v2 + with: + tool: protoc@3.31.1 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -235,7 +243,9 @@ jobs: steps: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 - - uses: taiki-e/install-action@protoc + - uses: taiki-e/install-action@v2 + with: + tool: protoc@3.31.1 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 From 13fae2fefb12d0749a9fbde8457828ed9ab8fc4d Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Wed, 20 Aug 2025 00:42:49 +0530 Subject: [PATCH 08/11] chore: update protobuf to 4.32.0 (#2391) --- .github/workflows/CI.yml | 10 +++++----- interop/Cargo.toml | 2 +- tonic-protobuf-build/Cargo.toml | 2 +- tonic-protobuf-build/src/lib.rs | 2 -- tonic-protobuf/Cargo.toml | 2 +- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 608a885d8..273f9e495 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -84,7 +84,7 @@ jobs: components: clippy - uses: taiki-e/install-action@v2 with: - tool: protoc@3.31.1 + tool: protoc@3.32.0 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -119,7 +119,7 @@ jobs: - uses: taiki-e/install-action@cargo-udeps - uses: taiki-e/install-action@v2 with: - tool: protoc@3.31.1 + tool: protoc@3.32.0 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -153,7 +153,7 @@ jobs: - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@v2 with: - tool: protoc@3.31.1 + tool: protoc@3.32.0 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -206,7 +206,7 @@ jobs: - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@v2 with: - tool: protoc@3.31.1 + tool: protoc@3.32.0 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 @@ -245,7 +245,7 @@ jobs: - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@v2 with: - tool: protoc@3.31.1 + tool: protoc@3.32.0 - name: Restore protoc plugin from cache id: cache-plugin uses: actions/cache@v4 diff --git a/interop/Cargo.toml b/interop/Cargo.toml index dfc69aa38..957eb8af9 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -32,7 +32,7 @@ grpc = {path = "../grpc"} # We also need the protobuf-codegen crate to support configuring the path # to the protobuf crate used in the generated message code, instead of # defaulting to `::protobuf`. -protobuf = { version = "4.31.1-release" } +protobuf = { version = "4.32.0-release" } tonic-protobuf = {path = "../tonic-protobuf"} [build-dependencies] diff --git a/tonic-protobuf-build/Cargo.toml b/tonic-protobuf-build/Cargo.toml index 744da8d61..0b31beb87 100644 --- a/tonic-protobuf-build/Cargo.toml +++ b/tonic-protobuf-build/Cargo.toml @@ -8,5 +8,5 @@ publish = false [dependencies] prettyplease = "0.2.35" -protobuf-codegen = { version = "4.31.1-release" } +protobuf-codegen = { version = "4.32.0-release" } syn = "2.0.104" diff --git a/tonic-protobuf-build/src/lib.rs b/tonic-protobuf-build/src/lib.rs index b3a54b69f..a8049091c 100644 --- a/tonic-protobuf-build/src/lib.rs +++ b/tonic-protobuf-build/src/lib.rs @@ -95,8 +95,6 @@ impl From<&Dependency> for protobuf_codegen::Dependency { protobuf_codegen::Dependency { crate_name: val.crate_name.clone(), proto_import_paths: val.proto_import_paths.clone(), - // The following field is not used by protobuf codegen. - c_include_paths: Vec::new(), proto_files: val.proto_files.clone(), } } diff --git a/tonic-protobuf/Cargo.toml b/tonic-protobuf/Cargo.toml index c573f1062..9d1ccfb7d 100644 --- a/tonic-protobuf/Cargo.toml +++ b/tonic-protobuf/Cargo.toml @@ -9,7 +9,7 @@ publish = false [dependencies] tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen"] } bytes = "1" -protobuf = { version = "4.31.1-release" } +protobuf = { version = "4.32.0-release" } [package.metadata.cargo_check_external_types] allowed_external_types = [ From 688522ae0f873a1cf4ccd6dcda16fccfee80da49 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Wed, 20 Aug 2025 10:02:27 -0600 Subject: [PATCH 09/11] chore: use local protoc install (#2392) --- flake.lock | 24 ++++++++++++------------ flake.nix | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/flake.lock b/flake.lock index a65081b5f..035b4272b 100644 --- a/flake.lock +++ b/flake.lock @@ -8,11 +8,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1754635447, - "narHash": "sha256-lslpNlJacd38xcTjS0j2CamRQ1XBbT8CEcVPBXTUFJQ=", + "lastModified": 1755585599, + "narHash": "sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY=", "owner": "nix-community", "repo": "fenix", - "rev": "8eff3e316c68c36eb783f49a13bb500755c4b544", + "rev": "6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42", "type": "github" }, "original": { @@ -64,11 +64,11 @@ ] }, "locked": { - "lastModified": 1754416808, - "narHash": "sha256-c6yg0EQ9xVESx6HGDOCMcyRSjaTpNJP10ef+6fRcofA=", + "lastModified": 1755446520, + "narHash": "sha256-I0Ok1OGDwc1jPd8cs2VvAYZsHriUVFGIUqW+7uSsOUM=", "owner": "cachix", "repo": "git-hooks.nix", - "rev": "9c52372878df6911f9afc1e2a1391f55e4dfc864", + "rev": "4b04db83821b819bbbe32ed0a025b31e7971f22e", "type": "github" }, "original": { @@ -100,11 +100,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1754498491, - "narHash": "sha256-erbiH2agUTD0Z30xcVSFcDHzkRvkRXOQ3lb887bcVrs=", + "lastModified": 1755186698, + "narHash": "sha256-wNO3+Ks2jZJ4nTHMuks+cxAiVBGNuEBXsT29Bz6HASo=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "c2ae88e026f9525daf89587f3cbee584b92b6134", + "rev": "fbcf476f790d8a217c3eab4e12033dc4a0f6d23c", "type": "github" }, "original": { @@ -140,11 +140,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1754573113, - "narHash": "sha256-rGSOEooq0u38dzvn677G14DGm3ue9UDQ9c1E4qyfcuU=", + "lastModified": 1755504847, + "narHash": "sha256-VX0B9hwhJypCGqncVVLC+SmeMVd/GAYbJZ0MiiUn2Pk=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "caef0f46fd97175e6c1894189689c098c4cb536f", + "rev": "a905e3b21b144d77e1b304e49f3264f6f8d4db75", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 99660c7f5..49a4947e3 100644 --- a/flake.nix +++ b/flake.nix @@ -92,13 +92,13 @@ "rustfmt" "rust-analyzer" ]) - protobuf + # protobuf ]; hardeningDisable = [ "fortify" ]; shellHook = '' - export PATH="$PWD/protoc-gen-rust-grpc/bazel-bin/src:$PATH" + export PATH="$PWD/protoc-gen-rust-grpc/bazel-bin/src:$HOME/code/install/bin:$PATH" ${config.pre-commit.installationScript} ''; From 61bb3daa50cd4decbb6095abc6703b454eb4b646 Mon Sep 17 00:00:00 2001 From: victor Date: Thu, 21 Aug 2025 16:36:53 +0100 Subject: [PATCH 10/11] transport: unify request modifiers and reduce allocations --- tonic/src/transport/channel/endpoint.rs | 4 + .../transport/channel/service/add_origin.rs | 69 ------ .../transport/channel/service/connection.rs | 23 +- tonic/src/transport/channel/service/mod.rs | 7 +- .../channel/service/request_modifiers.rs | 228 ++++++++++++++++++ .../transport/channel/service/user_agent.rs | 159 ------------ 6 files changed, 248 insertions(+), 242 deletions(-) delete mode 100644 tonic/src/transport/channel/service/add_origin.rs create mode 100644 tonic/src/transport/channel/service/request_modifiers.rs delete mode 100644 tonic/src/transport/channel/service/user_agent.rs diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index f5c386899..92ea86f3a 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -585,6 +585,10 @@ impl Endpoint { pub fn get_tcp_keepalive_retries(&self) -> Option { self.tcp_keepalive_retries } + + pub(crate) fn get_origin(&self) -> &Uri { + self.origin.as_ref().unwrap_or(self.uri()) + } } impl From for Endpoint { diff --git a/tonic/src/transport/channel/service/add_origin.rs b/tonic/src/transport/channel/service/add_origin.rs deleted file mode 100644 index 9e3791de7..000000000 --- a/tonic/src/transport/channel/service/add_origin.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::transport::channel::BoxFuture; -use http::uri::Authority; -use http::uri::Scheme; -use http::{Request, Uri}; -use std::task::{Context, Poll}; -use tower_service::Service; - -#[derive(Debug)] -pub(crate) struct AddOrigin { - inner: T, - scheme: Option, - authority: Option, -} - -impl AddOrigin { - pub(crate) fn new(inner: T, origin: Uri) -> Self { - let http::uri::Parts { - scheme, authority, .. - } = origin.into_parts(); - - Self { - inner, - scheme, - authority, - } - } -} - -impl Service> for AddOrigin -where - T: Service>, - T::Future: Send + 'static, - T::Error: Into, -{ - type Response = T::Response; - type Error = crate::BoxError; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) - } - - fn call(&mut self, req: Request) -> Self::Future { - if self.scheme.is_none() || self.authority.is_none() { - let err = crate::transport::Error::new_invalid_uri(); - return Box::pin(async move { Err::(err.into()) }); - } - - // Split the request into the head and the body. - let (mut head, body) = req.into_parts(); - - // Update the request URI - head.uri = { - // Split the request URI into parts. - let mut uri: http::uri::Parts = head.uri.into(); - // Update the URI parts, setting the scheme and authority - uri.scheme = self.scheme.clone(); - uri.authority = self.authority.clone(); - - http::Uri::from_parts(uri).expect("valid uri") - }; - - let request = Request::from_parts(head, body); - - let fut = self.inner.call(request); - - Box::pin(async move { fut.await.map_err(Into::into) }) - } -} diff --git a/tonic/src/transport/channel/service/connection.rs b/tonic/src/transport/channel/service/connection.rs index c4ce9408e..bc2f0431d 100644 --- a/tonic/src/transport/channel/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -1,4 +1,5 @@ use super::{AddOrigin, Reconnect, SharedExec, UserAgent}; +use crate::transport::channel::service::Modifier; use crate::{ body::Body, transport::{channel::BoxFuture, service::GrpcTimeout, Endpoint}, @@ -25,7 +26,7 @@ pub(crate) struct Connection { } impl Connection { - fn new(connector: C, endpoint: Endpoint, is_lazy: bool) -> Self + fn new(connector: C, endpoint: Endpoint, is_lazy: bool) -> Result where C: Service + Send + 'static, C::Error: Into + Send, @@ -55,13 +56,17 @@ impl Connection { settings.max_header_list_size(val); } + // We shift detecting abscence of both scheme and authority here + let add_origin = AddOrigin::new(endpoint.get_origin())?; let stack = ServiceBuilder::new() .layer_fn(|s| { - let origin = endpoint.origin.as_ref().unwrap_or(endpoint.uri()).clone(); - - AddOrigin::new(s, origin) + // The clone here is just &Uri + Modifier::new(s, add_origin.clone().into_fn()) + }) + .layer_fn(|s| { + let ua = UserAgent::new(endpoint.user_agent.clone()); + Modifier::new(s, ua.into_fn()) }) - .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) .layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout)) .option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) @@ -72,9 +77,9 @@ impl Connection { let conn = Reconnect::new(make_service, endpoint.uri().clone(), is_lazy); - Self { + Ok(Self { inner: BoxService::new(stack.layer(conn)), - } + }) } pub(crate) async fn connect( @@ -87,7 +92,7 @@ impl Connection { C::Future: Unpin + Send, C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - Self::new(connector, endpoint, false).ready_oneshot().await + Self::new(connector, endpoint, false)?.ready_oneshot().await } pub(crate) fn lazy(connector: C, endpoint: Endpoint) -> Self @@ -97,7 +102,7 @@ impl Connection { C::Future: Send, C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - Self::new(connector, endpoint, true) + Self::new(connector, endpoint, true).expect("Endpoint origin scheme and authority are set") } } diff --git a/tonic/src/transport/channel/service/mod.rs b/tonic/src/transport/channel/service/mod.rs index e4306f033..0c89fe4fd 100644 --- a/tonic/src/transport/channel/service/mod.rs +++ b/tonic/src/transport/channel/service/mod.rs @@ -1,8 +1,5 @@ -mod add_origin; -use self::add_origin::AddOrigin; - -mod user_agent; -use self::user_agent::UserAgent; +mod request_modifiers; +use self::request_modifiers::*; mod reconnect; use self::reconnect::Reconnect; diff --git a/tonic/src/transport/channel/service/request_modifiers.rs b/tonic/src/transport/channel/service/request_modifiers.rs new file mode 100644 index 000000000..1cbf13e80 --- /dev/null +++ b/tonic/src/transport/channel/service/request_modifiers.rs @@ -0,0 +1,228 @@ +use http::{header::USER_AGENT, HeaderValue, Request, Uri}; +use std::task::{Context, Poll}; +use tower_service::Service; +use crate::body::Body; + +#[derive(Debug)] +pub(crate) struct Modifier { + modifier_fn: M, + next: T, +} + +impl Modifier { + pub(crate) fn new(next: T, modifier_fn: M) -> Self { + Self { next, modifier_fn } + } +} + +impl Service> for Modifier +where + T: Service>, + M: FnOnce(Request) -> Request + Clone, + Body: Send + 'static, +{ + type Response = T::Response; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let modifier_fn = self.modifier_fn.clone(); + self.next.call(modifier_fn(req)) + } +} + +// We're borrowing to avoid cloning the Uri more than once in layer which expects Fn +// and not FnOnce +#[derive(Debug, Clone)] +pub(crate) struct AddOrigin<'a> { + origin: &'a Uri +} + +impl<'a> AddOrigin<'a> { + pub(crate) fn new(origin: &'a Uri) -> Result { + // We catch error right at initiation... This single line + // eliminates countless heap allocations at `runtime` + if origin.scheme().is_none() || origin.authority().is_none() { + return Err(crate::transport::Error::new_invalid_uri().into()); + } + + Ok(Self { origin }) + } + + pub(crate) fn into_fn( + self, + ) -> impl FnOnce(Request) -> Request + Clone { + let http::uri::Parts { + scheme, authority, .. + } = self.origin.clone().into_parts(); + + // Both have been checked + let scheme = scheme.unwrap(); + let authority = authority.unwrap(); + + move |req| { + // Split the request into the head and the body. + let (mut head, body) = req.into_parts(); + + // Update the request URI + head.uri = { + // Split the request URI into parts. + let mut uri: http::uri::Parts = head.uri.into(); + // Update the URI parts, setting the scheme and authority + uri.scheme = Some(scheme); + uri.authority = Some(authority); + + http::Uri::from_parts(uri).expect("valid uri") + }; + + Request::from_parts(head, body) + } + } +} + +const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION")); + +#[derive(Debug)] +pub(crate) struct UserAgent { + user_agent: HeaderValue, +} + +impl UserAgent { + pub(crate) fn new(user_agent: Option) -> Self { + let user_agent = user_agent + .map(|value| { + let mut buf = Vec::new(); + buf.extend(value.as_bytes()); + buf.push(b' '); + buf.extend(TONIC_USER_AGENT.as_bytes()); + HeaderValue::from_bytes(&buf).expect("user-agent should be valid") + }) + .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT)); + + Self { user_agent } + } + + pub(crate) fn into_fn( + self, + ) -> impl FnOnce(Request) -> Request + Clone { + move |mut req| { + use http::header::Entry; + + // The former code uses try_insert so we'll respect that + if let Ok(entry) = req.headers_mut().try_entry(USER_AGENT) { + // This is to avoid anticipative cloning which happened + // in the former code + match entry { + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(self.user_agent); + } + Entry::Occupied(occupied_entry) => { + // The User-Agent header has already been set on the request. Let's + // append our user agent to the end. + let occupied_entry = occupied_entry.into_mut(); + + let mut buf = + Vec::with_capacity(occupied_entry.len() + 1 + self.user_agent.len()); + buf.extend(occupied_entry.as_bytes()); + buf.push(b' '); + buf.extend(self.user_agent.as_bytes()); + + // with try_into http uses from_shared internally to probably minimize + // allocations + *occupied_entry = buf.try_into().expect("user-agent should be valid") + } + } + } + + req + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sets_default_if_no_custom_user_agent() { + assert_eq!( + UserAgent::new(None).user_agent, + HeaderValue::from_static(TONIC_USER_AGENT) + ) + } + + #[test] + fn prepends_custom_user_agent_to_default() { + assert_eq!( + UserAgent::new(Some(HeaderValue::from_static("Greeter 1.1"))).user_agent, + HeaderValue::from_str(&format!("Greeter 1.1 {TONIC_USER_AGENT}")).unwrap() + ) + } + + async fn assert_user_agent_modified( + genesis_user_agent: Option>, + expected_user_agent: impl TryInto, + request: Option>, + ) { + let ua = UserAgent::new(genesis_user_agent.map(|v| { + v.try_into() + .unwrap_or_else(|_| panic!("invalid header value")) + })) + .into_fn(); + + let modified_request = ua(request.unwrap_or_default()); + let user_agent = modified_request.headers().get(USER_AGENT).unwrap(); + assert_eq!( + user_agent, + expected_user_agent + .try_into() + .unwrap_or_else(|_| panic!("invalid header value")) + ); + } + + #[tokio::test] + async fn sets_default_user_agent_if_none_present() { + let genesis_user_agent = Option::<&str>::None; + let expected_user_agent = TONIC_USER_AGENT.to_string(); + let request = None; + + assert_user_agent_modified(genesis_user_agent, expected_user_agent, request).await + } + + #[tokio::test] + async fn sets_custom_user_agent_if_none_present() { + let genesis_user_agent = Some("Greeter 1.1"); + let expected_user_agent = format!("Greeter 1.1 {TONIC_USER_AGENT}"); + let request = None; + + assert_user_agent_modified(genesis_user_agent, expected_user_agent, request).await + } + + #[tokio::test] + async fn appends_default_user_agent_to_request_fn_user_agent() { + let genesis_user_agent = Option::<&str>::None; + let expected_user_agent = format!("request-ua/x.y {TONIC_USER_AGENT}"); + let mut request = Request::default(); + request + .headers_mut() + .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y")); + + assert_user_agent_modified(genesis_user_agent, expected_user_agent, Some(request)).await + } + + #[tokio::test] + async fn appends_custom_user_agent_to_request_fn_user_agent() { + let genesis_user_agent = Some("Greeter 1.1"); + let expected_user_agent = format!("request-ua/x.y Greeter 1.1 {TONIC_USER_AGENT}"); + let mut request = Request::default(); + request + .headers_mut() + .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y")); + + assert_user_agent_modified(genesis_user_agent, expected_user_agent, Some(request)).await + } +} diff --git a/tonic/src/transport/channel/service/user_agent.rs b/tonic/src/transport/channel/service/user_agent.rs deleted file mode 100644 index 8217d55e6..000000000 --- a/tonic/src/transport/channel/service/user_agent.rs +++ /dev/null @@ -1,159 +0,0 @@ -use http::{header::USER_AGENT, HeaderValue, Request}; -use std::task::{Context, Poll}; -use tower_service::Service; - -const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION")); - -#[derive(Debug)] -pub(crate) struct UserAgent { - inner: T, - user_agent: HeaderValue, -} - -impl UserAgent { - pub(crate) fn new(inner: T, user_agent: Option) -> Self { - let user_agent = user_agent - .map(|value| { - let mut buf = Vec::new(); - buf.extend(value.as_bytes()); - buf.push(b' '); - buf.extend(TONIC_USER_AGENT.as_bytes()); - HeaderValue::from_bytes(&buf).expect("user-agent should be valid") - }) - .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT)); - - Self { inner, user_agent } - } -} - -impl Service> for UserAgent -where - T: Service>, -{ - type Response = T::Response; - type Error = T::Error; - type Future = T::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - if let Ok(Some(user_agent)) = req - .headers_mut() - .try_insert(USER_AGENT, self.user_agent.clone()) - { - // The User-Agent header has already been set on the request. Let's - // append our user agent to the end. - let mut buf = Vec::new(); - buf.extend(user_agent.as_bytes()); - buf.push(b' '); - buf.extend(self.user_agent.as_bytes()); - req.headers_mut().insert( - USER_AGENT, - HeaderValue::from_bytes(&buf).expect("user-agent should be valid"), - ); - } - - self.inner.call(req) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - struct Svc; - - #[test] - fn sets_default_if_no_custom_user_agent() { - assert_eq!( - UserAgent::new(Svc, None).user_agent, - HeaderValue::from_static(TONIC_USER_AGENT) - ) - } - - #[test] - fn prepends_custom_user_agent_to_default() { - assert_eq!( - UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent, - HeaderValue::from_str(&format!("Greeter 1.1 {TONIC_USER_AGENT}")).unwrap() - ) - } - - struct TestSvc { - pub expected_user_agent: String, - } - - impl Service> for TestSvc { - type Response = (); - type Error = (); - type Future = std::future::Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request<()>) -> Self::Future { - let user_agent = req.headers().get(USER_AGENT).unwrap().to_str().unwrap(); - assert_eq!(user_agent, self.expected_user_agent); - std::future::ready(Ok(())) - } - } - - #[tokio::test] - async fn sets_default_user_agent_if_none_present() { - let expected_user_agent = TONIC_USER_AGENT.to_string(); - let mut ua = UserAgent::new( - TestSvc { - expected_user_agent, - }, - None, - ); - let _ = ua.call(Request::default()).await; - } - - #[tokio::test] - async fn sets_custom_user_agent_if_none_present() { - let expected_user_agent = format!("Greeter 1.1 {TONIC_USER_AGENT}"); - let mut ua = UserAgent::new( - TestSvc { - expected_user_agent, - }, - Some(HeaderValue::from_static("Greeter 1.1")), - ); - let _ = ua.call(Request::default()).await; - } - - #[tokio::test] - async fn appends_default_user_agent_to_request_user_agent() { - let mut req = Request::default(); - req.headers_mut() - .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y")); - - let expected_user_agent = format!("request-ua/x.y {TONIC_USER_AGENT}"); - let mut ua = UserAgent::new( - TestSvc { - expected_user_agent, - }, - None, - ); - let _ = ua.call(req).await; - } - - #[tokio::test] - async fn appends_custom_user_agent_to_request_user_agent() { - let mut req = Request::default(); - req.headers_mut() - .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y")); - - let expected_user_agent = format!("request-ua/x.y Greeter 1.1 {TONIC_USER_AGENT}"); - let mut ua = UserAgent::new( - TestSvc { - expected_user_agent, - }, - Some(HeaderValue::from_static("Greeter 1.1")), - ); - let _ = ua.call(req).await; - } -} From 469af1d6f330a2711d9c1d9803f15aa02d17a04c Mon Sep 17 00:00:00 2001 From: victor Date: Thu, 21 Aug 2025 16:49:59 +0100 Subject: [PATCH 11/11] transport: unify request modifiers and reduce allocations --- .../channel/service/request_modifiers.rs | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tonic/src/transport/channel/service/request_modifiers.rs b/tonic/src/transport/channel/service/request_modifiers.rs index 1cbf13e80..8ae207148 100644 --- a/tonic/src/transport/channel/service/request_modifiers.rs +++ b/tonic/src/transport/channel/service/request_modifiers.rs @@ -1,8 +1,17 @@ +use crate::body::Body; use http::{header::USER_AGENT, HeaderValue, Request, Uri}; use std::task::{Context, Poll}; use tower_service::Service; -use crate::body::Body; +/// A generic request modifier. +/// +/// `Modifier` wraps an inner service `T` and applies the +/// modifier `M` to each outgoing `Request`. +/// +/// This type centralizes the boilerplate for implementing +/// request middleware. A modifier is closure which receives +/// the request and mutates it before forwarding to the +/// inner service. #[derive(Debug)] pub(crate) struct Modifier { modifier_fn: M, @@ -39,7 +48,7 @@ where // and not FnOnce #[derive(Debug, Clone)] pub(crate) struct AddOrigin<'a> { - origin: &'a Uri + origin: &'a Uri, } impl<'a> AddOrigin<'a> { @@ -53,9 +62,7 @@ impl<'a> AddOrigin<'a> { Ok(Self { origin }) } - pub(crate) fn into_fn( - self, - ) -> impl FnOnce(Request) -> Request + Clone { + pub(crate) fn into_fn(self) -> impl FnOnce(Request) -> Request + Clone { let http::uri::Parts { scheme, authority, .. } = self.origin.clone().into_parts(); @@ -106,9 +113,7 @@ impl UserAgent { Self { user_agent } } - pub(crate) fn into_fn( - self, - ) -> impl FnOnce(Request) -> Request + Clone { + pub(crate) fn into_fn(self) -> impl FnOnce(Request) -> Request + Clone { move |mut req| { use http::header::Entry;