Skip to content

Commit b832823

Browse files
authored
Merge pull request #81 from stackhpc/resource-management
Add support for resource management
2 parents 138c501 + fcb8b86 commit b832823

File tree

7 files changed

+228
-11
lines changed

7 files changed

+228
-11
lines changed

benches/s3_client.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use aws_sdk_s3::Client;
55
use aws_types::region::Region;
66
use axum::body::Bytes;
77
use criterion::{black_box, criterion_group, criterion_main, Criterion};
8+
use reductionist::resource_manager::ResourceManager;
89
use reductionist::s3_client::{S3Client, S3ClientMap};
910
use url::Url;
1011
// Bring trait into scope to use as_bytes method.
@@ -42,6 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) {
4243
let bucket = "s3-client-bench";
4344
let runtime = tokio::runtime::Runtime::new().unwrap();
4445
let map = S3ClientMap::new();
46+
let resource_manager = ResourceManager::new(None, None, None);
4547
for size_k in [64, 256, 1024] {
4648
let size: isize = size_k * 1024;
4749
let data: Vec<u32> = (0_u32..(size as u32)).collect::<Vec<u32>>();
@@ -53,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) {
5355
b.to_async(&runtime).iter(|| async {
5456
let client = S3Client::new(&url, username, password).await;
5557
client
56-
.download_object(black_box(bucket), &key, None)
58+
.download_object(black_box(bucket), &key, None, &resource_manager, &mut None)
5759
.await
5860
.unwrap();
5961
})
@@ -63,7 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) {
6365
b.to_async(&runtime).iter(|| async {
6466
let client = map.get(&url, username, password).await;
6567
client
66-
.download_object(black_box(bucket), &key, None)
68+
.download_object(black_box(bucket), &key, None, &resource_manager, &mut None)
6769
.await
6870
.unwrap();
6971
})

src/app.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::metrics::{metrics_handler, track_metrics};
77
use crate::models;
88
use crate::operation;
99
use crate::operations;
10+
use crate::resource_manager::ResourceManager;
1011
use crate::s3_client;
1112
use crate::types::{ByteOrder, NATIVE_BYTE_ORDER};
1213
use crate::validated_json::ValidatedJson;
@@ -25,6 +26,7 @@ use axum::{
2526
};
2627

2728
use std::sync::Arc;
29+
use tokio::sync::SemaphorePermit;
2830
use tower::Layer;
2931
use tower::ServiceBuilder;
3032
use tower_http::normalize_path::NormalizePathLayer;
@@ -54,14 +56,21 @@ struct AppState {
5456

5557
/// Map of S3 client objects.
5658
s3_client_map: s3_client::S3ClientMap,
59+
60+
/// Resource manager.
61+
resource_manager: ResourceManager,
5762
}
5863

5964
impl AppState {
6065
/// Create and return an [AppState].
6166
fn new(args: &CommandLineArgs) -> Self {
67+
let task_limit = args.thread_limit.or_else(|| Some(num_cpus::get() - 1));
68+
let resource_manager =
69+
ResourceManager::new(args.s3_connection_limit, args.memory_limit, task_limit);
6270
Self {
6371
args: args.clone(),
6472
s3_client_map: s3_client::S3ClientMap::new(),
73+
resource_manager,
6574
}
6675
}
6776
}
@@ -176,14 +185,26 @@ async fn schema() -> &'static str {
176185
///
177186
/// * `auth`: Basic authentication credentials
178187
/// * `request_data`: RequestData object for the request
179-
#[tracing::instrument(level = "DEBUG", skip(client, request_data))]
180-
async fn download_object(
188+
#[tracing::instrument(
189+
level = "DEBUG",
190+
skip(client, request_data, resource_manager, mem_permits)
191+
)]
192+
async fn download_object<'a>(
181193
client: &s3_client::S3Client,
182194
request_data: &models::RequestData,
195+
resource_manager: &'a ResourceManager,
196+
mem_permits: &mut Option<SemaphorePermit<'a>>,
183197
) -> Result<Bytes, ActiveStorageError> {
184198
let range = s3_client::get_range(request_data.offset, request_data.size);
199+
let _conn_permits = resource_manager.s3_connection().await?;
185200
client
186-
.download_object(&request_data.bucket, &request_data.object, range)
201+
.download_object(
202+
&request_data.bucket,
203+
&request_data.object,
204+
range,
205+
resource_manager,
206+
mem_permits,
207+
)
187208
.await
188209
}
189210

