Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::pyclass::CompareOp;
use pyo3::types::PyTuple;
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyDict, PyTime, PyTzInfo};
use pyo3::IntoPyObjectExt;
use speedate::DateConfig;
use speedate::{
Date, DateTime, DateTimeConfig, Duration, MicrosecondsPrecisionOverflowBehavior, ParseError, Time, TimeConfig,
};
Expand All @@ -21,6 +22,7 @@ use super::Input;
use crate::errors::ToErrorValue;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::tools::py_err;
use crate::validators::TemporalUnitMode;

#[cfg_attr(debug_assertions, derive(Debug))]
pub enum EitherDate<'py> {
Expand Down Expand Up @@ -324,8 +326,12 @@ impl<'py> EitherDateTime<'py> {
}
}

pub fn bytes_as_date<'py>(input: &(impl Input<'py> + ?Sized), bytes: &[u8]) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes(bytes) {
pub fn bytes_as_date<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
mode: TemporalUnitMode,
) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes_with_config(bytes, &DateConfig::builder().timestamp_unit(mode.into()).build()) {
Ok(date) => Ok(date.into()),
Err(err) => Err(ValError::new(
ErrorType::DateParsing {
Expand Down Expand Up @@ -364,6 +370,7 @@ pub fn bytes_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
Comment on lines 459 to +460
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to just pass &DateTimeConfig here to simplify things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i chose not to as i wasn't a fan of using something from the serialization module in the validation logic. Happy to be overridden though!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidhewitt happy to do whatever here to get this over the line 😄

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, let's keep it as-is 👍

) -> ValResult<EitherDateTime<'py>> {
match DateTime::parse_bytes_with_config(
bytes,
Expand All @@ -372,7 +379,7 @@ pub fn bytes_as_datetime<'py>(
microseconds_precision_overflow_behavior: microseconds_overflow_behavior,
unix_timestamp_offset: Some(0),
},
..Default::default()
timestamp_unit: mode.into(),
},
) {
Ok(dt) => Ok(dt.into()),
Expand All @@ -390,6 +397,7 @@ pub fn int_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
mode: TemporalUnitMode,
) -> ValResult<EitherDateTime<'py>> {
match DateTime::from_timestamp_with_config(
timestamp,
Expand All @@ -399,7 +407,7 @@ pub fn int_as_datetime<'py>(
unix_timestamp_offset: Some(0),
..Default::default()
},
..Default::default()
timestamp_unit: mode.into(),
},
) {
Ok(dt) => Ok(dt.into()),
Expand Down Expand Up @@ -427,12 +435,30 @@ macro_rules! nan_check {
};
}

pub fn float_as_datetime<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherDateTime<'py>> {
pub fn float_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: f64,
mode: TemporalUnitMode,
) -> ValResult<EitherDateTime<'py>> {
nan_check!(input, timestamp, DatetimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
let microseconds = match mode {
TemporalUnitMode::Seconds => timestamp.fract().abs() * 1_000_000.0,
TemporalUnitMode::Milliseconds => timestamp.fract().abs() * 1_000.0,
TemporalUnitMode::Infer => {
// Use the same watershed from speedate to determine if we treat the float as seconds or milliseconds.
// TODO: should we expose this from speedate?
if timestamp.abs() <= 20_000_000_000.0 {
// treat as seconds
timestamp.fract().abs() * 1_000_000.0
} else {
// treat as milliseconds
timestamp.fract().abs() * 1_000.0
}
}
};
// checking for extra digits in microseconds is unreliable with large floats,
// so we just round to the nearest microsecond
int_as_datetime(input, timestamp.floor() as i64, microseconds.round() as u32)
int_as_datetime(input, timestamp.floor() as i64, microseconds.round() as u32, mode)
}

pub fn date_as_datetime<'py>(date: &Bound<'py, PyDate>) -> PyResult<EitherDateTime<'py>> {
Expand Down
5 changes: 3 additions & 2 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pyo3::{intern, prelude::*, IntoPyObjectExt};
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
Expand Down Expand Up @@ -158,7 +158,7 @@ pub trait Input<'py>: fmt::Debug {

fn validate_iter(&self) -> ValResult<GenericIterator<'static>>;

fn validate_date(&self, strict: bool) -> ValMatch<EitherDate<'py>>;
fn validate_date(&self, strict: bool, mode: TemporalUnitMode) -> ValMatch<EitherDate<'py>>;

