Skip to content

Commit de1a56f

Browse files
authored
Merge pull request #51 from stackhpc/compression
Add support for Gzip and Zlib compression
2 parents db787cf + c3490bb commit de1a56f

File tree

15 files changed

+379
-19
lines changed

15 files changed

+379
-19
lines changed

Cargo.lock

Lines changed: 21 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ axum = { version = "0.6", features = ["headers"] }
1818
axum-server = { version = "0.4.7", features = ["tls-rustls"] }
1919
clap = { version = "4.2", features = ["derive", "env"] }
2020
expanduser = "1.2.2"
21+
flate2 = "1.0"
2122
http = "*"
2223
hyper = { version = "0.14", features = ["full"] }
2324
lazy_static = "1.4"

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ with a JSON payload of the form:
7373
"selection": [
7474
[0, 19, 2],
7575
[1, 3, 1]
76-
]
76+
],
77+
78+
// Algorithm used to compress the data
79+
// - optional, defaults to no compression
80+
"compression": "gzip|zlib"
7781
}
7882
```
7983

@@ -92,7 +96,7 @@ In particular, the following are known limitations which we intend to address:
9296

9397
* Error handling and reporting is minimal
9498
* No support for missing data
95-
* No support for compressed or encrypted objects
99+
* No support for encrypted objects
96100

97101
## Running
98102

scripts/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def get_args() -> argparse.Namespace:
3636
parser.add_argument("--shape", type=str)
3737
parser.add_argument("--order", default="C") #, choices=["C", "F"]) allow invalid for testing
3838
parser.add_argument("--selection", type=str)
39+
parser.add_argument("--compression", type=str)
3940
parser.add_argument("--show-response-headers", action=argparse.BooleanOptionalAction)
4041
return parser.parse_args()
4142

@@ -49,6 +50,7 @@ def build_request_data(args: argparse.Namespace) -> dict:
4950
'offset': args.offset,
5051
'size': args.size,
5152
'order': args.order,
53+
'compression': args.compression,
5254
}
5355
if args.shape:
5456
request_data["shape"] = json.loads(args.shape)

scripts/upload_sample_data.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from enum import Enum
2+
import gzip
23
import numpy as np
34
import pathlib
45
import s3fs
6+
import zlib
57

68
NUM_ITEMS = 10
79
OBJECT_PREFIX = "data"
10+
COMPRESSION_ALGS = [None, "gzip", "zlib"]
811

912
#Use enum which also subclasses string type so that
1013
# auto-generated OpenAPI schema can determine allowed dtypes
@@ -33,8 +36,16 @@ def n_bytes(self):
3336
pass
3437

3538
# Create numpy arrays and upload to S3 as bytes
36-
for d in AllowedDatatypes.__members__.keys():
37-
with s3_fs.open(bucket / f'{OBJECT_PREFIX}-{d}.dat', 'wb') as s3_file:
38-
s3_file.write(np.arange(NUM_ITEMS, dtype=d).tobytes())
39+
for compression in COMPRESSION_ALGS:
40+
compression_suffix = f"-{compression}" if compression else ""
41+
for d in AllowedDatatypes.__members__.keys():
42+
obj_name = f'{OBJECT_PREFIX}-{d}{compression_suffix}.dat'
43+
with s3_fs.open(bucket / obj_name, 'wb') as s3_file:
44+
data = np.arange(NUM_ITEMS, dtype=d).tobytes()
45+
if compression == "gzip":
46+
data = gzip.compress(data)
47+
elif compression == "zlib":
48+
data = zlib.compress(data)
49+
s3_file.write(data)
3950

4051
print("Data upload successful. \nBucket contents:\n", s3_fs.ls(bucket))

src/app.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Active Storage server API
22
33
use crate::error::ActiveStorageError;
4+
use crate::filter_pipeline;
45
use crate::metrics::{metrics_handler, track_metrics};
56
use crate::models;
67
use crate::operation;
@@ -159,6 +160,11 @@ async fn operation_handler<T: operation::Operation>(
159160
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
160161
) -> Result<models::Response, ActiveStorageError> {
161162
let data = download_object(&auth, &request_data).await?;
163+
let data = filter_pipeline::filter_pipeline(&request_data, &data)?;
164+
if request_data.compression.is_some() || request_data.size.is_none() {
165+
// Validate the raw uncompressed data size now that we know it.
166+
models::validate_raw_size(data.len(), request_data.dtype, &request_data.shape)?;
167+
}
162168
T::execute(&request_data, &data)
163169
}
164170

src/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ mod tests {
236236
shape: None,
237237
order: None,
238238
selection: None,
239+
compression: None,
239240
},
240241
);
241242
assert_eq!([42], shape.raw_dim().as_array_view().as_slice().unwrap());
@@ -255,6 +256,7 @@ mod tests {
255256
shape: Some(vec![1, 2, 3]),
256257
order: None,
257258
selection: None,
259+
compression: None,
258260
},
259261
);
260262
assert_eq!(
@@ -458,6 +460,7 @@ mod tests {
458460
shape: None,
459461
order: None,
460462
selection: None,
463+
compression: None,
461464
};
462465
let bytes = Bytes::copy_from_slice(&data);
463466
let array = build_array::<u32>(&request_data, &bytes).unwrap();
@@ -477,6 +480,7 @@ mod tests {
477480
shape: Some(vec![2, 1]),
478481
order: None,
479482
selection: None,
483+
compression: None,
480484
};
481485
let bytes = Bytes::copy_from_slice(&data);
482486
let array = build_array::<i64>(&request_data, &bytes).unwrap();
@@ -496,6 +500,7 @@ mod tests {
496500
shape: None,
497501
order: None,
498502
selection: None,
503+
compression: None,
499504
};
500505
let bytes = Bytes::copy_from_slice(&data);
501506
let array = build_array::<u32>(&request_data, &bytes).unwrap();

src/compression.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//! (De)compression support.
2+
3+
use crate::error::ActiveStorageError;
4+
use crate::models;
5+
6+
use axum::body::Bytes;
7+
use flate2::read::{GzDecoder, ZlibDecoder};
8+
use std::io::Read;
9+
10+
/// Decompresses some Bytes and returns the uncompressed data.
11+
///
12+
/// # Arguments
13+
///
14+
/// * `compression`: Compression algorithm
15+
/// * `data`: Compressed data [Bytes](axum::body::Bytes)
16+
pub fn decompress(
17+
compression: models::Compression,
18+
data: &Bytes,
19+
) -> Result<Bytes, ActiveStorageError> {
20+
let mut decoder: Box<dyn Read> = match compression {
21+
models::Compression::Gzip => Box::new(GzDecoder::<&[u8]>::new(data)),
22+
models::Compression::Zlib => Box::new(ZlibDecoder::<&[u8]>::new(data)),
23+
};
24+
// The data returned by the S3 client does not have any alignment guarantees. In order to
25+
// reinterpret the data as an array of numbers with a higher alignment than 1, we need to
26+
// return the data in Bytes object in which the underlying data has a higher alignment.
27+
// For now we're hard-coding an alignment of 8 bytes, although this should depend on the
28+
// data type, and potentially whether there are any SIMD requirements.
29+
// Create an 8-byte aligned Vec<u8>.
30+
// FIXME: The compressed length will not be enough to store the uncompressed data, and may
31+
// result in a change in the underlying buffer to one that is not correctly aligned.
32+
let mut buf = maligned::align_first::<u8, maligned::A8>(data.len());
33+
decoder.read_to_end(&mut buf)?;
34+
// Release any unnecessary capacity.
35+
buf.shrink_to(0);
36+
Ok(buf.into())
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use super::*;
42+
use flate2::read::{GzEncoder, ZlibEncoder};
43+
use flate2::Compression;
44+
45+
fn compress_gzip() -> Vec<u8> {
46+
// Adapated from flate2 documentation.
47+
let mut result = Vec::<u8>::new();
48+
let input = b"hello world";
49+
let mut deflater = GzEncoder::new(&input[..], Compression::fast());
50+
deflater.read_to_end(&mut result).unwrap();
51+
result
52+
}
53+
54+
fn compress_zlib() -> Vec<u8> {
55+
// Adapated from flate2 documentation.
56+
let mut result = Vec::<u8>::new();
57+
let input = b"hello world";
58+
let mut deflater = ZlibEncoder::new(&input[..], Compression::fast());
59+
deflater.read_to_end(&mut result).unwrap();
60+
result
61+
}
62+
63+
#[test]
64+
fn test_decompress_gzip() {
65+
let compressed = compress_gzip();
66+
let result = decompress(models::Compression::Gzip, &compressed.into()).unwrap();
67+
assert_eq!(result, b"hello world".as_ref());
68+
assert_eq!(result.as_ptr().align_offset(8), 0);
69+
}
70+
71+
#[test]
72+
fn test_decompress_zlib() {
73+
let compressed = compress_zlib();
74+
let result = decompress(models::Compression::Zlib, &compressed.into()).unwrap();
75+
assert_eq!(result, b"hello world".as_ref());
76+
assert_eq!(result.as_ptr().align_offset(8), 0);
77+
}
78+
79+
#[test]
80+
fn test_decompress_invalid_gzip() {
81+
let invalid = b"invalid format";
82+
let err = decompress(models::Compression::Gzip, &invalid.as_ref().into()).unwrap_err();
83+
match err {
84+
ActiveStorageError::Decompression(io_err) => {
85+
assert_eq!(io_err.kind(), std::io::ErrorKind::InvalidInput);
86+
assert_eq!(io_err.to_string(), "invalid gzip header");
87+
}
88+
err => panic!("unexpected error {}", err),
89+
}
90+
}
91+
92+
#[test]
93+
fn test_decompress_invalid_zlib() {
94+
let invalid = b"invalid format";
95+
let err = decompress(models::Compression::Zlib, &invalid.as_ref().into()).unwrap_err();
96+
match err {
97+
ActiveStorageError::Decompression(io_err) => {
98+
assert_eq!(io_err.kind(), std::io::ErrorKind::InvalidInput);
99+
assert_eq!(io_err.to_string(), "corrupt deflate stream");
100+
}
101+
err => panic!("unexpected error {}", err),
102+
}
103+
}
104+
}

src/error.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ use tracing::{event, Level};
2222
/// Each variant may result in a different API error response.
2323
#[derive(Debug, Error)]
2424
pub enum ActiveStorageError {
25+
/// Error decompressing data
26+
#[error("failed to decompress data")]
27+
Decompression(#[from] std::io::Error),
28+
2529
/// Attempt to perform an invalid operation on an empty array or selection
2630
#[error("cannot perform {operation} on empty array or selection")]
2731
EmptyArray { operation: &'static str },
@@ -34,7 +38,11 @@ pub enum ActiveStorageError {
3438
#[error("request data is not valid")]
3539
RequestDataJsonRejection(#[from] JsonRejection),
3640

37-
/// Error validating RequestData
41+
/// Error validating RequestData (single error)
42+
#[error("request data is not valid")]
43+
RequestDataValidationSingle(#[from] validator::ValidationError),
44+
45+
/// Error validating RequestData (multiple errors)
3846
#[error("request data is not valid")]
3947
RequestDataValidation(#[from] validator::ValidationErrors),
4048

@@ -174,8 +182,10 @@ impl From<ActiveStorageError> for ErrorResponse {
174182
fn from(error: ActiveStorageError) -> Self {
175183
let response = match &error {
176184
// Bad request
177-
ActiveStorageError::EmptyArray { operation: _ }
185+
ActiveStorageError::Decompression(_)
186+
| ActiveStorageError::EmptyArray { operation: _ }
178187
| ActiveStorageError::RequestDataJsonRejection(_)
188+
| ActiveStorageError::RequestDataValidationSingle(_)
179189
| ActiveStorageError::RequestDataValidation(_)
180190
| ActiveStorageError::ShapeInvalid(_) => Self::bad_request(&error),
181191

@@ -309,6 +319,15 @@ mod tests {
309319
assert_eq!(caused_by, error_response.error.caused_by);
310320
}
311321

322+
#[tokio::test]
323+
async fn decompression_error() {
324+
let io_error = std::io::Error::new(std::io::ErrorKind::InvalidInput, "decompression error");
325+
let error = ActiveStorageError::Decompression(io_error);
326+
let message = "failed to decompress data";
327+
let caused_by = Some(vec!["decompression error"]);
328+
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
329+
}
330+
312331
#[tokio::test]
313332
async fn empty_array_op_error() {
314333
let error = ActiveStorageError::EmptyArray { operation: "foo" };
@@ -326,6 +345,15 @@ mod tests {
326345
.await;
327346
}
328347

348+
#[tokio::test]
349+
async fn request_data_validation_single() {
350+
let validation_error = validator::ValidationError::new("foo");
351+
let error = ActiveStorageError::RequestDataValidationSingle(validation_error);
352+
let message = "request data is not valid";
353+
let caused_by = Some(vec!["Validation error: foo [{}]"]);
354+
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
355+
}
356+
329357
#[tokio::test]
330358
async fn request_data_validation() {
331359
let mut validation_errors = validator::ValidationErrors::new();

0 commit comments

Comments
 (0)