@@ -206,19 +227,27 @@ async fn operation_handler<T: operation::Operation>(
206227
TypedHeader(auth): TypedHeader<Authorization<Basic>>,
207228
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
208229
) -> Result<models::Response, ActiveStorageError> {
230+
let memory = request_data.size.unwrap_or(0);
231+
let mut _mem_permits = state.resource_manager.memory(memory).await?;
209232
let s3_client = state
210233
.s3_client_map
211234
.get(&request_data.source, auth.username(), auth.password())
212235
.instrument(tracing::Span::current())
213236
.await;
214-
let data = download_object(&s3_client, &request_data)
215-
.instrument(tracing::Span::current())
216-
.await?;
237+
let data = download_object(
238+
&s3_client,
239+
&request_data,
240+
&state.resource_manager,
241+
&mut _mem_permits,
242+
)
243+
.instrument(tracing::Span::current())
244+
.await?;
217245
// All remaining work is synchronous. If the use_rayon argument was specified, delegate to the
218246
// Rayon thread pool. Otherwise, execute as normal using Tokio.
219247
if state.args.use_rayon {
220248
tokio_rayon::spawn(move || operation::<T>(request_data, data)).await
221249
} else {
250+
let _task_permit = state.resource_manager.task().await?;
222251
operation::<T>(request_data, data)
223252
}
224253
}

src/cli.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ pub struct CommandLineArgs {
3737
/// Whether to use Rayon for execution of CPU-bound tasks.
3838
#[arg(long, default_value_t = false, env = "REDUCTIONIST_USE_RAYON")]
3939
pub use_rayon: bool,
40+
/// Memory limit in bytes. Default is no limit.
41+
#[arg(long, env = "REDUCTIONIST_MEMORY_LIMIT")]
42+
pub memory_limit: Option<usize>,
43+
/// S3 connection limit. Default is no limit.
44+
#[arg(long, env = "REDUCTIONIST_S3_CONNECTION_LIMIT")]
45+
pub s3_connection_limit: Option<usize>,
46+
/// Thread limit for CPU-bound tasks. Default is one less than the number of CPUs. Used only
47+
/// when use_rayon is false.
48+
#[arg(long, env = "REDUCTIONIST_THREAD_LIMIT")]
49+
pub thread_limit: Option<usize>,
4050
}
4151

4252
/// Returns parsed command line arguments.

src/error.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use ndarray::ShapeError;
1414
use serde::{Deserialize, Serialize};
1515
use std::error::Error;
1616
use thiserror::Error;
17+
use tokio::sync::AcquireError;
1718
use tracing::{event, Level};
1819
use zune_inflate::errors::InflateDecodeErrors;
1920

