Skip to content

Commit 46f183d

Browse files
committed
Replace dates with floats for interpolation and curves
1 parent b72b0e0 commit 46f183d

File tree

10 files changed

+79
-206
lines changed

10 files changed

+79
-206
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ rayon = "1.9.0" # https://docs.rs/rayon/latest/rayon/
9999
rust_decimal = { version = "1.34.3", features = ["maths"] } # https://docs.rs/rust_decimal/latest/rust_decimal/
100100
statrs = "0.17.1" # https://docs.rs/statrs/latest/statrs/
101101
thiserror = "1.0.57" # https://docs.rs/thiserror/latest/thiserror/
102+
ordered-float = "5.1.0" # https://docs.rs/ordered-float/latest/ordered_float/
102103

103104
# https://docs.rs/ndarray/latest/ndarray/
104105
ndarray = { version = "0.16.1", features = ["rayon"] }

crates/RustQuant_data/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ time = { workspace = true }
3030
plotly = { workspace = true }
3131
argmin = { workspace = true }
3232
argmin-math = { workspace = true }
33+
ordered-float = {workspace=true}
3334
pyo3 = {workspace=true}
3435

3536
## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

crates/RustQuant_data/src/curves.rs

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use plotly::{common::Mode, Plot, Scatter};
3535
use pyo3::{pyclass, pymethods, PyResult};
3636
use std::collections::BTreeMap;
3737
use std::sync::Arc;
38-
use time::Date;
38+
use ordered_float::OrderedFloat;
3939
use RustQuant_math::interpolation::{ExponentialInterpolator, Interpolator, LinearInterpolator};
4040
use RustQuant_stochastics::{CurveModel, NelsonSiegelSvensson};
4141

