Skip to content

Commit 8c28d46

Browse files
committed
Perform size validation after decompression
When data is compressed, the size parameter refers to the size of the compressed data. Typically this is not equal to the size of the uncompressed data, so we can't validate it against the data type size. This change skips initial size/dtype validation when compression is applied, instead performing it once the data has been decompressed. It also adds an additional validation that the size matches the shape, when a shape has been specified.
1 parent 6518d00 commit 8c28d46

File tree

3 files changed

+69
-9
lines changed

3 files changed

+69
-9
lines changed

src/app.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ async fn operation_handler<T: operation::Operation>(
161161
) -> Result<models::Response, ActiveStorageError> {
162162
let data = download_object(&auth, &request_data).await?;
163163
let data = filter_pipeline::filter_pipeline(&request_data, &data)?;
164+
if request_data.compression.is_some() || request_data.size.is_none() {
165+
// Validate the raw uncompressed data size now that we know it.
166+
models::validate_raw_size(data.len(), request_data.dtype, &request_data.shape)?;
167+
}
164168
T::execute(&request_data, &data)
165169
}
166170

src/error.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ pub enum ActiveStorageError {
3838
#[error("request data is not valid")]
3939
RequestDataJsonRejection(#[from] JsonRejection),
4040

41-
/// Error validating RequestData
41+
/// Error validating RequestData (single error)
42+
#[error("request data is not valid")]
43+
RequestDataValidationSingle(#[from] validator::ValidationError),
44+
45+
/// Error validating RequestData (multiple errors)
4246
#[error("request data is not valid")]
4347
RequestDataValidation(#[from] validator::ValidationErrors),
4448

@@ -181,6 +185,7 @@ impl From<ActiveStorageError> for ErrorResponse {
181185
ActiveStorageError::Decompression(_)
182186
| ActiveStorageError::EmptyArray { operation: _ }
183187
| ActiveStorageError::RequestDataJsonRejection(_)
188+
| ActiveStorageError::RequestDataValidationSingle(_)
184189
| ActiveStorageError::RequestDataValidation(_)
185190
| ActiveStorageError::ShapeInvalid(_) => Self::bad_request(&error),
186191

@@ -340,6 +345,15 @@ mod tests {
340345
.await;
341346
}
342347

348+
#[tokio::test]
349+
async fn request_data_validation_single() {
350+
let validation_error = validator::ValidationError::new("foo");
351+
let error = ActiveStorageError::RequestDataValidationSingle(validation_error);
352+
let message = "request data is not valid";
353+
let caused_by = Some(vec!["Validation error: foo [{}]"]);
354+
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
355+
}
356+
343357
#[tokio::test]
344358
async fn request_data_validation() {
345359
let mut validation_errors = validator::ValidationErrors::new();

src/models.rs

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub enum DType {
2626

2727
impl DType {
2828
/// Returns the size of the associated type in bytes.
29-
fn size_of(self) -> usize {
29+
pub fn size_of(self) -> usize {
3030
match self {
3131
Self::Int32 => std::mem::size_of::<i32>(),
3232
Self::Int64 => std::mem::size_of::<i64>(),
@@ -164,16 +164,47 @@ fn validate_shape_selection(
164164
Ok(())
165165
}
166166

167+
/// Validate raw data size against data type and shape.
168+
///
169+
/// # Arguments
170+
///
171+
/// * `raw_size`: Raw (uncompressed) size of the data in bytes.
172+
/// * `dtype`: Data type
173+
/// * `shape`: Optional shape of the multi-dimensional array
174+
pub fn validate_raw_size(
175+
raw_size: usize,
176+
dtype: DType,
177+
shape: &Option<Vec<usize>>,
178+
) -> Result<(), ValidationError> {
179+
let dtype_size = dtype.size_of();
180+
if let Some(shape) = shape {
181+
let expected_size = shape.iter().product::<usize>() * dtype_size;
182+
if raw_size != expected_size {
183+
let mut error =
184+
ValidationError::new("Raw data size must be equal to the product of shape indices and dtype size in bytes");
185+
error.add_param("raw size".into(), &raw_size);
186+
error.add_param("dtype size".into(), &dtype_size);
187+
error.add_param("expected size".into(), &expected_size);
188+
return Err(error);
189+
}
190+
} else if raw_size % dtype_size != 0 {
191+
let mut error =
192+
ValidationError::new("Raw data size must be a multiple of dtype size in bytes");
193+
error.add_param("raw size".into(), &raw_size);
194+
error.add_param("dtype size".into(), &dtype_size);
195+
return Err(error);
196+
}
197+
Ok(())
198+
}
199+
167200
/// Validate request data
168201
fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationError> {
169202
// Validation of multiple fields in RequestData.
170203
if let Some(size) = &request_data.size {
171-
let dtype_size = request_data.dtype.size_of();
172-
if size % dtype_size != 0 {
173-
let mut error = ValidationError::new("Size must be a multiple of dtype size in bytes");
174-
error.add_param("size".into(), &size);
175-
error.add_param("dtype size".into(), &dtype_size);
176-
return Err(error);
204+
// If the data is compressed then the size refers to the size of the compressed data, so we
205+
// can't validate it at this point.
206+
if request_data.compression.is_none() {
207+
validate_raw_size(*size, request_data.dtype, &request_data.shape)?;
177208
}
178209
};
179210
match (&request_data.shape, &request_data.selection) {
@@ -531,13 +562,24 @@ mod tests {
531562
}
532563

533564
#[test]
534-
#[should_panic(expected = "Size must be a multiple of dtype size in bytes")]
565+
#[should_panic(expected = "Raw data size must be a multiple of dtype size in bytes")]
535566
fn test_invalid_size_for_dtype() {
536567
let mut request_data = get_test_request_data();
537568
request_data.size = Some(1);
538569
request_data.validate().unwrap()
539570
}
540571

572+
#[test]
573+
#[should_panic(
574+
expected = "Raw data size must be equal to the product of shape indices and dtype size in bytes"
575+
)]
576+
fn test_invalid_size_for_shape() {
577+
let mut request_data = get_test_request_data();
578+
request_data.size = Some(4);
579+
request_data.shape = Some(vec![1, 2]);
580+
request_data.validate().unwrap()
581+
}
582+
541583
#[test]
542584
#[should_panic(expected = "Shape and selection must have the same length")]
543585
fn test_shape_selection_mismatch() {

0 commit comments

Comments
 (0)