diff --git a/Cargo.toml b/Cargo.toml index 03cbed0c..189e06b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ futures-util = { version = "0.3.31", default-features = false, features = ["io"] [target.'cfg(not(target_arch = "wasm32"))'.dependencies] jsonwebtoken = { version = "9.3.1", default-features = false } +tokio = { version = "1.38", optional = true, features = ["time"] } [target.'cfg(target_arch = "wasm32")'.dependencies] uuid = { version = "1.17.0", default-features = false, features = ["v4", "js"] } @@ -42,7 +43,7 @@ wasm-bindgen-futures = "0.4" [features] default = ["reqwest"] -reqwest = ["dep:reqwest", "pin-project-lite", "bytes"] +reqwest = ["dep:reqwest", "dep:tokio", "pin-project-lite", "bytes"] futures-unsend = [] [dev-dependencies] diff --git a/src/client.rs b/src/client.rs index 3e5068c9..81ba91d7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,7 +12,7 @@ use crate::{ search::*, task_info::TaskInfo, tasks::{Task, TasksCancelQuery, TasksDeleteQuery, TasksResults, TasksSearchQuery}, - utils::async_sleep, + utils::SleepBackend, DefaultHttpClient, }; @@ -933,7 +933,7 @@ impl Client { } Task::Enqueued { .. } | Task::Processing { .. } => { elapsed_time += interval; - async_sleep(interval).await; + self.sleep_backend().sleep(interval).await; } }, Err(error) => return Err(error), @@ -1144,6 +1144,10 @@ impl Client { crate::tenant_tokens::generate_tenant_token(api_key_uid, search_rules, api_key, expires_at) } + + fn sleep_backend(&self) -> SleepBackend { + SleepBackend::infer(self.http_client.is_tokio()) + } } #[derive(Debug, Clone, Deserialize)] diff --git a/src/request.rs b/src/request.rs index 49888adc..d3b426b8 100644 --- a/src/request.rs +++ b/src/request.rs @@ -101,6 +101,10 @@ pub trait HttpClient: Clone + Send + Sync { content_type: &str, expected_status_code: u16, ) -> Result; + + fn is_tokio(&self) -> bool { + false + } } pub fn parse_response( diff --git a/src/reqwest.rs b/src/reqwest.rs index 5a43d2da..356aac36 100644 --- a/src/reqwest.rs +++ b/src/reqwest.rs @@ -112,6 +112,10 @@ impl HttpClient for ReqwestClient { parse_response(status, expected_status_code, &body, url.to_string()) } + + fn is_tokio(&self) -> bool { + true + } } fn verb(method: &Method) -> reqwest::Method { diff --git a/src/utils.rs b/src/utils.rs index 91b63fe6..075f4e5f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,31 +1,65 @@ use std::time::Duration; -#[cfg(not(target_arch = "wasm32"))] -pub(crate) async fn async_sleep(interval: Duration) { - let (sender, receiver) = futures_channel::oneshot::channel::<()>(); - std::thread::spawn(move || { - std::thread::sleep(interval); - let _ = sender.send(()); - }); - let _ = receiver.await; +#[derive(Debug, Copy, Clone)] +pub(crate) enum SleepBackend { + #[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))] + Tokio, + #[cfg(not(target_arch = "wasm32"))] + Thread, + #[cfg(target_arch = "wasm32")] + Javascript, } -#[cfg(target_arch = "wasm32")] -pub(crate) async fn async_sleep(interval: Duration) { - use std::convert::TryInto; - use wasm_bindgen_futures::JsFuture; - - JsFuture::from(web_sys::js_sys::Promise::new(&mut |yes, _| { - web_sys::window() - .unwrap() - .set_timeout_with_callback_and_timeout_and_arguments_0( - &yes, - interval.as_millis().try_into().unwrap(), - ) - .unwrap(); - })) - .await - .unwrap(); +impl SleepBackend { + pub(crate) fn infer(is_tokio: bool) -> Self { + #[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))] + if is_tokio { + return Self::Tokio; + } + #[cfg(any(target_arch = "wasm32", not(feature = "reqwest")))] + let _ = is_tokio; + + #[cfg(not(target_arch = "wasm32"))] + return Self::Thread; + + #[cfg(target_arch = "wasm32")] + return Self::Javascript; + } + + pub(crate) async fn sleep(self, interval: Duration) { + match self { + #[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))] + Self::Tokio => { + tokio::time::sleep(interval).await; + } + #[cfg(not(target_arch = "wasm32"))] + Self::Thread => { + let (sender, receiver) = futures_channel::oneshot::channel::<()>(); + std::thread::spawn(move || { + std::thread::sleep(interval); + let _ = sender.send(()); + }); + let _ = receiver.await; + } + #[cfg(target_arch = "wasm32")] + Self::Javascript => { + use std::convert::TryInto; + use wasm_bindgen_futures::JsFuture; + + JsFuture::from(web_sys::js_sys::Promise::new(&mut |yes, _| { + web_sys::window() + .unwrap() + .set_timeout_with_callback_and_timeout_and_arguments_0( + &yes, + interval.as_millis().try_into().unwrap(), + ) + .unwrap(); + })) + .await + .unwrap(); + } + } + } } #[cfg(test)] @@ -33,12 +67,35 @@ mod test { use super::*; use meilisearch_test_macro::meilisearch_test; + #[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))] + #[meilisearch_test] + async fn sleep_tokio() { + let sleep_duration = Duration::from_millis(10); + let now = std::time::Instant::now(); + + SleepBackend::Tokio.sleep(sleep_duration).await; + + assert!(now.elapsed() >= sleep_duration); + } + + #[cfg(not(target_arch = "wasm32"))] + #[meilisearch_test] + async fn sleep_thread() { + let sleep_duration = Duration::from_millis(10); + let now = std::time::Instant::now(); + + SleepBackend::Thread.sleep(sleep_duration).await; + + assert!(now.elapsed() >= sleep_duration); + } + + #[cfg(target_arch = "wasm32")] #[meilisearch_test] - async fn test_async_sleep() { + async fn sleep_javascript() { let sleep_duration = Duration::from_millis(10); let now = std::time::Instant::now(); - async_sleep(sleep_duration).await; + SleepBackend::Javascript.sleep(sleep_duration).await; assert!(now.elapsed() >= sleep_duration); }