@@ -41,9 +42,14 @@ pub enum ActiveStorageError {
4142
#[error("failed to convert from bytes to {type_name}")]
4243
FromBytes { type_name: &'static str },
4344

45+
/// Incompatible missing data descriptor
4446
#[error("Incompatible value {0} for missing")]
4547
IncompatibleMissing(DValue),
4648

49+
/// Insufficient memory to process request
50+
#[error("Insufficient memory to process request ({requested} > {total})")]
51+
InsufficientMemory { requested: usize, total: usize },
52+
4753
/// Error deserialising request data into RequestData
4854
#[error("request data is not valid")]
4955
RequestDataJsonRejection(#[from] JsonRejection),
@@ -64,6 +70,10 @@ pub enum ActiveStorageError {
6470
#[error("error retrieving object from S3 storage")]
6571
S3GetObject(#[from] SdkError<GetObjectError>),
6672

73+
/// Error acquiring a semaphore
74+
#[error("error acquiring resources")]
75+
SemaphoreAcquireError(#[from] AcquireError),
76+
6777
/// Error creating ndarray ArrayView from Shape
6878
#[error("failed to create array from shape")]
6979
ShapeInvalid(#[from] ShapeError),
@@ -196,6 +206,10 @@ impl From<ActiveStorageError> for ErrorResponse {
196206
| ActiveStorageError::DecompressionZune(_)
197207
| ActiveStorageError::EmptyArray { operation: _ }
198208
| ActiveStorageError::IncompatibleMissing(_)
209+
| ActiveStorageError::InsufficientMemory {
210+
requested: _,
211+
total: _,
212+
}
199213
| ActiveStorageError::RequestDataJsonRejection(_)
200214
| ActiveStorageError::RequestDataValidationSingle(_)
201215
| ActiveStorageError::RequestDataValidation(_)
@@ -207,7 +221,8 @@ impl From<ActiveStorageError> for ErrorResponse {
207221
// Internal server error
208222
ActiveStorageError::FromBytes { type_name: _ }
209223
| ActiveStorageError::TryFromInt(_)
210-
| ActiveStorageError::S3ByteStream(_) => Self::internal_server_error(&error),
224+
| ActiveStorageError::S3ByteStream(_)
225+
| ActiveStorageError::SemaphoreAcquireError(_) => Self::internal_server_error(&error),
211226

212227
ActiveStorageError::S3GetObject(sdk_error) => {
213228
// Tailor the response based on the specific SdkError variant.
@@ -377,6 +392,17 @@ mod tests {
377392
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
378393
}
379394

395+
#[tokio::test]
396+
async fn insufficient_memory() {
397+
let error = ActiveStorageError::InsufficientMemory {
398+
requested: 2,
399+
total: 1,
400+
};
401+
let message = "Insufficient memory to process request (2 > 1)";
402+
let caused_by = None;
403+
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
404+
}
405+
380406
#[tokio::test]
381407
async fn request_data_validation_single() {
382408
let validation_error = validator::ValidationError::new("foo");
@@ -504,6 +530,17 @@ mod tests {
504530
.await;
505531
}
506532

533+
#[tokio::test]
534+
async fn semaphore_acquire_error() {
535+
let sem = tokio::sync::Semaphore::new(1);
536+
sem.close();
537+
let error = ActiveStorageError::SemaphoreAcquireError(sem.acquire().await.unwrap_err());
538+
let message = "error acquiring resources";
539+
let caused_by = Some(vec!["semaphore closed"]);
540+
test_active_storage_error(error, StatusCode::INTERNAL_SERVER_ERROR, message, caused_by)
541+
.await;
542+
}
543+
507544
#[tokio::test]
508545
async fn shape_error() {
509546
let error = ActiveStorageError::ShapeInvalid(ShapeError::from_kind(

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ pub mod metrics;
296296
pub mod models;
297297
pub mod operation;
298298
pub mod operations;
299+
pub mod resource_manager;
299300
pub mod s3_client;
300301
pub mod server;
301302
#[cfg(test)]

src/resource_manager.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//! Resource management
2+
3+
use crate::error::ActiveStorageError;
4+
5+
use tokio::sync::{Semaphore, SemaphorePermit};
6+
7+
/// [crate::resource_manager::ResourceManager] provides a simple way to allocate various resources
8+
/// to tasks. Resource management is performed using a Tokio Semaphore for each type of resource.
9+
pub struct ResourceManager {
10+
/// Optional semaphore for S3 connections.
11+
s3_connections: Option<Semaphore>,
12+
13+
/// Optional semaphore for memory (bytes).
14+
memory: Option<Semaphore>,
15+
16+
/// Optional total memory pool in bytes.
17+
total_memory: Option<usize>,
18+
19+
/// Optional semaphore for tasks.
20+
tasks: Option<Semaphore>,
21+
}
22+
23+
impl ResourceManager {
24+
/// Returns a new ResourceManager object.
25+
pub fn new(
26+
s3_connection_limit: Option<usize>,
27+
memory_limit: Option<usize>,
28+
task_limit: Option<usize>,
29+
) -> Self {
30+
Self {
31+
s3_connections: s3_connection_limit.map(Semaphore::new),
32+
memory: memory_limit.map(Semaphore::new),
33+
total_memory: memory_limit,
34+
tasks: task_limit.map(Semaphore::new),
35+
}
36+
}
37+
38+
/// Acquire an S3 connection resource.
39+
pub async fn s3_connection(&self) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
40+
optional_acquire(&self.s3_connections, 1).await
41+
}
42+
43+
/// Acquire memory resource.
44+
pub async fn memory(
45+
&self,
46+
bytes: usize,
47+
) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
48+
if let Some(total_memory) = self.total_memory {
49+
if bytes > total_memory {
50+
return Err(ActiveStorageError::InsufficientMemory {
51+
requested: bytes,
52+
total: total_memory,
53+
});
54+
};
55+
};
56+
optional_acquire(&self.memory, bytes).await
57+
}
58+
59+
/// Acquire a task resource.
60+
pub async fn task(&self) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
61+
optional_acquire(&self.tasks, 1).await
62+
}
63+
}
64+
65+
/// Acquire permits on an optional Semaphore, if present.
66+
async fn optional_acquire(
67+
sem: &Option<Semaphore>,
68+
n: usize,
69+
) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
70+
let n = n.try_into()?;
71+
if let Some(sem) = sem {
72+
sem.acquire_many(n)
73+
.await
74+
.map(Some)
75+
.map_err(|err| err.into())
76+
} else {
77+
Ok(None)
78+
}
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
use super::*;
84+
85+
use tokio::sync::TryAcquireError;
86+
87+
#[tokio::test]
88+
async fn no_resource_management() {
89+
let rm = ResourceManager::new(None, None, None);
90+
assert!(rm.s3_connections.is_none());
91+
assert!(rm.memory.is_none());
92+
assert!(rm.tasks.is_none());
93+
let _c = rm.s3_connection().await.unwrap();
94+
let _m = rm.memory(1).await.unwrap();
95+
let _t = rm.task().await.unwrap();
96+
assert!(_c.is_none());
97+
assert!(_m.is_none());
98+
assert!(_t.is_none());
99+
}
100+
101+
#[tokio::test]
102+
async fn full_resource_management() {
103+
let rm = ResourceManager::new(Some(1), Some(1), Some(1));
104+
assert!(rm.s3_connections.is_some());
105+
assert!(rm.memory.is_some());
106+
assert!(rm.tasks.is_some());
107+
let _c = rm.s3_connection().await.unwrap();
108+
let _m = rm.memory(1).await.unwrap();
109+
let _t = rm.task().await.unwrap();
110+
assert!(_c.is_some());
111+
assert!(_m.is_some());
112+
assert!(_t.is_some());
113+
// Check that there are no more resources (without blocking).
114+
assert_eq!(
115+
rm.s3_connections.as_ref().unwrap().try_acquire().err(),
116+
Some(TryAcquireError::NoPermits)
117+
);
118+
assert_eq!(
119+
rm.memory.as_ref().unwrap().try_acquire().err(),
120+
Some(TryAcquireError::NoPermits)
121+
);
122+
assert_eq!(
123+
rm.tasks.as_ref().unwrap().try_acquire().err(),
124+
Some(TryAcquireError::NoPermits)
125+
);
126+
}
127+
}

0 commit comments

Comments
 (0)