fn validate_time(
&self,
Expand All @@ -170,6 +170,7 @@ pub trait Input<'py>: fmt::Debug {
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValMatch<EitherDateTime<'py>>;

fn validate_timedelta(
Expand Down
20 changes: 11 additions & 9 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -277,9 +277,9 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
match self {
JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()).map(ValidationMatch::strict),
JsonValue::Str(v) => bytes_as_date(self, v.as_bytes(), mode).map(ValidationMatch::strict),
_ => Err(ValError::new(ErrorTypeDefaults::DateType, self)),
}
}
Expand Down Expand Up @@ -313,13 +313,14 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
match self {
JsonValue::Str(v) => {
bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict)
bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior, mode).map(ValidationMatch::strict)
}
JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0).map(ValidationMatch::lax),
JsonValue::Float(v) if !strict => float_as_datetime(self, *v).map(ValidationMatch::lax),
JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0, mode).map(ValidationMatch::lax),
JsonValue::Float(v) if !strict => float_as_datetime(self, *v, mode).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)),
}
}
Expand Down Expand Up @@ -485,8 +486,8 @@ impl<'py> Input<'py> for str {
Ok(string_to_vec(self).into())
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
bytes_as_date(self, self.as_bytes()).map(ValidationMatch::lax)
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
bytes_as_date(self, self.as_bytes(), mode).map(ValidationMatch::lax)
}

fn validate_time(
Expand All @@ -501,8 +502,9 @@ impl<'py> Input<'py> for str {
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior, mode).map(ValidationMatch::lax)
}

fn validate_timedelta(
Expand Down
14 changes: 8 additions & 6 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::TemporalUnitMode;
use crate::validators::ValBytesMode;
use crate::ArgsKwargs;

Expand Down Expand Up @@ -494,7 +495,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
if let Ok(date) = self.downcast_exact::<PyDate>() {
Ok(ValidationMatch::exact(date.clone().into()))
} else if self.is_instance_of::<PyDateTime>() {
Expand All @@ -515,7 +516,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
None
}
} {
bytes_as_date(self, bytes).map(ValidationMatch::lax)
bytes_as_date(self, bytes, mode).map(ValidationMatch::lax)
} else {
Err(ValError::new(ErrorTypeDefaults::DateType, self))
}
Expand Down Expand Up @@ -559,6 +560,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
&self,
strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
if let Ok(dt) = self.downcast_exact::<PyDateTime>() {
return Ok(ValidationMatch::exact(dt.clone().into()));
Expand All @@ -570,15 +572,15 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
if !strict {
return if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior)
bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior, mode)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior, mode)
} else if self.is_exact_instance_of::<PyBool>() {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
int_as_datetime(self, int, 0, mode)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
float_as_datetime(self, float, mode)
} else if let Ok(date) = self.downcast::<PyDate>() {
Ok(date_as_datetime(date)?)
} else {
Expand Down
13 changes: 8 additions & 5 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
Expand Down Expand Up @@ -201,9 +201,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
Err(ValError::new(ErrorTypeDefaults::IterableType, self))
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
match self {
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()).map(ValidationMatch::strict),
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes(), mode).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)),
}
}
Expand All @@ -224,10 +224,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
match self {
Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior)
.map(ValidationMatch::strict),
Self::String(s) => {
bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior, mode)
.map(ValidationMatch::strict)
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)),
}
}
Expand Down
60 changes: 55 additions & 5 deletions src/validators/config.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::borrow::Cow;
use std::str::FromStr;

use crate::build_tools::py_schema_err;
use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;
use base64::engine::general_purpose::GeneralPurpose;
use base64::engine::{DecodePaddingMode, GeneralPurposeConfig};
use base64::{alphabet, DecodeError, Engine};
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, prelude::*};

use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;
use speedate::TimestampUnit;

const URL_SAFE_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new(
&alphabet::URL_SAFE,
Expand All @@ -21,6 +22,55 @@ const STANDARD_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new(
GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent),
);

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemporalUnitMode {
Seconds,
Milliseconds,
#[default]
Infer,
}

impl FromStr for TemporalUnitMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"seconds" => Ok(Self::Seconds),
"milliseconds" => Ok(Self::Milliseconds),
"infer" => Ok(Self::Infer),

s => py_schema_err!(
"Invalid temporal_unit_mode serialization mode: `{}`, expected seconds, milliseconds or infer",
s
),
}
}
}

impl TemporalUnitMode {
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_temporal_unit"))?;
let temporal_unit = raw_mode.map_or_else(
|| Ok(TemporalUnitMode::default()),
|raw| TemporalUnitMode::from_str(&raw.to_cow()?),
)?;
Ok(temporal_unit)
}
}

impl From<TemporalUnitMode> for TimestampUnit {
fn from(value: TemporalUnitMode) -> Self {
match value {
TemporalUnitMode::Seconds => TimestampUnit::Second,
TemporalUnitMode::Milliseconds => TimestampUnit::Millisecond,
TemporalUnitMode::Infer => TimestampUnit::Infer,
}
}
}

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct ValBytesMode {
pub ser: BytesMode,
Expand Down
Loading
Loading