Skip to content

Commit 93e294c

Browse files
committed
ENH: Add ability to append to a model
1 parent 98a2bfa commit 93e294c

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

arch/univariate/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from arch.univariate.distribution import Distribution, Normal
3737
from arch.univariate.volatility import ConstantVariance, VolatilityProcess
38-
from arch.utility.array import ensure1d
38+
from arch.utility.array import ensure1d, append_same_type
3939
from arch.utility.exceptions import (
4040
ConvergenceWarning,
4141
DataScaleWarning,
@@ -230,6 +230,28 @@ def name(self) -> str:
230230
"""The name of the model."""
231231
return self._name
232232

233+
def append(self, y: ArrayLike) -> None:
234+
"""
235+
Append data to the model
236+
237+
Parameters
238+
----------
239+
y : ndarray or Series
240+
Data to append
241+
242+
Returns
243+
-------
244+
ARCHModel
245+
Model with data appended
246+
"""
247+
_y = ensure1d(y, "y", series=True)
248+
self._y_original = append_same_type(self._y_original, y)
249+
self._y_series = pd.concat([self._y_series, _y])
250+
self._y = np.concatenate([self._y, np.asarray(_y)])
251+
252+
self._fit_indices: [0, int(self._y.shape[0])]
253+
self._fit_y = self._y
254+
233255
def constraints(self) -> tuple[Float64Array, Float64Array]:
234256
"""
235257
Construct linear constraint arrays for use in non-linear optimization

arch/univariate/mean.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SkewStudent,
4040
StudentsT,
4141
)
42+
from arch.utility.array import append_same_type
4243

4344
if TYPE_CHECKING:
4445
# Fake path to satisfy mypy
@@ -269,6 +270,7 @@ def __init__(
269270
distribution=distribution,
270271
rescale=rescale,
271272
)
273+
self._x_original = x
272274
self._x = x
273275
self._x_names: list[str] = []
274276
self._x_index: None | NDArray | pd.Index = None
@@ -307,6 +309,25 @@ def __init__(
307309

308310
self._init_model()
309311

312+
def append(self, y: ArrayLike, x: ArrayLike2D | None = None) -> None:
313+
super().append(y)
314+
if x is not None:
315+
if self._x is None:
316+
raise ValueError("x was not provided in the original model")
317+
_x = np.atleast_2d(np.asarray(x))
318+
if _x.ndim != 2:
319+
raise ValueError("x must be 2-d")
320+
elif _x.shape[1] != self._x.shape[1]:
321+
raise ValueError(
322+
"x must have the same number of columns as the original x"
323+
)
324+
self._x_original = append_same_type(self._x_original, x)
325+
self._x = np.asarray(self._x_original)
326+
if self._x.shape[0] != self._y.shape[0]:
327+
raise ValueError("x must have the same number of observations as y")
328+
329+
self._init_model()
330+
310331
def _scale_changed(self):
311332
"""
312333
Called when the scale has changed. This allows the model

arch/utility/array.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
from typing import Any, Literal, overload
1313

1414
import numpy as np
15-
from pandas import DataFrame, DatetimeIndex, Index, NaT, Series, Timestamp, to_datetime
15+
from pandas import (
16+
DataFrame,
17+
DatetimeIndex,
18+
Index,
19+
NaT,
20+
Series,
21+
Timestamp,
22+
to_datetime,
23+
concat,
24+
)
1625

1726
from arch.typing import AnyPandas, ArrayLike, DateLike, NDArray
1827

@@ -310,3 +319,23 @@ def find_index(s: AnyPandas, index: int | DateLike) -> int:
310319
if loc.size == 0:
311320
raise ValueError("index not found")
312321
return int(loc)
322+
323+
324+
def append_same_type(original, new):
325+
if not isinstance(new, type(original)):
326+
raise TypeError(
327+
"Input data must be the same type as the original data. "
328+
f"Got {type(new)}, expected {type(original)}."
329+
)
330+
if isinstance(original, (Series, DataFrame)):
331+
extended = concat([original, new], axis=0)
332+
elif isinstance(original, np.ndarray):
333+
extended = np.concatenate([original, new])
334+
elif isinstance(original, list):
335+
extended = original + new
336+
else:
337+
raise TypeError(
338+
"Input data must be a pandas Series, DataFrame, numpy ndarray, or "
339+
f"list. Got {type(original)}."
340+
)
341+
return extended

0 commit comments

Comments
 (0)