@@ -35,7 +35,7 @@ use plotly::{common::Mode, Plot, Scatter};
3535use pyo3:: { pyclass, pymethods, PyResult } ;
3636use std:: collections:: BTreeMap ;
3737use std:: sync:: Arc ;
38- use time :: Date ;
38+ use ordered_float :: OrderedFloat ;
3939use RustQuant_math :: interpolation:: { ExponentialInterpolator , Interpolator , LinearInterpolator } ;
4040use RustQuant_stochastics :: { CurveModel , NelsonSiegelSvensson } ;
4141
@@ -84,7 +84,7 @@ pub enum InterpolationMethod {
8484#[ pyclass]
8585pub struct Curve {
8686 /// The nodes of the curve.
87- nodes : BTreeMap < Date , f64 > ,
87+ nodes : BTreeMap < OrderedFloat < f64 > , f64 > ,
8888
8989 /// The type of the curve.
9090 curve_type : CurveType ,
@@ -93,7 +93,7 @@ pub struct Curve {
9393 interpolation_method : InterpolationMethod ,
9494
9595 /// Interpolator backend.
96- interpolator : Arc < dyn Interpolator < Date , f64 > > ,
96+ interpolator : Arc < dyn Interpolator < f64 , f64 > > ,
9797
9898 /// Nelson-Siegel-Svensson curve parameters.
9999 nss : Option < NelsonSiegelSvensson > ,
@@ -104,12 +104,12 @@ impl Curve {
104104 /// Create a new curve.
105105 #[ new]
106106 pub fn new (
107- dates : Vec < Date > ,
107+ dates : Vec < f64 > ,
108108 rates : Vec < f64 > ,
109109 curve_type : CurveType ,
110110 interpolation_method : InterpolationMethod ,
111111 ) -> PyResult < Self > {
112- let interpolator: Arc < dyn Interpolator < Date , f64 > > = match interpolation_method {
112+ let interpolator: Arc < dyn Interpolator < f64 , f64 > > = match interpolation_method {
113113 InterpolationMethod :: Linear => {
114114 Arc :: new ( LinearInterpolator :: new ( dates. clone ( ) , rates. clone ( ) ) ?)
115115 }
@@ -123,92 +123,101 @@ impl Curve {
123123 todo ! ( "Implement LagrangeInterpolator" )
124124 }
125125 } ;
126+ // let x: BTreeMap<f64, f64> = dates.into_iter().zip(rates.into_iter()).collect();
127+
126128
127129 Ok ( Self {
128- nodes : dates. into_iter ( ) . zip ( rates. into_iter ( ) ) . collect ( ) ,
130+ nodes : dates. into_iter ( ) . zip ( rates. into_iter ( ) ) . map ( | ( a , b ) | ( OrderedFloat ( a ) , b ) ) . collect ( ) ,
129131 curve_type,
130132 interpolation_method,
131133 interpolator,
132134 nss : None ,
133135 } )
134136 }
135137
136- /// Create a new Curve from a list of nodes.
137- #[ staticmethod]
138- pub fn from_nodes (
139- nodes : BTreeMap < Date , f64 > ,
140- curve_type : CurveType ,
141- interpolation_method : InterpolationMethod ,
142- ) -> PyResult < Self > {
143- let interpolator: Arc < dyn Interpolator < Date , f64 > > = match interpolation_method {
144- InterpolationMethod :: Linear => Arc :: new ( LinearInterpolator :: new (
145- nodes. keys ( ) . cloned ( ) . collect ( ) ,
146- nodes. values ( ) . cloned ( ) . collect ( ) ,
147- ) ?) ,
148- InterpolationMethod :: Exponential => Arc :: new ( ExponentialInterpolator :: new (
149- nodes. keys ( ) . cloned ( ) . collect ( ) ,
150- nodes. values ( ) . cloned ( ) . collect ( ) ,
151- ) ?) ,
152- InterpolationMethod :: CubicSpline => {
153- todo ! ( "Implement CubicSplineInterpolator" )
154- }
155- InterpolationMethod :: Lagrange => {
156- todo ! ( "Implement LagrangeInterpolator" )
157- }
158- } ;
159138
160- Ok ( Self {
161- nodes,
162- curve_type,
163- interpolation_method,
164- interpolator,
165- nss : None ,
166- } )
167- }
139+ // /// Create a new Curve from a list of nodes.
140+ // #[staticmethod]
141+ // pub fn from_nodes(
142+ // nodes: BTreeMap<Date, f64>,
143+ // curve_type: CurveType,
144+ // interpolation_method: InterpolationMethod,
145+ // ) -> PyResult<Self> {
146+ // let interpolator: Arc<dyn Interpolator<Date, f64>> = match interpolation_method {
147+ // InterpolationMethod::Linear => Arc::new(LinearInterpolator::new(
148+ // nodes.keys().cloned().collect(),
149+ // nodes.values().cloned().collect(),
150+ // )?),
151+ // InterpolationMethod::Exponential => Arc::new(ExponentialInterpolator::new(
152+ // nodes.keys().cloned().collect(),
153+ // nodes.values().cloned().collect(),
154+ // )?),
155+ // InterpolationMethod::CubicSpline => {
156+ // todo!("Implement CubicSplineInterpolator")
157+ // }
158+ // InterpolationMethod::Lagrange => {
159+ // todo!("Implement LagrangeInterpolator")
160+ // }
161+ // };
162+
163+ // Ok(Self {
164+ // nodes,
165+ // curve_type,
166+ // interpolation_method,
167+ // interpolator,
168+ // nss: None,
169+ // })
170+ // }
168171
169172 /// Get the interpolation method used by the curve.
170173 pub fn interpolation_method ( & self ) -> InterpolationMethod {
171174 self . interpolation_method
172175 }
173176
174177 /// Get a rate from the curve.
175- pub fn get_rate ( & self , date : Date ) -> Option < f64 > {
176- match self . nodes . get ( & date) {
178+ pub fn get_rate ( & self , date : f64 ) -> Option < f64 > {
179+ match self . nodes . get ( & OrderedFloat ( date) ) {
177180 Some ( rate) => Some ( * rate) ,
178181 None => self . interpolator . interpolate ( date) . ok ( ) ,
179182 }
180183 }
181184
182185 /// Get multiple rates from the curve.
183- pub fn get_rates ( & self , dates : Vec < Date > ) -> Vec < Option < f64 > > {
186+ pub fn get_rates ( & self , dates : Vec < f64 > ) -> Vec < Option < f64 > > {
184187 dates. iter ( ) . map ( |date| self . get_rate ( * date) ) . collect ( )
185188 }
186189
187190 /// Set a rate in the curve.
188- pub fn set_rate ( & mut self , date : Date , rate : f64 ) {
189- self . nodes . insert ( date, rate) ;
191+ pub fn set_rate ( & mut self , date : f64 , rate : f64 ) {
192+ self . nodes . insert ( OrderedFloat ( date) , rate) ;
190193 }
191194
192195 /// Set multiple rates in the curve.
193- pub fn set_rates ( & mut self , rates : Vec < ( Date , f64 ) > ) {
196+ pub fn set_rates ( & mut self , rates : Vec < ( f64 , f64 ) > ) {
194197 for ( date, rate) in rates {
195198 self . set_rate ( date, rate) ;
196199 }
197200 }
198201
199202 /// Get the first date in the curve.
200- pub fn first_date ( & self ) -> Option < & Date > {
201- self . nodes . keys ( ) . next ( )
203+ pub fn first_date ( & self ) -> Option < & f64 > {
204+ match self . nodes . keys ( ) . next ( ) {
205+ Some ( date) => Some ( & date. 0 ) ,
206+ None => None ,
207+ }
202208 }
203209
204210 /// Get the last date in the curve.
205- pub fn last_date ( & self ) -> Option < & Date > {
206- self . nodes . keys ( ) . next_back ( )
211+ pub fn last_date ( & self ) -> Option < & f64 > {
212+ match self . nodes . keys ( ) . next_back ( ) {
213+ Some ( date) => Some ( & date. 0 ) ,
214+ None => None ,
215+ }
207216 }
208217
209218 /// Get the dates of the curve.
210- pub fn dates ( & self ) -> Vec < Date > {
211- self . nodes . keys ( ) . cloned ( ) . collect ( )
219+ pub fn dates ( & self ) -> Vec < f64 > {
220+ self . nodes . keys ( ) . map ( |k| k . 0 ) . collect ( ) //. cloned().collect()
212221 }
213222
214223 /// Get the rates of the curve.
@@ -237,7 +246,7 @@ impl Curve {
237246 }
238247
239248 /// Get the bracketing indices for a specific index.
240- pub fn get_brackets ( & self , index : Date ) -> ( Date , Date ) {
249+ pub fn get_brackets ( & self , index : f64 ) -> ( f64 , f64 ) {
241250 let first = self . first_date ( ) . unwrap ( ) ;
242251 let last = self . last_date ( ) . unwrap ( ) ;
243252
@@ -248,10 +257,10 @@ impl Curve {
248257 return ( * last, * last) ;
249258 }
250259
251- let left = self . nodes . range ( ..index) . next_back ( ) . unwrap ( ) . 0 ;
252- let right = self . nodes . range ( index..) . next ( ) . unwrap ( ) . 0 ;
260+ let left = self . nodes . range ( ..OrderedFloat ( index) ) . next_back ( ) . unwrap ( ) . 0 ;
261+ let right = self . nodes . range ( OrderedFloat ( index) ..) . next ( ) . unwrap ( ) . 0 ;
253262
254- return ( * left, * right) ;
263+ return ( left. 0 , right. 0 ) ;
255264 }
256265
257266 /// Shift the curve by a constant value.
@@ -306,7 +315,7 @@ impl CostFunction for Curve {
306315
307316 let y_model = x
308317 . into_iter ( )
309- . map ( |date| curve_function ( & nss, * date) )
318+ . map ( |date| curve_function ( & nss, * * date) )
310319 . collect :: < Vec < f64 > > ( ) ;
311320
312321 let data = std:: iter:: zip ( y, y_model) ;
0 commit comments