From c748be2883e8440cc2dd712e6b528177a3c5287f Mon Sep 17 00:00:00 2001 From: Kevin Swiber Date: Wed, 26 Apr 2023 11:00:52 -0700 Subject: [PATCH] Initial prep for hyper@1.0. --- Cargo.toml | 3 ++- src/main.rs | 11 +++++++---- src/routes/request_inspection.rs | 7 ++++--- src/routes/response_formats.rs | 3 ++- src/routes/status_codes.rs | 29 ++++++++++++----------------- src/server.rs | 7 ++++--- 6 files changed, 31 insertions(+), 29 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 55a0ed8..d8bbe9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] -axum = { version = "0.6.16", features = ["headers", "http2"] } +axum = { git="https://github.com/tokio-rs/axum.git", branch="david/hyper-1.0-rc.x", features = ["headers", "http2"] } mime = "0.3" minijinja = "0.32.0" rand = "0.8.5" @@ -17,4 +17,5 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] hyper = { version = "0.14", features = ["full"] } +http-body-util = "0.1.0-rc.2" tower = { version = "0.4", features = ["util"] } diff --git a/src/main.rs b/src/main.rs index 3e64a6d..786f755 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,8 +22,11 @@ async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], port)); tracing::info!("listening on {}", addr); - axum::Server::bind(&addr) - .serve(server::app().into_make_service_with_connect_info::()) - .await - .unwrap(); + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve( + listener, + server::app().into_make_service_with_connect_info::(), + ) + .await + .unwrap(); } diff --git a/src/routes/request_inspection.rs b/src/routes/request_inspection.rs index 44a34d3..0ae06d9 100644 --- a/src/routes/request_inspection.rs +++ b/src/routes/request_inspection.rs @@ -43,6 +43,7 @@ mod tests { extract::connect_info::MockConnectInfo, http::{header, HeaderValue, Request, StatusCode}, }; + use http_body_util::BodyExt; use std::net::SocketAddr; use tower::ServiceExt; @@ -68,7 +69,7 @@ mod tests { Some(&HeaderValue::from_static(mime::APPLICATION_JSON.as_ref())) ); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.collect().await.unwrap().to_bytes(); let response_json = serde_json::from_slice::(&body.to_vec()).unwrap(); let headers = Value::as_object(&response_json["headers"]).unwrap(); assert_eq!(headers["foo"], "value-foo"); @@ -90,7 +91,7 @@ mod tests { Some(&HeaderValue::from_static(mime::APPLICATION_JSON.as_ref())) ); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.collect().await.unwrap().to_bytes(); let response_json = serde_json::from_slice::(&body.to_vec()).unwrap(); assert_eq!(&response_json["origin"], "10.10.32.1"); } @@ -116,7 +117,7 @@ mod tests { Some(&HeaderValue::from_static(mime::APPLICATION_JSON.as_ref())) ); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.collect().await.unwrap().to_bytes(); let response_json = serde_json::from_slice::(&body.to_vec()).unwrap(); assert_eq!(&response_json["user_agent"], "foo-bar"); } diff --git a/src/routes/response_formats.rs b/src/routes/response_formats.rs index 32e9e16..dad2697 100644 --- a/src/routes/response_formats.rs +++ b/src/routes/response_formats.rs @@ -17,6 +17,7 @@ mod tests { body::Body, http::{header, HeaderValue, Request, StatusCode}, }; + use http_body_util::BodyExt; use tower::ServiceExt; #[tokio::test] @@ -39,7 +40,7 @@ mod tests { Some(&HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref())) ); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.collect().await.unwrap().to_bytes(); assert!(std::str::from_utf8(&body).is_ok()) } } diff --git a/src/routes/status_codes.rs b/src/routes/status_codes.rs index 489a711..3119b89 100644 --- a/src/routes/status_codes.rs +++ b/src/routes/status_codes.rs @@ -135,7 +135,8 @@ mod tests { body::Body, http::{header, HeaderValue, Method, Request, StatusCode}, }; - use tower::{Service, ServiceExt}; + use http_body_util::BodyExt; + use tower::ServiceExt; #[tokio::test] async fn selects_a_single_status_code() { @@ -156,7 +157,7 @@ mod tests { #[tokio::test] async fn supports_multiple_http_methods() { - let mut app = routes(); + let app = routes(); let methods = vec![ Method::GET, @@ -170,10 +171,8 @@ mod tests { for method in methods { let response = app - .ready() - .await - .unwrap() - .call( + .clone() + .oneshot( Request::builder() .method(method) .uri("/status/200") @@ -237,7 +236,7 @@ mod tests { #[tokio::test] async fn chooses_a_higher_weighted_random_status_code_more_often() { - let mut app = routes(); + let app = routes(); let mut ok_returns: u16 = 0; let mut created_returns: u16 = 0; let mut accepted_returns: u16 = 0; @@ -245,10 +244,8 @@ mod tests { for _num in 0..1000 { let response = app - .ready() - .await - .unwrap() - .call( + .clone() + .oneshot( Request::builder() .uri("/status/200:0.1,201:0.25,202:0.75,204:1") .body(Body::empty()) @@ -318,7 +315,7 @@ mod tests { #[tokio::test] async fn redirects_have_location_header() { - let mut app = routes(); + let app = routes(); let redirects = vec![ StatusCode::MOVED_PERMANENTLY, @@ -330,10 +327,8 @@ mod tests { for redirect in redirects { let response = app - .ready() - .await - .unwrap() - .call( + .clone() + .oneshot( Request::builder() .uri(format!("/status/{}", redirect.as_str())) .body(Body::empty()) @@ -444,7 +439,7 @@ mod tests { Some(&HeaderValue::from_static(mime::TEXT_PLAIN.as_ref())) ); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.collect().await.unwrap().to_bytes(); assert!(std::str::from_utf8(&body).is_ok()) } } diff --git a/src/server.rs b/src/server.rs index b538aa6..0819c5a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,7 @@ use crate::routes::{request_inspection, response_formats, root, status_codes}; use axum::{ - http::{header, HeaderValue, Method, Request, StatusCode}, + extract::Request, + http::{header, HeaderValue, Method, StatusCode}, middleware::{from_fn, Next}, response::Response, Router, @@ -17,7 +18,7 @@ pub fn app() -> Router { .layer(from_fn(inject_cors_headers)) } -async fn inject_server_header(request: Request, next: Next) -> Response { +async fn inject_server_header(request: Request, next: Next) -> Response { let mut response = next.run(request).await; let headers = response.headers_mut(); @@ -28,7 +29,7 @@ async fn inject_server_header(request: Request, next: Next) -> Response response } -async fn inject_cors_headers(request: Request, next: Next) -> Response { +async fn inject_cors_headers(request: Request, next: Next) -> Response { let method = request.method().clone(); let request_headers = request.headers().clone(); let mut response = next.run(request).await;