diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 111b45c2a..493f91186 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -4,8 +4,8 @@ use std::ops::{Deref, DerefMut}; use crate::dh::Dh; use crate::error::ErrorStack; use crate::ssl::{ - HandshakeError, Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, - SslOptions, SslRef, SslStream, SslVerifyMode, + Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, SslOptions, SslRef, + SslVerifyMode, }; use crate::version; use std::net::IpAddr; @@ -112,21 +112,6 @@ impl SslConnector { self.configure()?.setup_connect(domain, stream) } - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Returns a structure allowing for configuration of a single TLS session before connection. pub fn configure(&self) -> Result { Ssl::new(&self.0).map(|ssl| ConnectConfiguration { @@ -253,21 +238,6 @@ impl ConnectConfiguration { { Ok(self.into_ssl(domain)?.setup_connect(stream)) } - - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } } impl Deref for ConnectConfiguration { @@ -387,19 +357,6 @@ impl SslAcceptor { Ok(ssl.setup_accept(stream)) } - /// Attempts a server-side TLS handshake on a stream. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(&self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. #[must_use] pub fn into_context(self) -> SslContext { diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index 5acad8200..7416dab87 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -1,15 +1,12 @@ use crate::ffi; -use crate::x509::X509VerifyError; use libc::c_int; use openssl_macros::corresponds; use std::error; -use std::error::Error as StdError; use std::ffi::CStr; use std::fmt; use std::io; use crate::error::ErrorStack; -use crate::ssl::MidHandshakeSslStream; /// `SSL_ERROR_*` error code returned from SSL functions. /// @@ -206,67 +203,3 @@ impl error::Error for Error { } } } - -/// An error or intermediate state after a TLS handshake attempt. -// FIXME overhaul -#[derive(Debug)] -pub enum HandshakeError { - /// Setup failed. - SetupFailure(ErrorStack), - /// The handshake failed. - Failure(MidHandshakeSslStream), - /// The handshake encountered a `WouldBlock` error midway through. - /// - /// This error will never be returned for blocking streams. - WouldBlock(MidHandshakeSslStream), -} - -impl StdError for HandshakeError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - HandshakeError::SetupFailure(ref e) => Some(e), - HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => Some(s.error()), - } - } -} - -impl fmt::Display for HandshakeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - HandshakeError::SetupFailure(ref e) => { - write!(f, "TLS stream setup failed {e}") - } - HandshakeError::Failure(ref s) => fmt_mid_handshake_error(s, f, "TLS handshake failed"), - HandshakeError::WouldBlock(ref s) => { - fmt_mid_handshake_error(s, f, "TLS handshake interrupted") - } - } - } -} - -fn fmt_mid_handshake_error( - s: &MidHandshakeSslStream, - f: &mut fmt::Formatter, - prefix: &str, -) -> fmt::Result { - #[cfg(feature = "rpk")] - if s.ssl().ssl_context().is_rpk() { - write!(f, "{}", prefix)?; - return write!(f, " {}", s.error()); - } - - match s.ssl().verify_result() { - // INVALID_CALL is returned if no verification took place, - // such as before a cert is sent. - Ok(()) | Err(X509VerifyError::INVALID_CALL) => write!(f, "{prefix}")?, - Err(verify) => write!(f, "{prefix}: cert verification failed - {verify}")?, - } - - write!(f, " {}", s.error()) -} - -impl From for HandshakeError { - fn from(e: ErrorStack) -> HandshakeError { - HandshakeError::SetupFailure(e) - } -} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index bd864bdae..a71886771 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -15,7 +15,11 @@ //! let connector = SslConnector::builder(SslMethod::tls()).unwrap().build(); //! //! let stream = TcpStream::connect("google.com:443").unwrap(); -//! let mut stream = connector.connect("google.com", stream).unwrap(); +//! let mut stream = connector +//! .setup_connect("google.com", stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); //! //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); //! let mut res = vec![]; @@ -49,7 +53,12 @@ //! Ok(stream) => { //! let acceptor = acceptor.clone(); //! thread::spawn(move || { -//! let stream = acceptor.accept(stream).unwrap(); +//! let stream = acceptor +//! .setup_accept(stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); +//! //! handle_client(stream); //! }); //! } @@ -107,7 +116,7 @@ pub use self::connector::{ ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, }; pub use self::ech::{SslEchKeys, SslEchKeysRef}; -pub use self::error::{Error, ErrorCode, HandshakeError}; +pub use self::error::{Error, ErrorCode}; mod async_callbacks; mod bio; @@ -2742,22 +2751,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_connect() } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// [`SslConnector`] rather than `Ssl` directly, as it manages that configuration. - pub fn connect(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(stream).handshake() - } - /// Initiates a server-side TLS handshake. /// /// This method is guaranteed to return without calling any callback defined @@ -2790,24 +2783,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_accept() } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. - /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream).handshake() - } } impl fmt::Debug for SslRef { @@ -3903,18 +3878,43 @@ impl MidHandshakeSslStream { /// Restarts the handshake process. #[corresponds(SSL_do_handshake)] - pub fn handshake(mut self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result, Self> { let ret = unsafe { ffi::SSL_do_handshake(self.stream.ssl.as_ptr()) }; if ret > 0 { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - Err(if self.error.would_block() { - HandshakeError::WouldBlock(self) - } else { - HandshakeError::Failure(self) - }) + + Err(self) + } + } + + /// An `impl Display` suitable to represent the current error. + pub fn display_error<'a>(&'a self) -> impl fmt::Display + 'a { + struct Display<'a, S>(&'a MidHandshakeSslStream); + + impl fmt::Display for Display<'_, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("TLS handshake failed")?; + + #[cfg(feature = "rpk")] + if self.0.ssl().ssl_context().is_rpk() { + return self.0.error().fmt(fmt); + } + + match self.0.ssl().verify_result() { + // INVALID_CALL is returned if no verification took place, + // such as before a cert is sent. + Ok(()) | Err(X509VerifyError::INVALID_CALL) => {} + Err(verify) => write!(fmt, ": cert verification failed - {verify}")?, + } + + fmt.write_str(" ")?; + self.0.error().fmt(fmt) + } } + + Display(self) } } @@ -4249,22 +4249,20 @@ where /// Configure as an outgoing stream from a client. #[corresponds(SSL_set_connect_state)] - pub fn set_connect_state(&mut self) { + fn set_connect_state(&mut self) { unsafe { ffi::SSL_set_connect_state(self.inner.ssl.as_ptr()) } } /// Configure as an incoming stream to a server. #[corresponds(SSL_set_accept_state)] - pub fn set_accept_state(&mut self) { + fn set_accept_state(&mut self) { unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_connect_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. - #[must_use] + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_connect(mut self) -> MidHandshakeSslStream { self.set_connect_state(); @@ -4280,20 +4278,10 @@ where } } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self) -> Result, HandshakeError> { - self.setup_connect().handshake() - } - /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_accept_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. - #[must_use] + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_accept(mut self) -> MidHandshakeSslStream { self.set_accept_state(); @@ -4308,33 +4296,6 @@ where }, } } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(self) -> Result, HandshakeError> { - self.setup_accept().handshake() - } - - /// Initiates the handshake. - /// - /// This will fail if `set_accept_state` or `set_connect_state` was not called first. - #[corresponds(SSL_do_handshake)] - pub fn handshake(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_do_handshake(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - Err(if error.would_block() { - HandshakeError::WouldBlock(MidHandshakeSslStream { stream, error }) - } else { - HandshakeError::Failure(MidHandshakeSslStream { stream, error }) - }) - } - } } impl SslStreamBuilder { diff --git a/boring/src/ssl/test/custom_verify.rs b/boring/src/ssl/test/custom_verify.rs index 64e8f89b5..b30e3a563 100644 --- a/boring/src/ssl/test/custom_verify.rs +++ b/boring/src/ssl/test/custom_verify.rs @@ -1,5 +1,5 @@ use super::server::Server; -use crate::ssl::{ErrorCode, HandshakeError, SslAlert, SslVerifyMode}; +use crate::ssl::{ErrorCode, SslAlert, SslVerifyMode}; use crate::x509::X509StoreContext; use crate::{hash::MessageDigest, ssl::SslVerifyError}; use hex; @@ -9,11 +9,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; fn untrusted_callback_override_bad() { let mut server = Server::builder(); - server.err_cb(|err| { - let HandshakeError::Failure(handshake) = err else { - panic!("expected failure error"); - }; - + server.err_cb(|handshake| { assert_eq!( handshake.error().to_string(), "[SSLV3_ALERT_CERTIFICATE_REVOKED]" @@ -242,11 +238,7 @@ fn both_callback() { fn retry() { let mut server = Server::builder(); - server.err_cb(|err| { - let HandshakeError::Failure(handshake) = err else { - panic!("expected failure error"); - }; - + server.err_cb(|handshake| { assert_eq!( handshake.error().to_string(), "[SSLV3_ALERT_CERTIFICATE_REVOKED]" @@ -267,9 +259,7 @@ fn retry() { Err(SslVerifyError::Invalid(SslAlert::CERTIFICATE_REVOKED)) }); - let HandshakeError::WouldBlock(handshake) = client.connect_err() else { - panic!("should be WouldBlock"); - }; + let handshake = client.connect_err(); assert!(CALLED_BACK.load(Ordering::SeqCst)); assert!(handshake.error().would_block()); diff --git a/boring/src/ssl/test/ech.rs b/boring/src/ssl/test/ech.rs index d2797d427..815fe5cc3 100644 --- a/boring/src/ssl/test/ech.rs +++ b/boring/src/ssl/test/ech.rs @@ -1,7 +1,6 @@ use crate::hpke::HpkeKey; use crate::ssl::ech::SslEchKeys; use crate::ssl::test::server::{ClientSslBuilder, Server}; -use crate::ssl::HandshakeError; // For future reference, these configs are generated by building the bssl tool (the binary is built // alongside boringssl) and running the following command: @@ -49,9 +48,8 @@ fn ech_rejection() { // `ECH_CONFIG_LIST_2` should trigger rejection. let (_server, client) = bootstrap_ech(ECH_CONFIG_2, ECH_KEY_2, ECH_CONFIG_LIST); - let HandshakeError::Failure(failed_ssl_stream) = client.connect_err() else { - panic!("wrong HandshakeError failure variant!"); - }; + let failed_ssl_stream = client.connect_err(); + assert_eq!( failed_ssl_stream.ssl().get_ech_name_override(), Some(b"ech.com".to_vec().as_ref()) diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index aded182d1..bd85672fb 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -125,7 +125,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -144,7 +144,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -194,7 +194,7 @@ fn test_connect_with_srtp_ssl() { profilenames ); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -213,7 +213,7 @@ fn test_connect_with_srtp_ssl() { ssl.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32") .unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -445,7 +445,10 @@ fn write_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -476,7 +479,10 @@ fn read_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -507,7 +513,10 @@ fn flush_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -537,7 +546,7 @@ fn default_verify_paths() { }; let mut ssl = Ssl::new(&ctx).unwrap(); ssl.set_hostname("google.com").unwrap(); - let mut socket = ssl.connect(s).unwrap(); + let mut socket = ssl.setup_connect(s).handshake().unwrap(); socket.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); let mut result = vec![]; @@ -668,7 +677,12 @@ fn connector_valid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - let mut s = connector.build().connect("foobar.com", s).unwrap(); + let mut s = connector + .build() + .setup_connect("foobar.com", s) + .unwrap() + .handshake() + .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -682,7 +696,12 @@ fn connector_invalid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - connector.build().connect("bogus.com", s).unwrap_err(); + connector + .build() + .setup_connect("bogus.com", s) + .unwrap() + .handshake() + .unwrap_err(); } #[test] @@ -698,7 +717,9 @@ fn connector_invalid_no_hostname_verification() { .configure() .unwrap() .verify_hostname(false) - .connect("bogus.com", s) + .setup_connect("bogus.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -716,7 +737,9 @@ fn connector_no_hostname_still_verifies() { .configure() .unwrap() .verify_hostname(false) - .connect("fizzbuzz.com", s) + .setup_connect("fizzbuzz.com", s) + .unwrap() + .handshake() .is_err()); } @@ -733,7 +756,9 @@ fn connector_no_hostname_can_disable_verify() { .configure() .unwrap() .verify_hostname(false) - .connect("foobar.com", s) + .setup_connect("foobar.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -750,7 +775,7 @@ fn test_mozilla_server(new: fn(SslMethod) -> Result Result, io_cb: Box) + Send>, - err_cb: Box) + Send>, + err_cb: Box) + Send>, should_error: bool, expected_connections_count: usize, } @@ -102,7 +103,7 @@ impl Builder { self.io_cb = Box::new(cb); } - pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError) + Send + 'static) { + pub fn err_cb(&mut self, cb: impl FnMut(MidHandshakeSslStream) + Send + 'static) { self.should_error(); self.err_cb = Box::new(cb); @@ -133,7 +134,7 @@ impl Builder { ssl_cb(&mut ssl); - let r = ssl.accept(socket); + let r = ssl.setup_accept(socket).handshake(); if should_error { err_cb(r.unwrap_err()); @@ -176,7 +177,7 @@ impl ClientBuilder { self.build().builder().connect() } - pub fn connect_err(self) -> HandshakeError { + pub fn connect_err(self) -> MidHandshakeSslStream { self.build().builder().connect_err() } } @@ -207,12 +208,12 @@ impl ClientSslBuilder { pub fn connect(self) -> SslStream { let socket = TcpStream::connect(self.addr).unwrap(); - let mut s = self.ssl.connect(socket).unwrap(); + let mut s = self.ssl.setup_connect(socket).handshake().unwrap(); s.read_exact(&mut [0]).unwrap(); s } - pub fn connect_err(self) -> HandshakeError { + pub fn connect_err(self) -> MidHandshakeSslStream { let socket = TcpStream::connect(self.addr).unwrap(); self.ssl.setup_connect(socket).handshake().unwrap_err() diff --git a/boring/src/ssl/test/session.rs b/boring/src/ssl/test/session.rs index 23c0f4d5d..873a7f236 100644 --- a/boring/src/ssl/test/session.rs +++ b/boring/src/ssl/test/session.rs @@ -4,8 +4,8 @@ use std::sync::OnceLock; use crate::ssl::test::server::Server; use crate::ssl::{ - ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder, - SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion, + ErrorCode, GetSessionPendingError, Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, + SslSession, SslSessionCacheMode, SslVersion, }; #[test] @@ -125,11 +125,7 @@ fn new_get_session_callback_pending() { }); } server.ctx().set_session_id_context(b"foo").unwrap(); - server.err_cb(|error| { - let HandshakeError::WouldBlock(mid_handshake) = error else { - panic!("should be WouldBlock"); - }; - + server.err_cb(|mid_handshake| { assert!(mid_handshake.error().would_block()); assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION); diff --git a/hyper-boring/tests/v1.rs b/hyper-boring/tests/v1.rs index 4082d2cef..90769e13c 100644 --- a/hyper-boring/tests/v1.rs +++ b/hyper-boring/tests/v1.rs @@ -47,7 +47,10 @@ async fn localhost() { for _ in 0..3 { let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); let service = service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(>::new())) @@ -117,7 +120,10 @@ async fn alpn_h2() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service = service::service_fn(|_| async { diff --git a/tokio-boring/examples/simple-async.rs b/tokio-boring/examples/simple-async.rs index f4a69a1c9..d9b159ded 100644 --- a/tokio-boring/examples/simple-async.rs +++ b/tokio-boring/examples/simple-async.rs @@ -11,6 +11,6 @@ async fn main() -> anyhow::Result<()> { ssl_builder.set_default_verify_paths()?; ssl_builder.set_verify(ssl::SslVerifyMode::PEER); let acceptor = ssl_builder.build(); - let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream).await?; + let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream)?.await?; Ok(()) } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index 374a0bde0..513cc0389 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,6 +13,7 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, @@ -42,34 +43,40 @@ pub use boring::ssl::{ /// /// This function automatically sets the task waker on the `Ssl` from `config` to /// allow to make use of async callbacks provided by the boring crate. -pub async fn connect( +pub fn connect( config: ConnectConfiguration, domain: &str, stream: S, -) -> Result, HandshakeError> +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - let mid_handshake = config - .setup_connect(domain, AsyncStreamBridge::new(stream)) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; - - HandshakeFuture(Some(mid_handshake)).await + handshake(|s| config.setup_connect(domain, s), stream) } /// Asynchronously performs a server-side TLS handshake over the provided stream. /// /// This function automatically sets the task waker on the `Ssl` from `config` to /// allow to make use of async callbacks provided by the boring crate. -pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> +pub fn accept(acceptor: &SslAcceptor, stream: S) -> Result, ErrorStack> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + handshake(|s| acceptor.setup_accept(s), stream) +} + +fn handshake( + f: impl FnOnce( + AsyncStreamBridge, + ) -> Result>, ErrorStack>, + stream: S, +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - let mid_handshake = acceptor - .setup_accept(AsyncStreamBridge::new(stream)) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; + let ongoing_handshake = Some(f(AsyncStreamBridge::new(stream))?); - HandshakeFuture(Some(mid_handshake)).await + Ok(HandshakeFuture(ongoing_handshake)) } fn cvt(r: io::Result) -> Poll> { @@ -252,52 +259,33 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(MidHandshakeSslStream>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. - #[must_use] - pub fn ssl(&self) -> Option<&SslRef> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.ssl()), - _ => None, - } + pub fn ssl(&self) -> &SslRef { + self.0.ssl() } /// Converts error to the source data stream that was used for the handshake. - #[must_use] - pub fn into_source_stream(self) -> Option { - match self.0 { - ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), - _ => None, - } + pub fn into_source_stream(self) -> S { + self.0.into_source_stream().stream } /// Returns a reference to the source data stream. - #[must_use] - pub fn as_source_stream(&self) -> Option<&S> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream), - _ => None, - } + pub fn as_source_stream(&self) -> &S { + &self.0.get_ref().stream } /// Returns the error code, if any. - #[must_use] - pub fn code(&self) -> Option { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.error().code()), - _ => None, - } + pub fn code(&self) -> ErrorCode { + self.0.error().code() } /// Returns a reference to the inner I/O error, if any. #[must_use] pub fn as_io_error(&self) -> Option<&io::Error> { - match &self.0 { - ssl::HandshakeError::Failure(s) => s.error().io_error(), - _ => None, - } + self.0.error().io_error() } } @@ -312,7 +300,7 @@ where impl fmt::Display for HandshakeError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, fmt) + self.0.display_error().fmt(fmt) } } @@ -321,7 +309,7 @@ where S: fmt::Debug, { fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() + self.0.error().source() } } @@ -351,7 +339,7 @@ where Poll::Ready(Ok(SslStream(stream))) } - Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { + Err(mut mid_handshake) if mid_handshake.error().would_block() => { mid_handshake.get_mut().set_waker(None); mid_handshake.ssl_mut().set_task_waker(None); @@ -359,15 +347,10 @@ where Poll::Pending } - Err(ssl::HandshakeError::Failure(mut mid_handshake)) => { + Err(mut mid_handshake) => { mid_handshake.get_mut().set_waker(None); - Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure( - mid_handshake, - )))) - } - Err(err @ ssl::HandshakeError::SetupFailure(_)) => { - Poll::Ready(Err(HandshakeError(err))) + Poll::Ready(Err(HandshakeError(mid_handshake))) } } } diff --git a/tokio-boring/tests/async_get_session.rs b/tokio-boring/tests/async_get_session.rs index 0ab9b396e..3cb0c7a61 100644 --- a/tokio-boring/tests/async_get_session.rs +++ b/tokio-boring/tests/async_get_session.rs @@ -57,6 +57,7 @@ async fn test() { let server = async move { tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .unwrap() .await .unwrap(); @@ -64,6 +65,7 @@ async fn test() { assert!(!FOUND_SESSION.load(Ordering::SeqCst)); tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .unwrap() .await .unwrap(); @@ -76,6 +78,7 @@ async fn test() { "localhost", TcpStream::connect(&addr).await.unwrap(), ) + .unwrap() .await .unwrap(); @@ -94,6 +97,7 @@ async fn test() { "localhost", TcpStream::connect(&addr).await.unwrap(), ) + .unwrap() .await .unwrap(); }; diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 925f9875e..872033f8e 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -19,6 +19,7 @@ async fn google() { .configure() .unwrap(); let mut stream = tokio_boring::connect(config, "google.com", stream) + .unwrap() .await .unwrap(); @@ -44,15 +45,11 @@ async fn handshake_error() { let (stream, addr) = create_server(|_| ()); let server = async { - let err = stream.await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = stream.await.unwrap_err(); }; let client = async { - let err = connect(addr, |_| Ok(())).await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = connect(addr, |_| Ok(())).await.unwrap_err(); }; future::join(server, client).await; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs index b28917b42..a94787baa 100644 --- a/tokio-boring/tests/common/mod.rs +++ b/tokio-boring/tests/common/mod.rs @@ -24,7 +24,7 @@ pub(crate) fn create_server( let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -65,7 +65,9 @@ pub(crate) async fn connect( let stream = TcpStream::connect(&addr).await.unwrap(); - tokio_boring::connect(config, "localhost", stream).await + tokio_boring::connect(config, "localhost", stream) + .unwrap() + .await } pub(crate) fn create_connector( diff --git a/tokio-boring/tests/rpk.rs b/tokio-boring/tests/rpk.rs index 5492767ab..48f53655c 100644 --- a/tokio-boring/tests/rpk.rs +++ b/tokio-boring/tests/rpk.rs @@ -34,7 +34,7 @@ mod test_rpk { let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -66,6 +66,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let mut stream = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap(); @@ -97,6 +98,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let err = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap_err();