From 655c6fe9cf5a87f64c4945309f28291c9cc1e74e Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 3 Aug 2023 19:49:35 +0200 Subject: [PATCH] Remove all connect and accept functions Those functions were blurring the line between setup failures (which are caused by the developer misusing the API) and actual failures encountered on the stream while trying to achieve the TLS handshake, especially on SslConnector and SslAcceptor. Removing them allows for the removal of HandshakeError, as the HandshakeError::SetupFailure variant becomes useless, and there is no real need to distinguish in that error type between Failure and WouldBlock when we can just check the error stored in MidHandshakeSslStream. This then allow us to simplify tokio_boring's own entry points, also making them distinguish between setup failures and failures on the stream. --- boring/src/ssl/connector.rs | 47 +------- boring/src/ssl/error.rs | 67 ----------- boring/src/ssl/mod.rs | 137 ++++++++-------------- boring/src/ssl/test/custom_verify.rs | 18 +-- boring/src/ssl/test/ech.rs | 6 +- boring/src/ssl/test/mod.rs | 63 +++++++--- boring/src/ssl/test/private_key_method.rs | 31 ++--- boring/src/ssl/test/server.rs | 15 +-- boring/src/ssl/test/session.rs | 10 +- hyper-boring/tests/v1.rs | 10 +- tokio-boring/examples/simple-async.rs | 2 +- tokio-boring/src/lib.rs | 85 ++++++-------- tokio-boring/tests/async_get_session.rs | 4 + tokio-boring/tests/client_server.rs | 9 +- tokio-boring/tests/common/mod.rs | 6 +- tokio-boring/tests/rpk.rs | 4 +- 16 files changed, 178 insertions(+), 336 deletions(-) diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 111b45c2a..493f91186 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -4,8 +4,8 @@ use std::ops::{Deref, DerefMut}; use crate::dh::Dh; use crate::error::ErrorStack; use crate::ssl::{ - HandshakeError, Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, - SslOptions, SslRef, SslStream, SslVerifyMode, + Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, SslOptions, SslRef, + SslVerifyMode, }; use crate::version; use std::net::IpAddr; @@ -112,21 +112,6 @@ impl SslConnector { self.configure()?.setup_connect(domain, stream) } - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Returns a structure allowing for configuration of a single TLS session before connection. pub fn configure(&self) -> Result { Ssl::new(&self.0).map(|ssl| ConnectConfiguration { @@ -253,21 +238,6 @@ impl ConnectConfiguration { { Ok(self.into_ssl(domain)?.setup_connect(stream)) } - - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } } impl Deref for ConnectConfiguration { @@ -387,19 +357,6 @@ impl SslAcceptor { Ok(ssl.setup_accept(stream)) } - /// Attempts a server-side TLS handshake on a stream. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(&self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. #[must_use] pub fn into_context(self) -> SslContext { diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index 5acad8200..7416dab87 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -1,15 +1,12 @@ use crate::ffi; -use crate::x509::X509VerifyError; use libc::c_int; use openssl_macros::corresponds; use std::error; -use std::error::Error as StdError; use std::ffi::CStr; use std::fmt; use std::io; use crate::error::ErrorStack; -use crate::ssl::MidHandshakeSslStream; /// `SSL_ERROR_*` error code returned from SSL functions. /// @@ -206,67 +203,3 @@ impl error::Error for Error { } } } - -/// An error or intermediate state after a TLS handshake attempt. -// FIXME overhaul -#[derive(Debug)] -pub enum HandshakeError { - /// Setup failed. - SetupFailure(ErrorStack), - /// The handshake failed. - Failure(MidHandshakeSslStream), - /// The handshake encountered a `WouldBlock` error midway through. - /// - /// This error will never be returned for blocking streams. - WouldBlock(MidHandshakeSslStream), -} - -impl StdError for HandshakeError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - HandshakeError::SetupFailure(ref e) => Some(e), - HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => Some(s.error()), - } - } -} - -impl fmt::Display for HandshakeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - HandshakeError::SetupFailure(ref e) => { - write!(f, "TLS stream setup failed {e}") - } - HandshakeError::Failure(ref s) => fmt_mid_handshake_error(s, f, "TLS handshake failed"), - HandshakeError::WouldBlock(ref s) => { - fmt_mid_handshake_error(s, f, "TLS handshake interrupted") - } - } - } -} - -fn fmt_mid_handshake_error( - s: &MidHandshakeSslStream, - f: &mut fmt::Formatter, - prefix: &str, -) -> fmt::Result { - #[cfg(feature = "rpk")] - if s.ssl().ssl_context().is_rpk() { - write!(f, "{}", prefix)?; - return write!(f, " {}", s.error()); - } - - match s.ssl().verify_result() { - // INVALID_CALL is returned if no verification took place, - // such as before a cert is sent. - Ok(()) | Err(X509VerifyError::INVALID_CALL) => write!(f, "{prefix}")?, - Err(verify) => write!(f, "{prefix}: cert verification failed - {verify}")?, - } - - write!(f, " {}", s.error()) -} - -impl From for HandshakeError { - fn from(e: ErrorStack) -> HandshakeError { - HandshakeError::SetupFailure(e) - } -} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index bd864bdae..a71886771 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -15,7 +15,11 @@ //! let connector = SslConnector::builder(SslMethod::tls()).unwrap().build(); //! //! let stream = TcpStream::connect("google.com:443").unwrap(); -//! let mut stream = connector.connect("google.com", stream).unwrap(); +//! let mut stream = connector +//! .setup_connect("google.com", stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); //! //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); //! let mut res = vec![]; @@ -49,7 +53,12 @@ //! Ok(stream) => { //! let acceptor = acceptor.clone(); //! thread::spawn(move || { -//! let stream = acceptor.accept(stream).unwrap(); +//! let stream = acceptor +//! .setup_accept(stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); +//! //! handle_client(stream); //! }); //! } @@ -107,7 +116,7 @@ pub use self::connector::{ ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, }; pub use self::ech::{SslEchKeys, SslEchKeysRef}; -pub use self::error::{Error, ErrorCode, HandshakeError}; +pub use self::error::{Error, ErrorCode}; mod async_callbacks; mod bio; @@ -2742,22 +2751,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_connect() } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// [`SslConnector`] rather than `Ssl` directly, as it manages that configuration. - pub fn connect(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(stream).handshake() - } - /// Initiates a server-side TLS handshake. /// /// This method is guaranteed to return without calling any callback defined @@ -2790,24 +2783,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_accept() } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. - /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream).handshake() - } } impl fmt::Debug for SslRef { @@ -3903,18 +3878,43 @@ impl MidHandshakeSslStream { /// Restarts the handshake process. #[corresponds(SSL_do_handshake)] - pub fn handshake(mut self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result, Self> { let ret = unsafe { ffi::SSL_do_handshake(self.stream.ssl.as_ptr()) }; if ret > 0 { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - Err(if self.error.would_block() { - HandshakeError::WouldBlock(self) - } else { - HandshakeError::Failure(self) - }) + + Err(self) + } + } + + /// An `impl Display` suitable to represent the current error. + pub fn display_error<'a>(&'a self) -> impl fmt::Display + 'a { + struct Display<'a, S>(&'a MidHandshakeSslStream); + + impl fmt::Display for Display<'_, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("TLS handshake failed")?; + + #[cfg(feature = "rpk")] + if self.0.ssl().ssl_context().is_rpk() { + return self.0.error().fmt(fmt); + } + + match self.0.ssl().verify_result() { + // INVALID_CALL is returned if no verification took place, + // such as before a cert is sent. + Ok(()) | Err(X509VerifyError::INVALID_CALL) => {} + Err(verify) => write!(fmt, ": cert verification failed - {verify}")?, + } + + fmt.write_str(" ")?; + self.0.error().fmt(fmt) + } } + + Display(self) } } @@ -4249,22 +4249,20 @@ where /// Configure as an outgoing stream from a client. #[corresponds(SSL_set_connect_state)] - pub fn set_connect_state(&mut self) { + fn set_connect_state(&mut self) { unsafe { ffi::SSL_set_connect_state(self.inner.ssl.as_ptr()) } } /// Configure as an incoming stream to a server. #[corresponds(SSL_set_accept_state)] - pub fn set_accept_state(&mut self) { + fn set_accept_state(&mut self) { unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_connect_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. - #[must_use] + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_connect(mut self) -> MidHandshakeSslStream { self.set_connect_state(); @@ -4280,20 +4278,10 @@ where } } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self) -> Result, HandshakeError> { - self.setup_connect().handshake() - } - /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_accept_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. - #[must_use] + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_accept(mut self) -> MidHandshakeSslStream { self.set_accept_state(); @@ -4308,33 +4296,6 @@ where }, } } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(self) -> Result, HandshakeError> { - self.setup_accept().handshake() - } - - /// Initiates the handshake. - /// - /// This will fail if `set_accept_state` or `set_connect_state` was not called first. - #[corresponds(SSL_do_handshake)] - pub fn handshake(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_do_handshake(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - Err(if error.would_block() { - HandshakeError::WouldBlock(MidHandshakeSslStream { stream, error }) - } else { - HandshakeError::Failure(MidHandshakeSslStream { stream, error }) - }) - } - } } impl SslStreamBuilder { diff --git a/boring/src/ssl/test/custom_verify.rs b/boring/src/ssl/test/custom_verify.rs index 64e8f89b5..b30e3a563 100644 --- a/boring/src/ssl/test/custom_verify.rs +++ b/boring/src/ssl/test/custom_verify.rs @@ -1,5 +1,5 @@ use super::server::Server; -use crate::ssl::{ErrorCode, HandshakeError, SslAlert, SslVerifyMode}; +use crate::ssl::{ErrorCode, SslAlert, SslVerifyMode}; use crate::x509::X509StoreContext; use crate::{hash::MessageDigest, ssl::SslVerifyError}; use hex; @@ -9,11 +9,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; fn untrusted_callback_override_bad() { let mut server = Server::builder(); - server.err_cb(|err| { - let HandshakeError::Failure(handshake) = err else { - panic!("expected failure error"); - }; - + server.err_cb(|handshake| { assert_eq!( handshake.error().to_string(), "[SSLV3_ALERT_CERTIFICATE_REVOKED]" @@ -242,11 +238,7 @@ fn both_callback() { fn retry() { let mut server = Server::builder(); - server.err_cb(|err| { - let HandshakeError::Failure(handshake) = err else { - panic!("expected failure error"); - }; - + server.err_cb(|handshake| { assert_eq!( handshake.error().to_string(), "[SSLV3_ALERT_CERTIFICATE_REVOKED]" @@ -267,9 +259,7 @@ fn retry() { Err(SslVerifyError::Invalid(SslAlert::CERTIFICATE_REVOKED)) }); - let HandshakeError::WouldBlock(handshake) = client.connect_err() else { - panic!("should be WouldBlock"); - }; + let handshake = client.connect_err(); assert!(CALLED_BACK.load(Ordering::SeqCst)); assert!(handshake.error().would_block()); diff --git a/boring/src/ssl/test/ech.rs b/boring/src/ssl/test/ech.rs index d2797d427..815fe5cc3 100644 --- a/boring/src/ssl/test/ech.rs +++ b/boring/src/ssl/test/ech.rs @@ -1,7 +1,6 @@ use crate::hpke::HpkeKey; use crate::ssl::ech::SslEchKeys; use crate::ssl::test::server::{ClientSslBuilder, Server}; -use crate::ssl::HandshakeError; // For future reference, these configs are generated by building the bssl tool (the binary is built // alongside boringssl) and running the following command: @@ -49,9 +48,8 @@ fn ech_rejection() { // `ECH_CONFIG_LIST_2` should trigger rejection. let (_server, client) = bootstrap_ech(ECH_CONFIG_2, ECH_KEY_2, ECH_CONFIG_LIST); - let HandshakeError::Failure(failed_ssl_stream) = client.connect_err() else { - panic!("wrong HandshakeError failure variant!"); - }; + let failed_ssl_stream = client.connect_err(); + assert_eq!( failed_ssl_stream.ssl().get_ech_name_override(), Some(b"ech.com".to_vec().as_ref()) diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index aded182d1..bd85672fb 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -125,7 +125,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -144,7 +144,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -194,7 +194,7 @@ fn test_connect_with_srtp_ssl() { profilenames ); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -213,7 +213,7 @@ fn test_connect_with_srtp_ssl() { ssl.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32") .unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -445,7 +445,10 @@ fn write_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -476,7 +479,10 @@ fn read_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -507,7 +513,10 @@ fn flush_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -537,7 +546,7 @@ fn default_verify_paths() { }; let mut ssl = Ssl::new(&ctx).unwrap(); ssl.set_hostname("google.com").unwrap(); - let mut socket = ssl.connect(s).unwrap(); + let mut socket = ssl.setup_connect(s).handshake().unwrap(); socket.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); let mut result = vec![]; @@ -668,7 +677,12 @@ fn connector_valid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - let mut s = connector.build().connect("foobar.com", s).unwrap(); + let mut s = connector + .build() + .setup_connect("foobar.com", s) + .unwrap() + .handshake() + .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -682,7 +696,12 @@ fn connector_invalid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - connector.build().connect("bogus.com", s).unwrap_err(); + connector + .build() + .setup_connect("bogus.com", s) + .unwrap() + .handshake() + .unwrap_err(); } #[test] @@ -698,7 +717,9 @@ fn connector_invalid_no_hostname_verification() { .configure() .unwrap() .verify_hostname(false) - .connect("bogus.com", s) + .setup_connect("bogus.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -716,7 +737,9 @@ fn connector_no_hostname_still_verifies() { .configure() .unwrap() .verify_hostname(false) - .connect("fizzbuzz.com", s) + .setup_connect("fizzbuzz.com", s) + .unwrap() + .handshake() .is_err()); } @@ -733,7 +756,9 @@ fn connector_no_hostname_can_disable_verify() { .configure() .unwrap() .verify_hostname(false) - .connect("foobar.com", s) + .setup_connect("foobar.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -750,7 +775,7 @@ fn test_mozilla_server(new: fn(SslMethod) -> Result Result, io_cb: Box) + Send>, - err_cb: Box) + Send>, + err_cb: Box) + Send>, should_error: bool, expected_connections_count: usize, } @@ -102,7 +103,7 @@ impl Builder { self.io_cb = Box::new(cb); } - pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError) + Send + 'static) { + pub fn err_cb(&mut self, cb: impl FnMut(MidHandshakeSslStream) + Send + 'static) { self.should_error(); self.err_cb = Box::new(cb); @@ -133,7 +134,7 @@ impl Builder { ssl_cb(&mut ssl); - let r = ssl.accept(socket); + let r = ssl.setup_accept(socket).handshake(); if should_error { err_cb(r.unwrap_err()); @@ -176,7 +177,7 @@ impl ClientBuilder { self.build().builder().connect() } - pub fn connect_err(self) -> HandshakeError { + pub fn connect_err(self) -> MidHandshakeSslStream { self.build().builder().connect_err() } } @@ -207,12 +208,12 @@ impl ClientSslBuilder { pub fn connect(self) -> SslStream { let socket = TcpStream::connect(self.addr).unwrap(); - let mut s = self.ssl.connect(socket).unwrap(); + let mut s = self.ssl.setup_connect(socket).handshake().unwrap(); s.read_exact(&mut [0]).unwrap(); s } - pub fn connect_err(self) -> HandshakeError { + pub fn connect_err(self) -> MidHandshakeSslStream { let socket = TcpStream::connect(self.addr).unwrap(); self.ssl.setup_connect(socket).handshake().unwrap_err() diff --git a/boring/src/ssl/test/session.rs b/boring/src/ssl/test/session.rs index 23c0f4d5d..873a7f236 100644 --- a/boring/src/ssl/test/session.rs +++ b/boring/src/ssl/test/session.rs @@ -4,8 +4,8 @@ use std::sync::OnceLock; use crate::ssl::test::server::Server; use crate::ssl::{ - ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder, - SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion, + ErrorCode, GetSessionPendingError, Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, + SslSession, SslSessionCacheMode, SslVersion, }; #[test] @@ -125,11 +125,7 @@ fn new_get_session_callback_pending() { }); } server.ctx().set_session_id_context(b"foo").unwrap(); - server.err_cb(|error| { - let HandshakeError::WouldBlock(mid_handshake) = error else { - panic!("should be WouldBlock"); - }; - + server.err_cb(|mid_handshake| { assert!(mid_handshake.error().would_block()); assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION); diff --git a/hyper-boring/tests/v1.rs b/hyper-boring/tests/v1.rs index 4082d2cef..90769e13c 100644 --- a/hyper-boring/tests/v1.rs +++ b/hyper-boring/tests/v1.rs @@ -47,7 +47,10 @@ async fn localhost() { for _ in 0..3 { let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); let service = service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(>::new())) @@ -117,7 +120,10 @@ async fn alpn_h2() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service = service::service_fn(|_| async { diff --git a/tokio-boring/examples/simple-async.rs b/tokio-boring/examples/simple-async.rs index f4a69a1c9..d9b159ded 100644 --- a/tokio-boring/examples/simple-async.rs +++ b/tokio-boring/examples/simple-async.rs @@ -11,6 +11,6 @@ async fn main() -> anyhow::Result<()> { ssl_builder.set_default_verify_paths()?; ssl_builder.set_verify(ssl::SslVerifyMode::PEER); let acceptor = ssl_builder.build(); - let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream).await?; + let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream)?.await?; Ok(()) } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index 374a0bde0..513cc0389 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,6 +13,7 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, @@ -42,34 +43,40 @@ pub use boring::ssl::{ /// /// This function automatically sets the task waker on the `Ssl` from `config` to /// allow to make use of async callbacks provided by the boring crate. -pub async fn connect( +pub fn connect( config: ConnectConfiguration, domain: &str, stream: S, -) -> Result, HandshakeError> +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - let mid_handshake = config - .setup_connect(domain, AsyncStreamBridge::new(stream)) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; - - HandshakeFuture(Some(mid_handshake)).await + handshake(|s| config.setup_connect(domain, s), stream) } /// Asynchronously performs a server-side TLS handshake over the provided stream. /// /// This function automatically sets the task waker on the `Ssl` from `config` to /// allow to make use of async callbacks provided by the boring crate. -pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> +pub fn accept(acceptor: &SslAcceptor, stream: S) -> Result, ErrorStack> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + handshake(|s| acceptor.setup_accept(s), stream) +} + +fn handshake( + f: impl FnOnce( + AsyncStreamBridge, + ) -> Result>, ErrorStack>, + stream: S, +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - let mid_handshake = acceptor - .setup_accept(AsyncStreamBridge::new(stream)) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; + let ongoing_handshake = Some(f(AsyncStreamBridge::new(stream))?); - HandshakeFuture(Some(mid_handshake)).await + Ok(HandshakeFuture(ongoing_handshake)) } fn cvt(r: io::Result) -> Poll> { @@ -252,52 +259,33 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(MidHandshakeSslStream>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. - #[must_use] - pub fn ssl(&self) -> Option<&SslRef> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.ssl()), - _ => None, - } + pub fn ssl(&self) -> &SslRef { + self.0.ssl() } /// Converts error to the source data stream that was used for the handshake. - #[must_use] - pub fn into_source_stream(self) -> Option { - match self.0 { - ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), - _ => None, - } + pub fn into_source_stream(self) -> S { + self.0.into_source_stream().stream } /// Returns a reference to the source data stream. - #[must_use] - pub fn as_source_stream(&self) -> Option<&S> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream), - _ => None, - } + pub fn as_source_stream(&self) -> &S { + &self.0.get_ref().stream } /// Returns the error code, if any. - #[must_use] - pub fn code(&self) -> Option { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.error().code()), - _ => None, - } + pub fn code(&self) -> ErrorCode { + self.0.error().code() } /// Returns a reference to the inner I/O error, if any. #[must_use] pub fn as_io_error(&self) -> Option<&io::Error> { - match &self.0 { - ssl::HandshakeError::Failure(s) => s.error().io_error(), - _ => None, - } + self.0.error().io_error() } } @@ -312,7 +300,7 @@ where impl fmt::Display for HandshakeError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, fmt) + self.0.display_error().fmt(fmt) } } @@ -321,7 +309,7 @@ where S: fmt::Debug, { fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() + self.0.error().source() } } @@ -351,7 +339,7 @@ where Poll::Ready(Ok(SslStream(stream))) } - Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { + Err(mut mid_handshake) if mid_handshake.error().would_block() => { mid_handshake.get_mut().set_waker(None); mid_handshake.ssl_mut().set_task_waker(None); @@ -359,15 +347,10 @@ where Poll::Pending } - Err(ssl::HandshakeError::Failure(mut mid_handshake)) => { + Err(mut mid_handshake) => { mid_handshake.get_mut().set_waker(None); - Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure( - mid_handshake, - )))) - } - Err(err @ ssl::HandshakeError::SetupFailure(_)) => { - Poll::Ready(Err(HandshakeError(err))) + Poll::Ready(Err(HandshakeError(mid_handshake))) } } } diff --git a/tokio-boring/tests/async_get_session.rs b/tokio-boring/tests/async_get_session.rs index 0ab9b396e..3cb0c7a61 100644 --- a/tokio-boring/tests/async_get_session.rs +++ b/tokio-boring/tests/async_get_session.rs @@ -57,6 +57,7 @@ async fn test() { let server = async move { tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .unwrap() .await .unwrap(); @@ -64,6 +65,7 @@ async fn test() { assert!(!FOUND_SESSION.load(Ordering::SeqCst)); tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .unwrap() .await .unwrap(); @@ -76,6 +78,7 @@ async fn test() { "localhost", TcpStream::connect(&addr).await.unwrap(), ) + .unwrap() .await .unwrap(); @@ -94,6 +97,7 @@ async fn test() { "localhost", TcpStream::connect(&addr).await.unwrap(), ) + .unwrap() .await .unwrap(); }; diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 925f9875e..872033f8e 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -19,6 +19,7 @@ async fn google() { .configure() .unwrap(); let mut stream = tokio_boring::connect(config, "google.com", stream) + .unwrap() .await .unwrap(); @@ -44,15 +45,11 @@ async fn handshake_error() { let (stream, addr) = create_server(|_| ()); let server = async { - let err = stream.await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = stream.await.unwrap_err(); }; let client = async { - let err = connect(addr, |_| Ok(())).await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = connect(addr, |_| Ok(())).await.unwrap_err(); }; future::join(server, client).await; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs index b28917b42..a94787baa 100644 --- a/tokio-boring/tests/common/mod.rs +++ b/tokio-boring/tests/common/mod.rs @@ -24,7 +24,7 @@ pub(crate) fn create_server( let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -65,7 +65,9 @@ pub(crate) async fn connect( let stream = TcpStream::connect(&addr).await.unwrap(); - tokio_boring::connect(config, "localhost", stream).await + tokio_boring::connect(config, "localhost", stream) + .unwrap() + .await } pub(crate) fn create_connector( diff --git a/tokio-boring/tests/rpk.rs b/tokio-boring/tests/rpk.rs index 5492767ab..48f53655c 100644 --- a/tokio-boring/tests/rpk.rs +++ b/tokio-boring/tests/rpk.rs @@ -34,7 +34,7 @@ mod test_rpk { let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -66,6 +66,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let mut stream = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap(); @@ -97,6 +98,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let err = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap_err();