@@ -84,7 +84,7 @@ pub enum InterpolationMethod {
8484
#[pyclass]
8585
pub struct Curve {
8686
/// The nodes of the curve.
87-
nodes: BTreeMap<Date, f64>,
87+
nodes: BTreeMap<OrderedFloat<f64>, f64>,
8888

8989
/// The type of the curve.
9090
curve_type: CurveType,
@@ -93,7 +93,7 @@ pub struct Curve {
9393
interpolation_method: InterpolationMethod,
9494

9595
/// Interpolator backend.
96-
interpolator: Arc<dyn Interpolator<Date, f64>>,
96+
interpolator: Arc<dyn Interpolator<f64, f64>>,
9797

9898
/// Nelson-Siegel-Svensson curve parameters.
9999
nss: Option<NelsonSiegelSvensson>,
@@ -104,12 +104,12 @@ impl Curve {
104104
/// Create a new curve.
105105
#[new]
106106
pub fn new(
107-
dates: Vec<Date>,
107+
dates: Vec<f64>,
108108
rates: Vec<f64>,
109109
curve_type: CurveType,
110110
interpolation_method: InterpolationMethod,
111111
) -> PyResult<Self> {
112-
let interpolator: Arc<dyn Interpolator<Date, f64>> = match interpolation_method {
112+
let interpolator: Arc<dyn Interpolator<f64, f64>> = match interpolation_method {
113113
InterpolationMethod::Linear => {
114114
Arc::new(LinearInterpolator::new(dates.clone(), rates.clone())?)
115115
}
@@ -123,92 +123,101 @@ impl Curve {
123123
todo!("Implement LagrangeInterpolator")
124124
}
125125
};
126+
// let x: BTreeMap<f64, f64> = dates.into_iter().zip(rates.into_iter()).collect();
127+
126128

127129
Ok(Self {
128-
nodes: dates.into_iter().zip(rates.into_iter()).collect(),
130+
nodes: dates.into_iter().zip(rates.into_iter()).map(|(a, b)| (OrderedFloat(a), b)).collect(),
129131
curve_type,
130132
interpolation_method,
131133
interpolator,
132134
nss: None,
133135
})
134136
}
135137

136-
/// Create a new Curve from a list of nodes.
137-
#[staticmethod]
138-
pub fn from_nodes(
139-
nodes: BTreeMap<Date, f64>,
140-
curve_type: CurveType,
141-
interpolation_method: InterpolationMethod,
142-
) -> PyResult<Self> {
143-
let interpolator: Arc<dyn Interpolator<Date, f64>> = match interpolation_method {
144-
InterpolationMethod::Linear => Arc::new(LinearInterpolator::new(
145-
nodes.keys().cloned().collect(),
146-
nodes.values().cloned().collect(),
147-
)?),
148-
InterpolationMethod::Exponential => Arc::new(ExponentialInterpolator::new(
149-
nodes.keys().cloned().collect(),
150-
nodes.values().cloned().collect(),
151-
)?),
152-
InterpolationMethod::CubicSpline => {
153-
todo!("Implement CubicSplineInterpolator")
154-
}
155-
InterpolationMethod::Lagrange => {
156-
todo!("Implement LagrangeInterpolator")
157-
}
158-
};
159138

160-
Ok(Self {
161-
nodes,
162-
curve_type,
163-
interpolation_method,
164-
interpolator,
165-
nss: None,
166-
})
167-
}
139+
// /// Create a new Curve from a list of nodes.
140+
// #[staticmethod]
141+
// pub fn from_nodes(
142+
// nodes: BTreeMap<Date, f64>,
143+
// curve_type: CurveType,
144+
// interpolation_method: InterpolationMethod,
145+
// ) -> PyResult<Self> {
146+
// let interpolator: Arc<dyn Interpolator<Date, f64>> = match interpolation_method {
147+
// InterpolationMethod::Linear => Arc::new(LinearInterpolator::new(
148+
// nodes.keys().cloned().collect(),
149+
// nodes.values().cloned().collect(),
150+
// )?),
151+
// InterpolationMethod::Exponential => Arc::new(ExponentialInterpolator::new(
152+
// nodes.keys().cloned().collect(),
153+
// nodes.values().cloned().collect(),
154+
// )?),
155+
// InterpolationMethod::CubicSpline => {
156+
// todo!("Implement CubicSplineInterpolator")
157+
// }
158+
// InterpolationMethod::Lagrange => {
159+
// todo!("Implement LagrangeInterpolator")
160+
// }
161+
// };
162+
163+
// Ok(Self {
164+
// nodes,
165+
// curve_type,
166+
// interpolation_method,
167+
// interpolator,
168+
// nss: None,
169+
// })
170+
// }
168171

169172
/// Get the interpolation method used by the curve.
170173
pub fn interpolation_method(&self) -> InterpolationMethod {
171174
self.interpolation_method
172175
}
173176

174177
/// Get a rate from the curve.
175-
pub fn get_rate(&self, date: Date) -> Option<f64> {
176-
match self.nodes.get(&date) {
178+
pub fn get_rate(&self, date: f64) -> Option<f64> {
179+
match self.nodes.get(&OrderedFloat(date)) {
177180
Some(rate) => Some(*rate),
178181
None => self.interpolator.interpolate(date).ok(),
179182
}
180183
}
181184

182185
/// Get multiple rates from the curve.
183-
pub fn get_rates(&self, dates: Vec<Date>) -> Vec<Option<f64>> {
186+
pub fn get_rates(&self, dates: Vec<f64>) -> Vec<Option<f64>> {
184187
dates.iter().map(|date| self.get_rate(*date)).collect()
185188
}
186189

187190
/// Set a rate in the curve.
188-
pub fn set_rate(&mut self, date: Date, rate: f64) {
189-
self.nodes.insert(date, rate);
191+
pub fn set_rate(&mut self, date: f64, rate: f64) {
192+
self.nodes.insert(OrderedFloat(date), rate);
190193
}
191194

192195
/// Set multiple rates in the curve.
193-
pub fn set_rates(&mut self, rates: Vec<(Date, f64)>) {
196+
pub fn set_rates(&mut self, rates: Vec<(f64, f64)>) {
194197
for (date, rate) in rates {
195198
self.set_rate(date, rate);
196199
}
197200
}
198201

199202
/// Get the first date in the curve.
200-
pub fn first_date(&self) -> Option<&Date> {
201-
self.nodes.keys().next()
203+
pub fn first_date(&self) -> Option<&f64> {
204+
match self.nodes.keys().next() {
205+
Some(date) => Some(&date.0),
206+
None => None,
207+
}
202208
}
203209

204210
/// Get the last date in the curve.
205-
pub fn last_date(&self) -> Option<&Date> {
206-
self.nodes.keys().next_back()
211+
pub fn last_date(&self) -> Option<&f64> {
212+
match self.nodes.keys().next_back() {
213+
Some(date) => Some(&date.0),
214+
None => None,
215+
}
207216
}
208217

209218
/// Get the dates of the curve.
210-
pub fn dates(&self) -> Vec<Date> {
211-
self.nodes.keys().cloned().collect()
219+
pub fn dates(&self) -> Vec<f64> {
220+
self.nodes.keys().map(|k| k.0).collect()//.cloned().collect()
212221
}
213222

214223
/// Get the rates of the curve.
@@ -237,7 +246,7 @@ impl Curve {
237246
}
238247

239248
/// Get the bracketing indices for a specific index.
240-
pub fn get_brackets(&self, index: Date) -> (Date, Date) {
249+
pub fn get_brackets(&self, index: f64) -> (f64, f64) {
241250
let first = self.first_date().unwrap();
242251
let last = self.last_date().unwrap();
243252

@@ -248,10 +257,10 @@ impl Curve {
248257
return (*last, *last);
249258
}
250259

251-
let left = self.nodes.range(..index).next_back().unwrap().0;
252-
let right = self.nodes.range(index..).next().unwrap().0;
260+
let left = self.nodes.range(..OrderedFloat(index)).next_back().unwrap().0;
261+
let right = self.nodes.range(OrderedFloat(index)..).next().unwrap().0;
253262

254-
return (*left, *right);
263+
return (left.0, right.0);
255264
}
256265

257266
/// Shift the curve by a constant value.
@@ -306,7 +315,7 @@ impl CostFunction for Curve {
306315

307316
let y_model = x
308317
.into_iter()
309-
.map(|date| curve_function(&nss, *date))
318+
.map(|date| curve_function(&nss, **date))
310319
.collect::<Vec<f64>>();
311320

312321
let data = std::iter::zip(y, y_model);

crates/RustQuant_math/src/interpolation/b_splines.rs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -181,32 +181,6 @@ mod tests_b_splines {
181181
);
182182
}
183183

184-
#[test]
185-
fn test_b_spline_dates() {
186-
let now = time::OffsetDateTime::now_utc();
187-
let knots: Vec<time::OffsetDateTime> = vec![
188-
now,
189-
now + time::Duration::days(1),
190-
now + time::Duration::days(2),
191-
now + time::Duration::days(3),
192-
now + time::Duration::days(4),
193-
now + time::Duration::days(5),
194-
now + time::Duration::days(6),
195-
];
196-
let control_points = vec![-1.0, 2.0, 0.0, -1.0];
197-
198-
let mut interpolator = BSplineInterpolator::new(knots.clone(), control_points, 2).unwrap();
199-
let _ = interpolator.fit();
200-
201-
assert_approx_equal!(
202-
1.375,
203-
interpolator
204-
.interpolate(knots[2] + time::Duration::hours(12))
205-
.unwrap(),
206-
RUSTQUANT_EPSILON
207-
);
208-
}
209-
210184
#[test]
211185
fn test_b_spline_inconsistent_parameters() {
212186
let knots = vec![0.0, 1.0, 2.0, 3.0, 4.0];

crates/RustQuant_math/src/interpolation/cubic_spline.rs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -298,32 +298,6 @@ mod tests_cubic_spline_interpolation {
298298
);
299299
}
300300

301-
#[test]
302-
fn test_natural_cubic_interpolation_dates() {
303-
let now: time::OffsetDateTime = time::OffsetDateTime::now_utc();
304-
305-
let xs: Vec<time::OffsetDateTime> = vec![
306-
now,
307-
now + time::Duration::days(1),
308-
now + time::Duration::days(2),
309-
now + time::Duration::days(3),
310-
now + time::Duration::days(4),
311-
];
312-
313-
let ys: Vec<f64> = vec![0., 1., 16., 81., 256.];
314-
315-
let mut interpolator: CubicSplineInterpolator<time::OffsetDateTime, f64> = CubicSplineInterpolator::new(xs.clone(), ys).unwrap();
316-
let _ = interpolator.fit();
317-
318-
assert_approx_equal!(
319-
36.660714285714285,
320-
interpolator
321-
.interpolate(xs[2] + time::Duration::hours(12))
322-
.unwrap(),
323-
RUSTQUANT_EPSILON
324-
);
325-
}
326-
327301
#[test]
328302
fn test_cubic_interpolation_out_of_range() {
329303
let xs: Vec<f64> = vec![1., 2., 3., 4., 5.];

crates/RustQuant_math/src/interpolation/exponential_interpolator.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,24 +188,4 @@ mod tests_exponential_interpolation {
188188
RUSTQUANT_EPSILON
189189
);
190190
}
191-
192-
#[test]
193-
fn test_exponential_interpolation_dates() {
194-
let d_1m = date!(1990 - 06 - 16);
195-
let d_2m = date!(1990 - 07 - 17);
196-
197-
let r_1m = 0.9870;
198-
let r_2m = 0.9753;
199-
200-
let dates = vec![d_1m, d_2m];
201-
let rates = vec![r_1m, r_2m];
202-
203-
let interpolator = ExponentialInterpolator::new(dates, rates).unwrap();
204-
205-
assert_approx_equal!(
206-
0.9854824711068088,
207-
interpolator.interpolate(date!(1990 - 06 - 20)).unwrap(),
208-
RUSTQUANT_EPSILON
209-
);
210-
}
211191
}

crates/RustQuant_math/src/interpolation/linear_interpolator.rs

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ where
129129
#[cfg(test)]
130130
mod tests_linear_interpolation {
131131
use super::*;
132-
use time::macros::date;
133132
use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON};
134133

135134
#[test]
@@ -162,51 +161,4 @@ mod tests_linear_interpolation {
162161

163162
assert!(interpolator.interpolate(6.).is_err());
164163
}
165-
166-
#[test]
167-
fn test_linear_interpolation_dates() {
168-
let now = time::OffsetDateTime::now_utc();
169-
170-
let xs = vec![
171-
now,
172-
now + time::Duration::days(1),
173-
now + time::Duration::days(2),
174-
now + time::Duration::days(3),
175-
now + time::Duration::days(4),
176-
];
177-
178-
let ys = vec![1., 2., 3., 4., 5.];
179-
180-
let mut interpolator = LinearInterpolator::new(xs.clone(), ys).unwrap();
181-
let _ = interpolator.fit();
182-
183-
assert_approx_equal!(
184-
2.5,
185-
interpolator
186-
.interpolate(xs[1] + time::Duration::hours(12))
187-
.unwrap(),
188-
RUSTQUANT_EPSILON
189-
);
190-
}
191-
192-
#[test]
193-
fn test_linear_interpolation_dates_textbook() {
194-
let d_1m = date!(1990 - 06 - 16);
195-
let d_2m = date!(1990 - 07 - 17);
196-
197-
let r_1m = 0.9870;
198-
let r_2m = 0.9753;
199-
200-
let dates = vec![d_1m, d_2m];
201-
let rates = vec![r_1m, r_2m];
202-
203-
let interpolator = LinearInterpolator::new(dates, rates).unwrap();
204-
205-
let d = date!(1990 - 06 - 20);
206-
assert_approx_equal!(
207-
interpolator.interpolate(d).unwrap(),
208-
0.9854903225806452,
209-
RUSTQUANT_EPSILON
210-
);
211-
}
212164
}

0 commit comments

Comments
 (0)