@@ -35,7 +35,7 @@ use plotly::{common::Mode, Plot, Scatter};
3535use pyo3:: { pyclass, pymethods, PyResult } ;
3636use std:: collections:: BTreeMap ;
3737use std:: sync:: Arc ;
38- use time :: Date ;
38+ use ordered_float :: OrderedFloat ;
3939use RustQuant_math :: interpolation:: { ExponentialInterpolator , Interpolator , LinearInterpolator } ;
4040use RustQuant_stochastics :: { CurveModel , NelsonSiegelSvensson } ;
4141
@@ -84,7 +84,7 @@ pub enum InterpolationMethod {
8484#[ pyclass]
8585pub struct Curve {
8686 /// The nodes of the curve.
87- nodes : BTreeMap < Date , f64 > ,
87+ nodes : BTreeMap < OrderedFloat < f64 > , f64 > ,
8888
8989 /// The type of the curve.
9090 curve_type : CurveType ,
@@ -93,7 +93,7 @@ pub struct Curve {
9393 interpolation_method : InterpolationMethod ,
9494
9595 /// Interpolator backend.
96- interpolator : Arc < dyn Interpolator < Date , f64 > > ,
96+ interpolator : Arc < dyn Interpolator < f64 , f64 > > ,
9797
9898 /// Nelson-Siegel-Svensson curve parameters.
9999 nss : Option < NelsonSiegelSvensson > ,
@@ -104,12 +104,12 @@ impl Curve {
104104 /// Create a new curve.
105105 #[ new]
106106 pub fn new (
107- dates : Vec < Date > ,
107+ dates : Vec < f64 > ,
108108 rates : Vec < f64 > ,
109109 curve_type : CurveType ,
110110 interpolation_method : InterpolationMethod ,
111111 ) -> PyResult < Self > {
112- let interpolator: Arc < dyn Interpolator < Date , f64 > > = match interpolation_method {
112+ let interpolator: Arc < dyn Interpolator < f64 , f64 > > = match interpolation_method {
113113 InterpolationMethod :: Linear => {
114114 Arc :: new ( LinearInterpolator :: new ( dates. clone ( ) , rates. clone ( ) ) ?)
115115 }
@@ -122,93 +122,100 @@ impl Curve {
122122 InterpolationMethod :: Lagrange => {
123123 todo ! ( "Implement LagrangeInterpolator" )
124124 }
125- } ;
125+ } ;
126126
127127 Ok ( Self {
128- nodes : dates. into_iter ( ) . zip ( rates. into_iter ( ) ) . collect ( ) ,
128+ nodes : dates. into_iter ( ) . zip ( rates. into_iter ( ) ) . map ( | ( a , b ) | ( OrderedFloat ( a ) , b ) ) . collect ( ) ,
129129 curve_type,
130130 interpolation_method,
131131 interpolator,
132132 nss : None ,
133133 } )
134134 }
135135
136- /// Create a new Curve from a list of nodes.
137- #[ staticmethod]
138- pub fn from_nodes (
139- nodes : BTreeMap < Date , f64 > ,
140- curve_type : CurveType ,
141- interpolation_method : InterpolationMethod ,
142- ) -> PyResult < Self > {
143- let interpolator: Arc < dyn Interpolator < Date , f64 > > = match interpolation_method {
144- InterpolationMethod :: Linear => Arc :: new ( LinearInterpolator :: new (
145- nodes. keys ( ) . cloned ( ) . collect ( ) ,
146- nodes. values ( ) . cloned ( ) . collect ( ) ,
147- ) ?) ,
148- InterpolationMethod :: Exponential => Arc :: new ( ExponentialInterpolator :: new (
149- nodes. keys ( ) . cloned ( ) . collect ( ) ,
150- nodes. values ( ) . cloned ( ) . collect ( ) ,
151- ) ?) ,
152- InterpolationMethod :: CubicSpline => {
153- todo ! ( "Implement CubicSplineInterpolator" )
154- }
155- InterpolationMethod :: Lagrange => {
156- todo ! ( "Implement LagrangeInterpolator" )
157- }
158- } ;
159136
160- Ok ( Self {
161- nodes,
162- curve_type,
163- interpolation_method,
164- interpolator,
165- nss : None ,
166- } )
167- }
137+ // /// Create a new Curve from a list of nodes.
138+ // #[staticmethod]
139+ // pub fn from_nodes(
140+ // nodes: BTreeMap<Date, f64>,
141+ // curve_type: CurveType,
142+ // interpolation_method: InterpolationMethod,
143+ // ) -> PyResult<Self> {
144+ // let interpolator: Arc<dyn Interpolator<Date, f64>> = match interpolation_method {
145+ // InterpolationMethod::Linear => Arc::new(LinearInterpolator::new(
146+ // nodes.keys().cloned().collect(),
147+ // nodes.values().cloned().collect(),
148+ // )?),
149+ // InterpolationMethod::Exponential => Arc::new(ExponentialInterpolator::new(
150+ // nodes.keys().cloned().collect(),
151+ // nodes.values().cloned().collect(),
152+ // )?),
153+ // InterpolationMethod::CubicSpline => {
154+ // todo!("Implement CubicSplineInterpolator")
155+ // }
156+ // InterpolationMethod::Lagrange => {
157+ // todo!("Implement LagrangeInterpolator")
158+ // }
159+ // };
160+
161+ // Ok(Self {
162+ // nodes,
163+ // curve_type,
164+ // interpolation_method,
165+ // interpolator,
166+ // nss: None,
167+ // })
168+ // }
168169
169170 /// Get the interpolation method used by the curve.
170171 pub fn interpolation_method ( & self ) -> InterpolationMethod {
171172 self . interpolation_method
172173 }
173174
174175 /// Get a rate from the curve.
175- pub fn get_rate ( & self , date : Date ) -> Option < f64 > {
176- match self . nodes . get ( & date) {
176+ pub fn get_rate ( & self , date : f64 ) -> Option < f64 > {
177+ match self . nodes . get ( & OrderedFloat ( date) ) {
177178 Some ( rate) => Some ( * rate) ,
178179 None => self . interpolator . interpolate ( date) . ok ( ) ,
179180 }
180181 }
181182
182183 /// Get multiple rates from the curve.
183- pub fn get_rates ( & self , dates : Vec < Date > ) -> Vec < Option < f64 > > {
184+ pub fn get_rates ( & self , dates : Vec < f64 > ) -> Vec < Option < f64 > > {
184185 dates. iter ( ) . map ( |date| self . get_rate ( * date) ) . collect ( )
185186 }
186187
187188 /// Set a rate in the curve.
188- pub fn set_rate ( & mut self , date : Date , rate : f64 ) {
189- self . nodes . insert ( date, rate) ;
189+ pub fn set_rate ( & mut self , date : f64 , rate : f64 ) {
190+ self . nodes . insert ( OrderedFloat ( date) , rate) ;
190191 }
191192
192193 /// Set multiple rates in the curve.
193- pub fn set_rates ( & mut self , rates : Vec < ( Date , f64 ) > ) {
194+ pub fn set_rates ( & mut self , rates : Vec < ( f64 , f64 ) > ) {
194195 for ( date, rate) in rates {
195196 self . set_rate ( date, rate) ;
196197 }
197198 }
198199
199200 /// Get the first date in the curve.
200- pub fn first_date ( & self ) -> Option < & Date > {
201- self . nodes . keys ( ) . next ( )
201+ pub fn first_date ( & self ) -> Option < & f64 > {
202+ match self . nodes . keys ( ) . next ( ) {
203+ Some ( date) => Some ( & date. 0 ) ,
204+ None => None ,
205+ }
202206 }
203207
204208 /// Get the last date in the curve.
205- pub fn last_date ( & self ) -> Option < & Date > {
206- self . nodes . keys ( ) . next_back ( )
209+ pub fn last_date ( & self ) -> Option < & f64 > {
210+ match self . nodes . keys ( ) . next_back ( ) {
211+ Some ( date) => Some ( & date. 0 ) ,
212+ None => None ,
213+ }
207214 }
208215
209216 /// Get the dates of the curve.
210- pub fn dates ( & self ) -> Vec < Date > {
211- self . nodes . keys ( ) . cloned ( ) . collect ( )
217+ pub fn dates ( & self ) -> Vec < f64 > {
218+ self . nodes . keys ( ) . map ( |k| k . 0 ) . collect ( ) //. cloned().collect()
212219 }
213220
214221 /// Get the rates of the curve.
@@ -237,7 +244,7 @@ impl Curve {
237244 }
238245
239246 /// Get the bracketing indices for a specific index.
240- pub fn get_brackets ( & self , index : Date ) -> ( Date , Date ) {
247+ pub fn get_brackets ( & self , index : f64 ) -> ( f64 , f64 ) {
241248 let first = self . first_date ( ) . unwrap ( ) ;
242249 let last = self . last_date ( ) . unwrap ( ) ;
243250
@@ -248,10 +255,10 @@ impl Curve {
248255 return ( * last, * last) ;
249256 }
250257
251- let left = self . nodes . range ( ..index) . next_back ( ) . unwrap ( ) . 0 ;
252- let right = self . nodes . range ( index..) . next ( ) . unwrap ( ) . 0 ;
258+ let left = self . nodes . range ( ..OrderedFloat ( index) ) . next_back ( ) . unwrap ( ) . 0 ;
259+ let right = self . nodes . range ( OrderedFloat ( index) ..) . next ( ) . unwrap ( ) . 0 ;
253260
254- return ( * left, * right) ;
261+ return ( left. 0 , right. 0 ) ;
255262 }
256263
257264 /// Shift the curve by a constant value.
@@ -306,7 +313,7 @@ impl CostFunction for Curve {
306313
307314 let y_model = x
308315 . into_iter ( )
309- . map ( |date| curve_function ( & nss, * date) )
316+ . map ( |date| curve_function ( & nss, * * date) )
310317 . collect :: < Vec < f64 > > ( ) ;
311318
312319 let data = std:: iter:: zip ( y, y_model) ;
0 commit comments