diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cc04f70c..f5fa5d1f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,6 +30,7 @@ jobs: python --version python -m pip install --upgrade pip python -m pip install -r requirements.txt + python -m pip install jax python -m pip install torch displayName: 'Install dependencies' @@ -46,7 +47,7 @@ jobs: python -m pytest -v tslearn/ --doctest-modules displayName: 'Test' -- job: 'linux_without_torch' +- job: 'linux_without_torch_and_jax' pool: vmImage: 'ubuntu-latest' strategy: @@ -69,6 +70,7 @@ jobs: python --version python -m pip install --upgrade pip python -m pip install -r requirements.txt + python -m pip uninstall jax python -m pip uninstall torch displayName: 'Install dependencies' @@ -110,6 +112,7 @@ jobs: python --version python -m pip install --upgrade pip python -m pip install -r requirements.txt + python -m pip install jax python -m pip install torch displayName: 'Install dependencies' @@ -163,6 +166,7 @@ jobs: export OPENBLAS=$(brew --prefix openblas) python -m pip install --upgrade pip python -m pip install -r requirements.txt + python -m pip install jax python -m pip install torch displayName: 'Install dependencies' @@ -209,6 +213,7 @@ jobs: python --version python -m pip install --upgrade pip python -m pip install -r requirements_nocast.txt + python -m pip install jax python -m pip install torch displayName: 'Install dependencies' diff --git a/setup.py b/setup.py index ba887080..c0a9eba1 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ packages=find_packages(), package_data={"tslearn": [".cached_datasets/Trace.npz"]}, install_requires=['numpy', 'scipy', 'scikit-learn', 'numba', 'joblib'], - extras_require={'tests': ['pytest', 'torch'], 'pytorch': ['torch']}, + extras_require={'tests': ['pytest', 'torch', 'jax'], 'pytorch': ['torch'], 'jax': ['jax']}, version=VERSION, url="http://tslearn.readthedocs.io/", project_urls={ diff --git a/tslearn/backend/backend.py b/tslearn/backend/backend.py index c5465e8d..d1c031e1 100755 --- a/tslearn/backend/backend.py +++ b/tslearn/backend/backend.py @@ -2,6 +2,17 @@ from tslearn.backend.numpy_backend import NumPyBackend +try: + import jax + + from tslearn.backend.jax_backend import JAXBackend +except ImportError: + + class JAXBackend: + def __init__(self): + raise ValueError("Could not use JAX backend since JAX is not installed.") + + try: import torch @@ -11,7 +22,7 @@ class PyTorchBackend: def __init__(self): raise ValueError( - "Could not use pytorch backend since torch is not installed" + "Could not use PyTorch backend since Torch is not installed." ) @@ -28,14 +39,14 @@ def instantiate_backend(*args): backend : Backend instance The backend instance. """ + backends_str = ["numpy", "jax", "torch"] for arg in args: if isinstance(arg, Backend): return arg - arg_string = (str(type(arg)) + str(arg)).lower() - if "numpy" in arg_string: - return Backend("numpy") - if "torch" in arg_string: - return Backend("pytorch") + arg_str = (str(type(arg)) + str(arg)).lower() + for backend_str in backends_str: + if backend_str in arg_str: + return Backend(backend_str) return Backend("numpy") @@ -54,11 +65,15 @@ def select_backend(data): The backend class. If data is a Numpy array or data equals 'numpy' or data is None, backend equals NumpyBackend(). - If data is a PyTorch array or data equals 'pytorch', + If data is a JAX array or data equals 'jax', + backend equals JAXBackend(). + If data is a PyTorch tensor or data equals 'pytorch', backend equals PytorchBackend(). """ - arg_string = (str(type(data)) + str(data)).lower() - if "torch" in arg_string: + arg_str = (str(type(data)) + str(data)).lower() + if "jax" in arg_str: + return JAXBackend() + if "torch" in arg_str: return PyTorchBackend() return NumPyBackend() @@ -72,6 +87,8 @@ class Backend(object): Indicates the backend to choose. If data is a Numpy array or data equals 'numpy' or data is None, self.backend is set to NumpyBackend(). + If data is a JAX array or data equals 'jax', + self.backend is set to JAXBackend(). If data is a PyTorch array or data equals 'pytorch', self.backend is set to PytorchBackend(). Optional, default equals None. @@ -85,6 +102,7 @@ def __init__(self, data=None): setattr(self, element, getattr(self.backend, element)) self.is_numpy = self.backend_string == "numpy" + self.is_jax = self.backend_string == "jax" self.is_pytorch = self.backend_string == "pytorch" def get_backend(self): @@ -103,20 +121,24 @@ def cast(data, array_type="numpy"): The input data should be a list or numpy array or torch array. The data to cast. array_type: string - The type to cast the data. It can be "numpy", "pytorch" or "list". + The type to cast the data. It can be "numpy", "jax", "pytorch" or "list". Returns -------- data_cast: array-like Data cast to array_type. """ - data_type_string = str(type(data)).lower() + data_type_str = str(type(data)).lower() array_type = array_type.lower() if array_type == "pytorch": array_type = "torch" - if array_type in data_type_string: + if array_type in data_type_str: return data if array_type == "list": return data.tolist() be = Backend(array_type) + backends_str = ["numpy", "jax", "torch"] + for backend_str in backends_str: + if backend_str in data_type_str: + data = data.tolist() return be.array(data) diff --git a/tslearn/backend/jax_backend.py b/tslearn/backend/jax_backend.py new file mode 100755 index 00000000..fbd074f3 --- /dev/null +++ b/tslearn/backend/jax_backend.py @@ -0,0 +1,536 @@ +"""The JAX backend.""" + +import numpy as _np +from scipy.spatial.distance import cdist, pdist +from sklearn.metrics.pairwise import euclidean_distances, pairwise_distances + +try: + import jax as _jax + import jax.numpy as _jnp + from jax import config + + config.update("jax_enable_x64", True) + + HAS_JAX = True +except ImportError: + HAS_JAX = False + +if not HAS_JAX: + + class JAXBackend: + def __init__(self): + raise ValueError( + "Could not use the JAX backend since JAX is not installed." + ) + +else: + + class JAXMutableArray: + def __init__(self, *args, **kwargs): + if len(args) + len(kwargs) == 0: + self.array = None + else: + self.array = _jnp.array(*args, **kwargs) + + dtype = property(lambda self: self.array.dtype) + ndim = property(lambda self: self.array.ndim) + shape = property(lambda self: self.array.shape) + T = property(lambda self: self.from_jnp_array(self.array.T)) + + @classmethod + def from_jnp_array(cls, arr): + jni = cls() + jni.array = arr + return jni + + def __abs__(self): + return self.from_jnp_array(abs(self.array)) + + def __add__(self, other): + return self.from_jnp_array(self.array + other) + + def __array__(self, *args, **kwargs): + return _np.array(self.array, *args, **kwargs) + + def __bool__(self): + return self.array.__bool__() + + def __div__(self, other): + return self.from_jnp_array(self.array.__div__(other)) + + def __eq__(self, other): + return self.from_jnp_array(self.array == other) + + def __float__(self): + return self.array.__float__() + + def __floordiv__(self, other): + return self.from_jnp_array(self.array // other) + + def __ge__(self, other): + return self.from_jnp_array(self.array >= other) + + def __getitem__(self, key): + return self.from_jnp_array(self.array.at[key].get()) + + def __gt__(self, other): + return self.from_jnp_array(self.array > other) + + def __index__(self): + return self.from_jnp_array(self.array.__index__()) + + def __int__(self): + return self.array.__int__() + + def __invert__(self): + return self.from_jnp_array(self.array.__invert__()) + + def __iter__(self): + return self.from_jnp_array(self.array.__iter__()) + + def __jax_array__(self): + return self.array + + def __len__(self): + return self.array.__len__() + + def __le__(self, other): + return self.from_jnp_array(self.array <= other) + + def __lt__(self, other): + return self.from_jnp_array(self.array < other) + + def __lshift__(self, other): + return self.from_jnp_array(self.array << other) + + def __matmul__(self, other): + return self.from_jnp_array(self.array @ other) + + def __mod__(self, other): + return self.from_jnp_array(self.array % other) + + def __mul__(self, other): + return self.from_jnp_array(self.array * other) + + def __ne__(self, other): + return self.from_jnp_array(self.array != other) + + def __neg__(self): + return self.from_jnp_array(self.array.__neg__()) + + def __next__(self): + return self.from_jnp_array(self.array.__next__()) + + def __or__(self, other): + return self.from_jnp_array(self.array | other) + + def __pos__(self): + return self.from_jnp_array(self.array.__pos__()) + + def __pow__(self, other): + return self.from_jnp_array(self.array ** other) + + def __radd__(self, other): + return self.from_jnp_array(other + self.array) + + def __rdiv__(self, other): + return self.from_jnp_array(self.array.__rdiv__(other)) + + def __repr__(self): + return self.array.__repr__() + + def __rmul__(self, other): + return self.from_jnp_array(other * self.array) + + def __rshift__(self, other): + return self.from_jnp_array(self.array >> other) + + def __rsub__(self, other): + return self.from_jnp_array(other - self.array) + + def __rtruediv__(self, other): + return self.from_jnp_array(other / self.array) + + def __setitem__(self, key, value): + if hasattr(key, 'array'): + key = key.array + self.array = self.array.at[key].set(value) + + def __sub__(self, other): + return self.from_jnp_array(self.array - other) + + def __truediv__(self, other): + return self.from_jnp_array(self.array.__truediv__(other)) + + def __xor__(self, other): + return self.from_jnp_array(self.array ^ other) + + def astype(self, dtype): + self.array.astype(dtype) + + def conj(self): + return self.from_jnp_array(self.array.conj()) + + def copy(self): + return self.from_jnp_array(self.array.copy()) + + def reshape(self, shape, order='C'): + return self.from_jnp_array(self.array.reshape(shape, order=order)) + # self.array.reshape(shape, order=order) + + def tolist(self): + return self.array.tolist() + + + class JAXBackend(object): + """Class for the JAX backend.""" + + def __init__(self): + self.backend_string = "jax" + + self.linalg = JAXLinalg() + self.random = JAXRandom() + self.testing = JAXTesting() + + self.int8 = _jnp.int8 + self.int16 = _jnp.int16 + self.int32 = _jnp.int32 + self.int64 = _jnp.int64 + self.float32 = _jnp.float32 + self.float64 = _jnp.float64 + self.complex64 = _jnp.complex64 + self.complex128 = _jnp.complex128 + + self.dbl_max = _jnp.finfo("double").max + self.inf = _jnp.inf + self.nan = _jnp.nan + + @staticmethod + def abs(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.abs(*args, **kwargs)) + + @staticmethod + def all(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.all(*args, **kwargs)) + + @staticmethod + def any(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.any(*args, **kwargs)) + + @staticmethod + def arange(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.arange(*args, **kwargs)) + + @staticmethod + def argmax(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.argmax(*args, **kwargs)) + + @staticmethod + def argmin(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.argmin(*args, **kwargs)) + + @staticmethod + def array(*args, **kwargs): + return JAXMutableArray(*args, **kwargs) + + @staticmethod + def belongs_to_backend(x): + return "jax" in str(type(x)).lower() + + def cast(self, x, dtype): + return self.array(x, dtype=dtype) + + @staticmethod + def ceil(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.ceil(*args, **kwargs)) + + @staticmethod + def cos(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.cos(*args, **kwargs)) + + @staticmethod + def copy(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.copy(*args, **kwargs)) + + def cdist(self, XA, XB, metric='euclidean', p=2): + if metric == "euclidean": + metric = lambda x, y: self.linalg.norm(x - y) + if metric == "minkowski": + metric = lambda x, y: self.linalg.norm(x - y, ord=p) ** 2 + if metric == "sqeuclidean": + metric = lambda x, y: self.linalg.norm(x - y) ** 2 + if metric == "chebyshev": + metric = lambda x, y: self.linalg.norm(x - y, ord=self.inf) ** 2 + if callable(metric): + distance_matrix = self.zeros((XA.shape[0], XB.shape[0])) + for i in range(XA.shape[0]): + for j in range(XB.shape[0]): + distance_matrix[i, j] = metric(XA[i, ...], XB[j, ...]) + return distance_matrix + raise ValueError(f"Metric {metric} not implemented in JAX backend.") + + @staticmethod + def diag(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.diag(*args, **kwargs)) + + @staticmethod + def empty(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.empty(*args, **kwargs)) + + @staticmethod + def exp(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.exp(*args, **kwargs)) + + @staticmethod + def eye(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.eye(*args, **kwargs)) + + @staticmethod + def floor(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.floor(*args, **kwargs)) + + @staticmethod + def from_numpy(x): + return JAXMutableArray.from_jnp_array(_jnp.array(x)) + + @staticmethod + def full(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.full(*args, **kwargs)) + + @staticmethod + def full_like(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.full_like(*args, **kwargs)) + + @staticmethod + def hstack(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.hstack(*args, **kwargs)) + + @staticmethod + def is_array(x): + return type(x) is JAXMutableArray + + @staticmethod + def isclose(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.isclose(*args, **kwargs)) + + @staticmethod + def iscomplex(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.iscomplex(*args, **kwargs)) + + @staticmethod + def isfinite(x): + return JAXMutableArray.from_jnp_array(_jnp.isfinite(x)) + + @staticmethod + def is_float(x): + if hasattr(x, 'dtype'): + return 'float' in str(x.dtype) + return 'float' in str(x) + str(type(x)) + + @staticmethod + def is_float32(x): + if hasattr(x, 'dtype'): + return 'float32' in str(x.dtype) + return 'float32' in str(x) + str(type(x)) + + @staticmethod + def is_float64(x): + if hasattr(x, 'dtype'): + return 'float64' in str(x.dtype) + return 'float64' in str(x) + str(type(x)) + + @staticmethod + def isnan(x): + return JAXMutableArray.from_jnp_array(_jnp.isnan(x)) + + @staticmethod + def log(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.log(*args, **kwargs)) + + @staticmethod + def mean(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.mean(*args, **kwargs)) + + @staticmethod + def median(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.median(*args, **kwargs)) + + @staticmethod + def max(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.max(*args, **kwargs)) + + @staticmethod + def min(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.min(*args, **kwargs)) + + @staticmethod + def ndim(a): + return _jnp.ndim(a) + + @staticmethod + def ones(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.ones(*args, **kwargs)) + + def pairwise_distances(self, X, Y=None, metric="euclidean"): + if Y is None: + Y = X + if metric == "euclidean": + metric = lambda x, y: self.linalg.norm(x - y) + if metric == "sqeuclidean": + metric = lambda x, y: self.linalg.norm(x - y) ** 2 + if callable(metric): + distance_matrix = self.zeros((X.shape[0], Y.shape[0])) + for i in range(X.shape[0]): + for j in range(Y.shape[0]): + distance_matrix[i, j] = metric(X[i, ...], Y[j, ...]) + return distance_matrix + raise ValueError(f"Metric {metric} not implemented in JAX backend.") + + def pairwise_euclidean_distances(self, X, Y=None): + return self.pairwise_distances(X=X, Y=Y, metric="euclidean") + + def pdist(self, x, metric="euclidean", p=2): + if metric == "euclidean": + metric = lambda x, y: self.linalg.norm(x - y) + if metric == "minkowski": + metric = lambda x, y: self.linalg.norm(x - y, ord=p) ** 2 + if metric == "sqeuclidean": + metric = lambda x, y: self.linalg.norm(x - y) ** 2 + if metric == "chebyshev": + metric = lambda x, y: self.linalg.norm(x - y, ord=self.inf) ** 2 + if callable(metric): + n = x.shape[0] + distances = self.zeros((n * (n - 1)) // 2) + for i in range(n): + for j in range(i + 1, n): + distances[n * i + j - ((i + 2) * (i + 1)) // 2] = metric(x[i, ...], x[j, ...]) + return distances + raise ValueError(f"Metric {metric} not implemented in JAX backend.") + + @staticmethod + def reshape(a, newshape, order='C'): + return _jnp.reshape(a, newshape, order) + + @staticmethod + def round(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.round(*args, **kwargs)) + + @staticmethod + def shape(a): + return _jnp.shape(a) + + @staticmethod + def sin(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.sin(*args, **kwargs)) + + @staticmethod + def sqrt(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.sqrt(*args, **kwargs)) + + @staticmethod + def sum(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.sum(*args, **kwargs)) + + @staticmethod + def to_numpy(x): + return _np.array(_jnp.array(x)) + + @staticmethod + def transpose(a, axes=None): + return JAXMutableArray.from_jnp_array(_jnp.transpose(a=a, axes=axes)) + + @staticmethod + def tril(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.tril(*args, **kwargs)) + + @staticmethod + def tril_indices(n, k=0, m=None): + return _jnp.tril_indices(n=n, k=k, m=m) + + @staticmethod + def triu(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.triu(*args, **kwargs)) + + @staticmethod + def triu_indices(n, k=0, m=None): + return _jnp.triu_indices(n=n, k=k, m=m) + + @staticmethod + def vstack(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.vstack(*args, **kwargs)) + + @staticmethod + def zeros(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.zeros(*args, **kwargs)) + + @staticmethod + def zeros_like(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.zeros_like(*args, **kwargs)) + + + class JAXLinalg: + + @staticmethod + def inv(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.linalg.inv(*args, **kwargs)) + + @staticmethod + def norm(*args, **kwargs): + return JAXMutableArray.from_jnp_array(_jnp.linalg.norm(*args, **kwargs)) + + class JAXRandom: + def __init__(self): + self.key = _jax.random.PRNGKey(0) + + def rand(self, *args, dtype=float, key=None): + if key is None: + self.key = _jax.random.split(self.key, num=1)[0, :] + key = self.key + return JAXMutableArray.from_jnp_array(_jax.random.uniform( + key=key, shape=args, dtype=dtype, minval=0.0, maxval=1.0 + )) + + def randint(self, low, high=None, size=(), dtype=int, key=None): + if key is None: + self.key = _jax.random.split(self.key, num=1)[0, :] + key = self.key + if high is None: + minval = 0 + maxval = low + else: + minval = low + maxval = high + return JAXMutableArray.from_jnp_array(_jax.random.randint( + key=key, shape=size, minval=minval, maxval=maxval, dtype=dtype + )) + + def randn(self, *args, dtype=float, key=None): + if key is None: + self.key = _jax.random.split(self.key, num=1)[0, :] + key = self.key + return JAXMutableArray.from_jnp_array(_jax.random.normal( + key=key, shape=args, dtype=dtype + )) + + def uniform(self, low=0.0, high=1.0, size=(), dtype=float, key=None): + if key is None: + self.key = _jax.random.split(self.key, num=1)[0, :] + key = self.key + return JAXMutableArray.from_jnp_array(_jax.random.uniform( + key=key, shape=size, dtype=dtype, minval=low, maxval=high + )) + + class JAXTesting: + def __init__(self): + self.assert_equal = _np.testing.assert_equal + + @staticmethod + def assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg='', verbose=True): + return _np.testing.assert_allclose( + actual=JAXBackend.to_numpy(actual), + desired=JAXBackend.to_numpy(desired), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + err_msg=err_msg, + verbose=verbose) diff --git a/tslearn/backend/pytorch_backend.py b/tslearn/backend/pytorch_backend.py index 0c5755df..dc4039a9 100755 --- a/tslearn/backend/pytorch_backend.py +++ b/tslearn/backend/pytorch_backend.py @@ -131,13 +131,19 @@ def cast(self, x, dtype): return self.array(x, dtype=dtype) @staticmethod - def cdist(x, y, metric="euclidean", p=None): + def cdist(x, y, metric="euclidean", p=2): if metric == "euclidean": return _torch.cdist(x, y) if metric == "sqeuclidean": return _torch.cdist(x, y) ** 2 if metric == "minkowski": return _torch.cdist(x, y, p=p) + if callable(metric): + distance_matrix = _torch.zeros(x.shape[0], y.shape[0]) + for i in range(x.shape[0]): + for j in range(y.shape[0]): + distance_matrix[i, j] = metric(x[i, ...], y[j, ...]) + return distance_matrix raise ValueError(f"Metric {metric} not implemented in PyTorch backend.") @staticmethod @@ -189,13 +195,20 @@ def pairwise_distances(X, Y=None, metric="euclidean"): raise ValueError(f"Metric {metric} not implemented in PyTorch backend.") @staticmethod - def pdist(x, metric="euclidean", p=None): + def pdist(x, metric="euclidean", p=2): if metric == "euclidean": return _torch.pdist(x) if metric == "sqeuclidean": return _torch.pdist(x) ** 2 if metric == "minkowski": return _torch.pdist(x, p=p) + if callable(metric): + n = x.shape[0] + distances = _torch.zeros((n * (n - 1)) // 2) + for i in range(n): + for j in range(i + 1, n): + distances[n * i + j - ((i + 2) * (i + 1)) // 2] = metric(x[i, ...], x[j, ...]) + return distances raise ValueError(f"Metric {metric} not implemented in PyTorch backend.") def shape(self, data): @@ -261,5 +274,5 @@ def uniform(low=0.0, high=1.0, size=(1,), dtype=None): class PyTorchTesting: def __init__(self): - self.assert_allclose = _torch.allclose - self.assert_equal = _torch.testing.assert_close + self.assert_allclose = _np.testing.assert_allclose + self.assert_equal = _np.testing.assert_equal diff --git a/tslearn/metrics/ctw.py b/tslearn/metrics/ctw.py index a5ff19d0..df3e9e6b 100644 --- a/tslearn/metrics/ctw.py +++ b/tslearn/metrics/ctw.py @@ -1,7 +1,7 @@ import numpy as np from sklearn.cross_decomposition import CCA -from tslearn.backend import instantiate_backend +from tslearn.backend import cast, instantiate_backend from ..utils import to_time_series from .dtw_variants import dtw_path @@ -283,8 +283,10 @@ def ctw( human behavior". NIPS 2009. """ be = instantiate_backend(be, s1, s2) - s1 = be.array(s1) - s2 = be.array(s2) + # s1 = be.array(s1) + # s2 = be.array(s2) + s1 = cast(s1, array_type=be.backend_string) + s2 = cast(s2, array_type=be.backend_string) return ctw_path( s1=s1, s2=s2, diff --git a/tslearn/metrics/dtw_variants.py b/tslearn/metrics/dtw_variants.py index 48f0275c..ce3714da 100644 --- a/tslearn/metrics/dtw_variants.py +++ b/tslearn/metrics/dtw_variants.py @@ -3,7 +3,7 @@ import numpy from numba import njit, prange -from tslearn.backend import instantiate_backend +from tslearn.backend import cast, instantiate_backend from tslearn.utils import to_time_series from .utils import _cdist_generic @@ -625,7 +625,7 @@ def dtw_path_from_metric( """ # noqa: E501 be = instantiate_backend(be, s1, s2) if metric == "precomputed": # Pairwise distance given as input - s1 = be.array(s1) + s1 = cast(s1, array_type=be.backend_string) sz1, sz2 = be.shape(s1) mask = compute_mask( sz1, @@ -1223,8 +1223,8 @@ def subsequence_cost_matrix(subseq, longseq, be=None): Accumulated cost matrix. """ be = instantiate_backend(be, subseq, longseq) - subseq = be.array(subseq) - longseq = be.array(longseq) + subseq = cast(subseq, array_type=be.backend_string) + longseq = cast(longseq, array_type=be.backend_string) subseq = to_time_series(subseq, remove_nans=True, be=be) longseq = to_time_series(longseq, remove_nans=True, be=be) if be.is_numpy: @@ -1785,8 +1785,8 @@ def compute_mask( if isinstance(s1, int) and isinstance(s2, int): sz1, sz2 = s1, s2 else: - s1 = be.array(s1) - s2 = be.array(s2) + s1 = cast(s1, array_type=be.backend_string) + s2 = cast(s2, array_type=be.backend_string) sz1 = be.shape(s1)[0] sz2 = be.shape(s2)[0] if ( @@ -2163,7 +2163,7 @@ def lb_envelope(ts, radius=1, be=None): Conference on Very Large Data Bases, 2002. pp 406-417. """ be = instantiate_backend(be, ts) - ts = be.array(ts) + ts = cast(ts, array_type=be.backend_string) ts = to_time_series(ts, be=be) if be.is_numpy: return _njit_lb_envelope(ts, radius=radius) @@ -2252,10 +2252,10 @@ def lcss_accumulated_matrix(s1, s2, eps, mask, be=None): else: squared_dist = _local_squared_dist(s1[i - 1], s2[j - 1], be=be) if be.sqrt(squared_dist) <= eps: - acc_cost_mat[i][j] = 1 + acc_cost_mat[i - 1][j - 1] + acc_cost_mat[i, j] = 1 + acc_cost_mat[i - 1, j - 1] else: - acc_cost_mat[i][j] = max( - acc_cost_mat[i][j - 1], acc_cost_mat[i - 1][j] + acc_cost_mat[i, j] = max( + acc_cost_mat[i, j - 1], acc_cost_mat[i - 1, j] ) return acc_cost_mat @@ -2828,10 +2828,10 @@ def lcss_accumulated_matrix_from_dist_matrix(dist_matrix, eps, mask, be=None): for j in range(1, l2 + 1): if be.isfinite(mask[i - 1, j - 1]): if dist_matrix[i - 1, j - 1] <= eps: - acc_cost_mat[i][j] = 1 + acc_cost_mat[i - 1][j - 1] + acc_cost_mat[i, j] = 1 + acc_cost_mat[i - 1, j - 1] else: - acc_cost_mat[i][j] = max( - acc_cost_mat[i][j - 1], acc_cost_mat[i - 1][j] + acc_cost_mat[i, j] = max( + acc_cost_mat[i, j - 1], acc_cost_mat[i - 1, j] ) return acc_cost_mat diff --git a/tslearn/metrics/soft_dtw_fast.py b/tslearn/metrics/soft_dtw_fast.py index d5720372..471d6824 100644 --- a/tslearn/metrics/soft_dtw_fast.py +++ b/tslearn/metrics/soft_dtw_fast.py @@ -98,20 +98,25 @@ def _softmin3(a, b, c, gamma, be=None): @njit(fastmath=True) -def _njit_soft_dtw(D, R, gamma): +def _njit_soft_dtw(D, gamma): """Compute soft dynamic time warping. Parameters ---------- D : array-like, shape=(m, n), dtype=float64 - R : array-like, shape=(m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. + + Returns + ------- + R : array-like, shape=(m+2, n+2), dtype=float64 + We need +2 because we use indices starting from 1 + and to deal with edge cases in the backward recursion. """ - m = D.shape[0] - n = D.shape[1] + m, n = D.shape # Initialization. + R = np.zeros((m + 2, n + 2), dtype=np.float64) R[: m + 1, 0] = DBL_MAX R[0, : n + 1] = DBL_MAX R[0, 0] = 0 @@ -123,15 +128,15 @@ def _njit_soft_dtw(D, R, gamma): R[i, j] = D[i - 1, j - 1] + _njit_softmin3( R[i - 1, j], R[i - 1, j - 1], R[i, j - 1], gamma ) + return R -def _soft_dtw(D, R, gamma, be=None): +def _soft_dtw(D, gamma, be=None): """Compute soft dynamic time warping. Parameters ---------- D : array-like, shape=(m, n), dtype=float64 - R : array-like, shape=(m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. be : Backend object or string or None @@ -141,12 +146,19 @@ def _soft_dtw(D, R, gamma, be=None): the PyTorch backend is used. If `be` is `None`, the backend is determined by the input arrays. See our :ref:`dedicated user-guide page ` for more information. + + Returns + ------- + R : array-like, shape=(m+2, n+2), dtype=float64 + We need +2 because we use indices starting from 1 + and to deal with edge cases in the backward recursion. """ - be = instantiate_backend(be, D, R, gamma) + be = instantiate_backend(be, D, gamma) m, n = be.shape(D) # Initialization. + R = be.zeros((m + 2, n + 2), dtype=be.float64) R[: m + 1, 0] = be.dbl_max R[0, : n + 1] = be.dbl_max R[0, 0] = 0 @@ -162,30 +174,36 @@ def _soft_dtw(D, R, gamma, be=None): gamma, be=be, ) + return R @njit(parallel=True) -def _njit_soft_dtw_batch(D, R, gamma): +def _njit_soft_dtw_batch(D, gamma): """Compute soft dynamic time warping. Parameters ---------- D : array-like, shape=(b, m, n), dtype=float64 - R : array-like, shape=(b, m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. + + Returns + ------- + R : array-like, shape=(b, m+2, n+2), dtype=float64 """ - for i_sample in prange(D.shape[0]): - _njit_soft_dtw(D[i_sample, :, :], R[i_sample, :, :], gamma) + b, m, n = D.shape + R = np.zeros((b, m + 2, n + 2), dtype=np.float64) + for i_sample in prange(b): + R[i_sample, :, :] = _njit_soft_dtw(D[i_sample, :, :], gamma) + return R -def _soft_dtw_batch(D, R, gamma, be=None): +def _soft_dtw_batch(D, gamma, be=None): """Compute soft dynamic time warping. Parameters ---------- D : array-like, shape=(b, m, n), dtype=float64 - R : array-like, shape=(b, m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. be : Backend object or string or None @@ -195,14 +213,21 @@ def _soft_dtw_batch(D, R, gamma, be=None): the PyTorch backend is used. If `be` is `None`, the backend is determined by the input arrays. See our :ref:`dedicated user-guide page ` for more information. + + Returns + ------- + R : array-like, shape=(b, m+2, n+2), dtype=float64 """ - be = instantiate_backend(be, D, R) + be = instantiate_backend(be, D) + b, m, n = D.shape + R = be.zeros((b, m + 2, n + 2), dtype=be.float64) for i_sample in range(D.shape[0]): _soft_dtw(D[i_sample, :, :], R[i_sample, :, :], gamma, be=be) + return R @njit(fastmath=True) -def _njit_soft_dtw_grad(D, R, E, gamma): +def _njit_soft_dtw_grad(D, R, gamma): """Compute gradient of soft-DTW w.r.t. D. Parameters @@ -212,19 +237,27 @@ def _njit_soft_dtw_grad(D, R, E, gamma): E : array-like, shape=(m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. + + Returns + ------- + E : array-like, shape=(m+2, n+2), dtype=float64 + We need +2 because we use indices starting from 1 + and to deal with edge cases in the recursion. """ - m = D.shape[0] - 1 - n = D.shape[1] - 1 + m, n = D.shape + + # Add an extra row and an extra column to D. + # Needed to deal with edge cases in the recursion. + D = np.vstack((D, np.zeros((1, n)))) + D = np.hstack((D, np.zeros((m + 1, 1)))) # Initialization. - D[:m, n] = 0 - D[m, :n] = 0 - R[1 : m + 1, n + 1] = -DBL_MAX - R[m + 1, 1 : n + 1] = -DBL_MAX + R[1 : m + 1, n + 1] = - DBL_MAX + R[m + 1, 1 : n + 1] = - DBL_MAX + R[m + 1, n + 1] = R[m, n] + E = np.zeros((m + 2, n + 2), dtype=np.float64) E[m + 1, n + 1] = 1 - R[m + 1, n + 1] = R[m, n] - D[m, n] = 0 for j in range(n, 0, -1): # ranges from n to 1 for i in range(m, 0, -1): # ranges from m to 1 @@ -232,16 +265,16 @@ def _njit_soft_dtw_grad(D, R, E, gamma): b = np.exp((R[i, j + 1] - R[i, j] - D[i - 1, j]) / gamma) c = np.exp((R[i + 1, j + 1] - R[i, j] - D[i, j]) / gamma) E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c + return E -def _soft_dtw_grad(D, R, E, gamma, be=None): +def _soft_dtw_grad(D, R, gamma, be=None): """Compute gradient of soft-DTW w.r.t. D. Parameters ---------- D : array-like, shape=(m, n), dtype=float64 R : array-like, shape=(m+2, n+2), dtype=float64 - E : array-like, shape=(m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. be : Backend object or string or None @@ -251,21 +284,29 @@ def _soft_dtw_grad(D, R, E, gamma, be=None): the PyTorch backend is used. If `be` is `None`, the backend is determined by the input arrays. See our :ref:`dedicated user-guide page ` for more information. + + Returns + ------- + E : array-like, shape=(m+2, n+2), dtype=float64 + We need +2 because we use indices starting from 1 + and to deal with edge cases in the recursion. """ - be = instantiate_backend(be, D, R, E) + be = instantiate_backend(be, D, R) + + m, n = D.shape - m = D.shape[0] - 1 - n = D.shape[1] - 1 + # Add an extra row and an extra column to D. + # Needed to deal with edge cases in the recursion. + D = be.vstack((D, be.zeros((1, n)))) + D = be.hstack((D, be.zeros((m + 1, 1)))) # Initialization. - D[:m, n] = 0 - D[m, :n] = 0 - R[1 : m + 1, n + 1] = -be.dbl_max - R[m + 1, 1 : n + 1] = -be.dbl_max + R[1 : m + 1, n + 1] = - be.dbl_max + R[m + 1, 1 : n + 1] = - be.dbl_max + R[m + 1, n + 1] = R[m, n] + E = be.zeros((m + 2, n + 2), dtype=be.float64) E[m + 1, n + 1] = 1 - R[m + 1, n + 1] = R[m, n] - D[m, n] = 0 for j in range(n, 0, -1): # ranges from n to 1 for i in range(m, 0, -1): # ranges from m to 1 @@ -273,32 +314,38 @@ def _soft_dtw_grad(D, R, E, gamma, be=None): b = be.exp((R[i, j + 1] - R[i, j] - D[i - 1, j]) / gamma) c = be.exp((R[i + 1, j + 1] - R[i, j] - D[i, j]) / gamma) E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c + return E @njit(parallel=True) -def _njit_soft_dtw_grad_batch(D, R, E, gamma): +def _njit_soft_dtw_grad_batch(D, R, gamma): """Compute gradient of soft-DTW w.r.t. D. Parameters ---------- D : array-like, shape=(b, m, n), dtype=float64 R : array-like, shape=(b, m+2, n+2), dtype=float64 - E : array-like, shape=(b, m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. + + Returns + ------- + E : array-like, shape=(b, m+2, n+2), dtype=float64 """ - for i_sample in prange(D.shape[0]): - _njit_soft_dtw_grad(D[i_sample, :, :], R[i_sample, :, :], E[i_sample, :, :], gamma) + b, m, n = D.shape + E = np.zeros((b, m + 2, n + 2), dtype=np.float64) + for i_sample in prange(b): + E[i_sample, :, :] = _njit_soft_dtw_grad(D[i_sample, :, :], R[i_sample, :, :], gamma) + return E -def _soft_dtw_grad_batch(D, R, E, gamma, be=None): +def _soft_dtw_grad_batch(D, R, gamma, be=None): """Compute gradient of soft-DTW w.r.t. D. Parameters ---------- D : array-like, shape=(b, m, n), dtype=float64 R : array-like, shape=(b, m+2, n+2), dtype=float64 - E : array-like, shape=(b, m+2, n+2), dtype=float64 gamma : float64 Regularization parameter. be : Backend object or string or None @@ -308,14 +355,21 @@ def _soft_dtw_grad_batch(D, R, E, gamma, be=None): the PyTorch backend is used. If `be` is `None`, the backend is determined by the input arrays. See our :ref:`dedicated user-guide page ` for more information. + + Returns + ------- + E : array-like, shape=(b, m+2, n+2), dtype=float64 """ - be = instantiate_backend(be, D, R, E) - for i_sample in prange(D.shape[0]): - _soft_dtw_grad(D[i_sample, :, :], R[i_sample, :, :], E[i_sample, :, :], gamma, be=be) + be = instantiate_backend(be, D, R) + b, m, n = D.shape + E = be.zeros((b, m + 2, n + 2), dtype=be.float64) + for i_sample in range(b): + E[i_sample, :, :] = _soft_dtw_grad(D[i_sample, :, :], R[i_sample, :, :], gamma, be=be) + return E @njit(parallel=True, fastmath=True) -def _njit_jacobian_product_sq_euc(X, Y, E, G): +def _njit_jacobian_product_sq_euc(X, Y, E): """Compute the square Euclidean product between the Jacobian (a linear map from m x d to m x n) and a matrix E. @@ -326,6 +380,9 @@ def _njit_jacobian_product_sq_euc(X, Y, E, G): Y: array-like, shape=(n, d), dtype=float64 Second time series. E: array-like, shape=(m, n), dtype=float64 + + Returns + ------- G: array-like, shape=(m, d), dtype=float64 Product with Jacobian. ([m x d, m x n] * [m x n] = [m x d]). @@ -334,13 +391,16 @@ def _njit_jacobian_product_sq_euc(X, Y, E, G): n = Y.shape[0] d = X.shape[1] + G = np.zeros_like(X, dtype=np.float64) + for i in prange(m): for j in range(n): for k in range(d): G[i, k] += E[i, j] * 2 * (X[i, k] - Y[j, k]) + return G -def _jacobian_product_sq_euc(X, Y, E, G): +def _jacobian_product_sq_euc(X, Y, E, be): """Compute the square Euclidean product between the Jacobian (a linear map from m x d to m x n) and a matrix E. @@ -351,6 +411,11 @@ def _jacobian_product_sq_euc(X, Y, E, G): Y: array-like, shape=(n, d), dtype=float64 Second time series. E: array-like, shape=(m, n), dtype=float64 + be : Backend object or string or None + Backend. + + Returns + ------- G: array-like, shape=(m, d), dtype=float64 Product with Jacobian. ([m x d, m x n] * [m x n] = [m x d]). @@ -359,7 +424,10 @@ def _jacobian_product_sq_euc(X, Y, E, G): n = Y.shape[0] d = X.shape[1] + G = be.zeros_like(X, dtype=be.float64) + for i in range(m): for j in range(n): for k in range(d): G[i, k] += E[i, j] * 2 * (X[i, k] - Y[j, k]) + return G diff --git a/tslearn/metrics/soft_dtw_loss_pytorch.py b/tslearn/metrics/soft_dtw_loss_pytorch.py index 71921e5d..8e8dfd2f 100644 --- a/tslearn/metrics/soft_dtw_loss_pytorch.py +++ b/tslearn/metrics/soft_dtw_loss_pytorch.py @@ -49,10 +49,8 @@ def forward(ctx, D, gamma): """ dev = D.device dtype = D.dtype - b, m, n = torch.Tensor.size(D) D_ = D.detach().cpu().numpy() - R_ = np.zeros((b, m + 2, n + 2), dtype=np.float64) - _njit_soft_dtw_batch(D_, R_, gamma) + R_ = _njit_soft_dtw_batch(D_, gamma) gamma_tensor = torch.Tensor([gamma]).to(dev).type(dtype) R = torch.Tensor(R_).to(dev).type(dtype) ctx.save_for_backward(D, R, gamma_tensor) @@ -66,9 +64,8 @@ def backward(ctx, grad_output): b, m, n = torch.Tensor.size(D) D_ = D.detach().cpu().numpy() R_ = R.detach().cpu().numpy() - E_ = np.zeros((b, m + 2, n + 2), dtype=np.float64) gamma = gamma_tensor.item() - _njit_soft_dtw_grad_batch(D_, R_, E_, gamma) + E_ = _njit_soft_dtw_grad_batch(D_, R_, gamma) E = torch.Tensor(E_[:, 1 : m + 1, 1 : n + 1]).to(dev).type(dtype) return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None diff --git a/tslearn/metrics/softdtw_variants.py b/tslearn/metrics/softdtw_variants.py index 3be98011..fbfb909a 100644 --- a/tslearn/metrics/softdtw_variants.py +++ b/tslearn/metrics/softdtw_variants.py @@ -3,7 +3,7 @@ from numba import njit from sklearn.utils import check_random_state -from tslearn.backend import instantiate_backend +from tslearn.backend import cast, instantiate_backend from tslearn.utils import ( check_equal_size, to_time_series, @@ -51,7 +51,7 @@ def _gak(gram, be=None): Kernel value """ be = instantiate_backend(be, gram) - gram = be.array(gram) + gram = cast(gram, array_type=be.backend_string) sz1, sz2 = be.shape(gram) cum_sum = be.zeros((sz1 + 1, sz2 + 1)) @@ -181,9 +181,7 @@ def unnormalized_gak(s1, s2, sigma=1.0, be=None): be = instantiate_backend(be, s1, s2) s1 = to_time_series(s1, remove_nans=True, be=be) s2 = to_time_series(s2, remove_nans=True, be=be) - gram = _gak_gram(s1, s2, sigma=sigma, be=be) - if be.is_numpy: return _njit_gak(gram) return _gak(gram, be=be) @@ -238,8 +236,8 @@ def gak(s1, s2, sigma=1.0, be=None): # TODO: better doc (formula for the kernel .. [1] M. Cuturi, "Fast global alignment kernels," ICML 2011. """ be = instantiate_backend(be, s1, s2) - s1 = be.array(s1) - s2 = be.array(s2) + s1 = cast(s1, array_type=be.backend_string) + s2 = cast(s2, array_type=be.backend_string) denom = be.sqrt( unnormalized_gak(s1, s1, sigma=sigma, be=be) * unnormalized_gak(s2, s2, sigma=sigma, be=be) @@ -556,8 +554,8 @@ def soft_dtw(ts1, ts2, gamma=1.0, be=None, compute_with_backend=False): Time-Series," ICML 2017. """ # noqa: E501 be = instantiate_backend(be, ts1, ts2) - ts1 = be.array(ts1) - ts2 = be.array(ts2) + ts1 = cast(ts1, array_type=be.backend_string) + ts2 = cast(ts2, array_type=be.backend_string) if gamma == 0.0: return dtw(ts1, ts2, be=be) ** 2 return SoftDTW( @@ -668,8 +666,8 @@ def soft_dtw_alignment(ts1, ts2, gamma=1.0, be=None, compute_with_backend=False) Time-Series," ICML 2017. """ # noqa: E501 be = instantiate_backend(be, ts1, ts2) - ts1 = be.array(ts1) - ts2 = be.array(ts2) + ts1 = cast(ts1, array_type=be.backend_string) + ts2 = cast(ts2, array_type=be.backend_string) if gamma == 0.0: path, dist = dtw_path(ts1, ts2, be=be) dist_sq = dist**2 @@ -972,11 +970,10 @@ def __init__(self, D, gamma=1.0, be=None, compute_with_backend=False): Attributes ---------- - self.R_: array-like, shape =(m + 2, n + 2) + self.R_: array-like, shape=(m + 2, n + 2) Accumulated cost matrix (stored after calling `compute`). """ - be = instantiate_backend(be, D) - self.be = be + self.be = instantiate_backend(be, D) self.compute_with_backend = compute_with_backend if hasattr(D, "compute"): self.D = D.compute() @@ -1004,16 +1001,15 @@ def compute(self): m, n = self.be.shape(self.D) if self.be.is_numpy: - _njit_soft_dtw(self.D, self.R_, gamma=self.gamma) + self.R_ = _njit_soft_dtw(self.D, gamma=self.gamma) elif not self.compute_with_backend: - _njit_soft_dtw( + self.R_ = _njit_soft_dtw( self.be.to_numpy(self.D), - self.be.to_numpy(self.R_), gamma=self.be.to_numpy(self.gamma), ) self.R_ = self.be.array(self.R_) else: - _soft_dtw(self.D, self.R_, gamma=self.gamma, be=self.be) + self.R_ = _soft_dtw(self.D, gamma=self.gamma, be=self.be) self.computed = True @@ -1030,30 +1026,17 @@ def grad(self): if not self.computed: raise ValueError("Needs to call compute() first.") - m, n = self.be.shape(self.D) - - # Add an extra row and an extra column to D. - # Needed to deal with edge cases in the recursion. - D = self.be.vstack((self.D, self.be.zeros(n))) - D = self.be.hstack((D, self.be.zeros((m + 1, 1)))) - - # Allocate memory. - # We need +2 because we use indices starting from 1 - # and to deal with edge cases in the recursion. - E = self.be.zeros((m + 2, n + 2), dtype=self.be.float64) - if self.be.is_numpy: - _njit_soft_dtw_grad(D, self.R_, E, gamma=self.gamma) + E = _njit_soft_dtw_grad(self.D, self.R_, gamma=self.gamma) elif not self.compute_with_backend: - _njit_soft_dtw_grad( - self.be.to_numpy(D), + E = _njit_soft_dtw_grad( + self.be.to_numpy(self.D), self.be.to_numpy(self.R_), - self.be.to_numpy(E), gamma=self.be.to_numpy(self.gamma), ) self.R_ = self.be.array(self.R_) else: - _soft_dtw_grad(D, self.R_, E, gamma=self.gamma, be=self.be) + E = _soft_dtw_grad(self.D, self.R_, gamma=self.gamma, be=self.be) return E[1:-1, 1:-1] @@ -1086,8 +1069,8 @@ def __init__(self, X, Y, be=None, compute_with_backend=False): """ self.be = instantiate_backend(be, X, Y) self.compute_with_backend = compute_with_backend - self.X = self.be.cast(to_time_series(X, be=be), dtype=self.be.float64) - self.Y = self.be.cast(to_time_series(Y, be=be), dtype=self.be.float64) + self.X = self.be.cast(to_time_series(X, be=self.be), dtype=self.be.float64) + self.Y = self.be.cast(to_time_series(Y, be=self.be), dtype=self.be.float64) def compute(self): """Compute distance matrix. @@ -1114,21 +1097,18 @@ def jacobian_product(self, E): Product with Jacobian. ([m x d, m x n] * [m x n] = [m x d]). """ - G = self.be.zeros_like(self.X, dtype=self.be.float64) - if self.be.is_numpy: - _njit_jacobian_product_sq_euc(self.X, self.Y, E.astype(np.float64), G) + G = _njit_jacobian_product_sq_euc(self.X, self.Y, E.astype(np.float64)) elif not self.compute_with_backend: - _njit_jacobian_product_sq_euc( + G = _njit_jacobian_product_sq_euc( self.be.to_numpy(self.X), self.be.to_numpy(self.Y), self.be.to_numpy(E).astype(np.float64), - self.be.to_numpy(G), ) G = self.be.array(G) else: - _jacobian_product_sq_euc( - self.X, self.Y, self.be.cast(E, self.be.float64), G + G = _jacobian_product_sq_euc( + self.X, self.Y, self.be.cast(E, self.be.float64), be=self.be, ) return G diff --git a/tslearn/metrics/utils.py b/tslearn/metrics/utils.py index 7c93f098..f7e04a85 100644 --- a/tslearn/metrics/utils.py +++ b/tslearn/metrics/utils.py @@ -75,7 +75,6 @@ def _cdist_generic( """ # noqa: E501 be = instantiate_backend(be, dataset1, dataset2) dataset1 = to_time_series_dataset(dataset1, dtype=dtype, be=be) - if dataset2 is None: # Inspired from code by @GillesVandewiele: # https://github.com/rtavenar/tslearn/pull/128#discussion_r314978479 @@ -83,7 +82,6 @@ def _cdist_generic( indices = be.triu_indices( len(dataset1), k=0 if compute_diagonal else 1, m=len(dataset1) ) - matrix[indices] = be.array( Parallel(n_jobs=n_jobs, prefer="threads", verbose=verbose)( delayed(dist_fun)(dataset1[i], dataset1[j], *args, **kwargs) @@ -91,10 +89,8 @@ def _cdist_generic( for j in range(i if compute_diagonal else i + 1, len(dataset1)) ) ) - indices = be.tril_indices(len(dataset1), k=-1, m=len(dataset1)) matrix[indices] = matrix.T[indices] - return matrix else: dataset2 = to_time_series_dataset(dataset2, dtype=dtype, be=be) diff --git a/tslearn/tests/test_metrics.py b/tslearn/tests/test_metrics.py index ebd971d7..e8d0098c 100644 --- a/tslearn/tests/test_metrics.py +++ b/tslearn/tests/test_metrics.py @@ -12,13 +12,28 @@ __author__ = "Romain Tavenard romain.tavenard[at]univ-rennes2.fr" +backends = [Backend("numpy"), None] +array_types = ["numpy", "list"] + +try: + import jax + backends += [Backend("jax")] + array_types += ["jax"] + HAS_JAX = True +except ImportError: + HAS_JAX = False + try: import torch - backends = [Backend("numpy"), Backend("pytorch"), None] - array_types = ["numpy", "pytorch", "list"] + backends += [Backend("pytorch")] + array_types += ["pytorch"] + HAS_TORCH = True except ImportError: - backends = [Backend("numpy")] - array_types = ["numpy", "list"] + HAS_TORCH = False + +def test_backends_installation(): + assert HAS_TORCH + assert HAS_JAX def test_dtw(): @@ -27,14 +42,14 @@ def test_dtw(): backend = instantiate_backend(be, array_type) # dtw_path path, dist = tslearn.metrics.dtw_path(cast([1, 2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0], array_type), be=be) - np.testing.assert_equal(path, [(0, 0), (1, 1), (1, 2), (2, 3)]) - np.testing.assert_allclose(dist, [0.0]) + backend.testing.assert_equal(path, [(0, 0), (1, 1), (1, 2), (2, 3)]) + backend.testing.assert_allclose(dist, 0.0) assert backend.belongs_to_backend(dist) path, dist = tslearn.metrics.dtw_path( cast([1, 2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0, 4.0], array_type), be=be ) - np.testing.assert_allclose(dist, [1.0]) + backend.testing.assert_allclose(dist, 1.0) assert backend.belongs_to_backend(dist) # dtw @@ -43,19 +58,19 @@ def test_dtw(): x = cast(rng.randn(n1, d), array_type) y = cast(rng.randn(n2, d), array_type) - np.testing.assert_allclose( + backend.testing.assert_allclose( tslearn.metrics.dtw(x, y, be=be), tslearn.metrics.dtw_path(x, y, be=be)[1] ) # cdist_dtw dists = tslearn.metrics.cdist_dtw(cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), be=be) - np.testing.assert_allclose(dists, cast([[0.0, 1.0], [1.0, 0.0]], array_type)) + backend.testing.assert_allclose(dists, [[0.0, 1.0], [1.0, 0.0]]) dists = tslearn.metrics.cdist_dtw( cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), [[1, 2, 3], [2, 3, 4, 5]], # The second dataset can not be cast to array because of its shape be=be ) - np.testing.assert_allclose(dists, [[0.0, 2.44949], [1.0, 1.414214]], atol=1e-5) + backend.testing.assert_allclose(dists, [[0.0, 2.44949], [1.0, 1.414214]], atol=1e-5) assert backend.belongs_to_backend(dists) @@ -67,14 +82,14 @@ def test_ctw(): path, cca, dist = tslearn.metrics.ctw_path( cast([1, 2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0], array_type), be=be ) - np.testing.assert_equal(path, [(0, 0), (1, 1), (1, 2), (2, 3)]) - np.testing.assert_allclose(dist, 0.0) + backend.testing.assert_equal(path, [(0, 0), (1, 1), (1, 2), (2, 3)]) + backend.testing.assert_allclose(dist, 0.0) assert backend.belongs_to_backend(dist) path, cca, dist = tslearn.metrics.ctw_path( cast([1, 2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0, 4.0], array_type), be=be ) - np.testing.assert_allclose(dist, 1.0) + backend.testing.assert_allclose(dist, 1.0) assert backend.belongs_to_backend(dist) # dtw @@ -82,23 +97,23 @@ def test_ctw(): rng = np.random.RandomState(0) x = cast(rng.randn(n1, d1), array_type) y = cast(rng.randn(n2, d2), array_type) - np.testing.assert_allclose( + backend.testing.assert_allclose( tslearn.metrics.ctw(x, y, be=be), tslearn.metrics.ctw(y, x, be=be) ) - np.testing.assert_allclose( + backend.testing.assert_allclose( tslearn.metrics.ctw(x, y, be=be), tslearn.metrics.ctw_path(x, y, be=be)[-1] ) # cdist_dtw dists = tslearn.metrics.cdist_ctw(cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), be=be) - np.testing.assert_allclose(dists, [[0.0, 1.0], [1.0, 0.0]]) + backend.testing.assert_allclose(dists, [[0.0, 1.0], [1.0, 0.0]]) assert backend.belongs_to_backend(dist) dists = tslearn.metrics.cdist_ctw( cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), [[1, 2, 3], [2, 3, 4, 5]], be=be # The second dataset can not be cast to array because of its shape ) - np.testing.assert_allclose( + backend.testing.assert_allclose( dists, [[0.0, 2.44949], [1.0, 1.414214]], atol=1e-5 ) assert backend.belongs_to_backend(dists) @@ -107,6 +122,7 @@ def test_ctw(): def test_ldtw(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) n1, n2, d = 15, 10, 3 rng = np.random.RandomState(0) x = cast(rng.randn(n1, d), array_type) @@ -115,11 +131,10 @@ def test_ldtw(): ldtw_n1_plus_2 = tslearn.metrics.dtw_limited_warping_length(x, y, n1 + 2, be=be) # LDTW >= DTW - np.testing.assert_allclose( + backend.testing.assert_allclose( tslearn.metrics.dtw(x, y, be=be), ldtw_n1_plus_2, ) - backend = instantiate_backend(be, array_type) assert backend.belongs_to_backend(ldtw_n1_plus_2) # if path is too short, LDTW raises a ValueError @@ -139,16 +154,16 @@ def test_ldtw(): ) # if max_length is geq than length of optimal DTW path, LDTW = DTW - np.testing.assert_allclose( - cost, tslearn.metrics.dtw_limited_warping_length(x, y, len(path)) + backend.testing.assert_allclose( + cost, tslearn.metrics.dtw_limited_warping_length(x, y, len(path), be=be) ) - np.testing.assert_allclose( + backend.testing.assert_allclose( cost, tslearn.metrics.dtw_limited_warping_length(x, y, len(path) + 1, be=be) ) path, cost = tslearn.metrics.dtw_path_limited_warping_length( x, y, n1 + 2, be=be ) - np.testing.assert_allclose( + backend.testing.assert_allclose( cost, ldtw_n1_plus_2 ) assert len(path) <= n1 + 2 @@ -157,57 +172,60 @@ def test_ldtw(): def test_lcss(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) sim = tslearn.metrics.lcss(cast([1, 2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0], array_type), be=be) - np.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(sim, 1.0) assert isinstance(sim, float) sim = tslearn.metrics.lcss([1, 2, 3], [1.0, 2.0, 2.0, 4.0], be=be) - np.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(sim, 1.0) sim = tslearn.metrics.lcss([1, 2, 3], [-2.0, 5.0, 7.0], eps=3, be=be) - np.testing.assert_equal(round(sim, 2), 0.67) + backend.testing.assert_equal(round(sim, 2), 0.67) sim = tslearn.metrics.lcss([1, 2, 3], [1.0, 2.0, 2.0, 2.0, 3.0], eps=0, be=be) - np.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(sim, 1.0) sim = tslearn.metrics.lcss( [[1, 1], [2, 2], [3, 3]], [[1.0, 1.0], [2.0, 2.0], [2.0, 2.0], [2.0, 2.0], [3.0, 3.0]], eps=0, be=be) - np.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(sim, 1.0) def test_lcss_path(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) path, sim = tslearn.metrics.lcss_path( cast([1.0, 2.0, 3.0], array_type), cast([1.0, 2.0, 2.0, 3.0], array_type), be=be ) - np.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(sim, 1.0) assert isinstance(sim, float) - np.testing.assert_equal(path, [(0, 1), (1, 2), (2, 3)]) + backend.testing.assert_equal(path, [(0, 1), (1, 2), (2, 3)]) assert isinstance(path, list) path, sim = tslearn.metrics.lcss_path( cast([1.0, 2.0, 3.0], array_type), cast([1.0, 2.0, 2.0, 4.0], array_type), be=be ) - np.testing.assert_equal(sim, 1.0) - np.testing.assert_equal(path, [(0, 1), (1, 2), (2, 3)]) + backend.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(path, [(0, 1), (1, 2), (2, 3)]) path, sim = tslearn.metrics.lcss_path( cast([1.0, 2.0, 3.0], array_type), cast([-2.0, 5.0, 7.0], array_type), eps=3, be=be ) - np.testing.assert_equal(round(sim, 2), 0.67) - np.testing.assert_equal(path, [(0, 0), (2, 1)]) + backend.testing.assert_equal(round(sim, 2), 0.67) + backend.testing.assert_equal(path, [(0, 0), (2, 1)]) path, sim = tslearn.metrics.lcss_path( cast([1.0, 2.0, 3.0], array_type), cast([1.0, 2.0, 2.0, 2.0, 3.0], array_type), eps=0, be=be ) - np.testing.assert_equal(sim, 1.0) - np.testing.assert_equal(path, [(0, 0), (1, 3), (2, 4)]) + backend.testing.assert_equal(sim, 1.0) + backend.testing.assert_equal(path, [(0, 0), (1, 3), (2, 4)]) def test_lcss_path_from_metric(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) for d in np.arange(1, 5): rng = np.random.RandomState(0) s1 = cast(rng.randn(10, d), array_type) @@ -221,9 +239,9 @@ def test_lcss_path_from_metric(): s1, s2, metric="sqeuclidean", be=be ) - np.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(path, path_ref) assert isinstance(path, list) - np.testing.assert_equal(sim, sim_ref) + backend.testing.assert_equal(sim, sim_ref) assert isinstance(sim, float) # Test of defining a custom function @@ -233,35 +251,40 @@ def sqeuclidean(x, y): path, sim = tslearn.metrics.lcss_path_from_metric( s1, s2, metric=sqeuclidean, be=be ) - np.testing.assert_equal(path, path_ref) - np.testing.assert_equal(sim, sim_ref) + backend.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(sim, sim_ref) # Test of precomputing the distance matrix - dist_matrix = cdist(s1, s2, metric="sqeuclidean") + dist_matrix = backend.cdist( + cast(s1, array_type=backend.backend_string), + cast(s2, array_type=backend.backend_string), + metric="sqeuclidean") path, sim = tslearn.metrics.lcss_path_from_metric( dist_matrix, metric="precomputed", be=be ) - np.testing.assert_equal(path, path_ref) - np.testing.assert_equal(sim, sim_ref) + backend.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(sim, sim_ref) def test_constrained_paths(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) n, d = 10, 3 rng = np.random.RandomState(0) - x = cast(rng.randn(n, d), array_type) - y = cast(rng.randn(n, d), array_type) + x_np = rng.randn(n, d) + y_np = rng.randn(n, d) + x = cast(x_np, array_type) + y = cast(y_np, array_type) dtw_sakoe = tslearn.metrics.dtw( x, y, global_constraint="sakoe_chiba", sakoe_chiba_radius=0, be=be ) dtw_itak = tslearn.metrics.dtw( x, y, global_constraint="itakura", itakura_max_slope=1.0, be=be ) - backend = instantiate_backend(be, array_type) - euc_dist = backend.linalg.norm(backend.array(x) - backend.array(y)) - np.testing.assert_allclose(dtw_sakoe, euc_dist, atol=1e-5) - np.testing.assert_allclose(dtw_itak, euc_dist, atol=1e-5) + euc_dist = backend.linalg.norm(cast(x_np, backend.backend_string) - cast(y_np, backend.backend_string)) + backend.testing.assert_allclose(dtw_sakoe, euc_dist, atol=1e-5) + backend.testing.assert_allclose(dtw_itak, euc_dist, atol=1e-5) backend = instantiate_backend(be, array_type) assert backend.is_float(dtw_sakoe) assert backend.is_float(dtw_itak) @@ -290,19 +313,20 @@ def test_constrained_paths(): def test_dtw_subseq(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) path, dist = tslearn.metrics.dtw_subsequence_path( cast([2, 3], array_type), cast([1.0, 2.0, 2.0, 3.0, 4.0], array_type), be=be ) - np.testing.assert_allclose(path, [(0, 2), (1, 3)]) - np.testing.assert_allclose(dist, 0.0) + backend.testing.assert_allclose(path, [(0, 2), (1, 3)]) + backend.testing.assert_allclose(dist, 0.0) backend = instantiate_backend(be, array_type) backend.belongs_to_backend(dist) path, dist = tslearn.metrics.dtw_subsequence_path( cast([1, 4], array_type), cast([1.0, 2.0, 2.0, 3.0, 4.0], array_type), be=be ) - np.testing.assert_allclose(path, [(0, 2), (1, 3)]) - np.testing.assert_allclose(dist, np.sqrt(2.0)) + backend.testing.assert_allclose(path, [(0, 2), (1, 3)]) + backend.testing.assert_allclose(dist, np.sqrt(2.0)) assert backend.belongs_to_backend(dist) @@ -319,11 +343,11 @@ def test_dtw_subseq_path(): assert backend.belongs_to_backend(cost_matrix) path = tslearn.metrics.subsequence_path(cost_matrix, 3, be=be) - np.testing.assert_equal(path, [(0, 2), (1, 3)]) + backend.testing.assert_equal(path, [(0, 2), (1, 3)]) assert isinstance(path, list) path = tslearn.metrics.subsequence_path(cost_matrix, 1, be=be) - np.testing.assert_equal(path, [(0, 0), (1, 1)]) + backend.testing.assert_equal(path, [(0, 0), (1, 1)]) assert isinstance(path, list) @@ -340,7 +364,7 @@ def test_masks(): [np.inf, np.inf, 0.0, 0.0], ] ) - np.testing.assert_allclose(sk_mask, reference_mask) + backend.testing.assert_allclose(sk_mask, reference_mask) assert backend.belongs_to_backend(sk_mask) sk_mask = tslearn.metrics.sakoe_chiba_mask(7, 3, 1, be=be) @@ -355,7 +379,7 @@ def test_masks(): [np.inf, 0.0, 0.0], ] ) - np.testing.assert_allclose(sk_mask, reference_mask) + backend.testing.assert_allclose(sk_mask, reference_mask) assert backend.belongs_to_backend(sk_mask) i_mask = tslearn.metrics.itakura_mask(6, 6, be=be) @@ -369,7 +393,7 @@ def test_masks(): [np.inf, np.inf, np.inf, np.inf, np.inf, 0.0], ] ) - np.testing.assert_allclose(i_mask, reference_mask) + backend.testing.assert_allclose(i_mask, reference_mask) assert backend.belongs_to_backend(i_mask) # Test masks for different combinations of global_constraints / @@ -380,8 +404,8 @@ def test_masks(): mask_no_constraint = tslearn.metrics.dtw_variants.compute_mask( ts0, ts1, global_constraint=0, be=be ) - np.testing.assert_allclose(mask_no_constraint, np.zeros((sz, sz))) backend = instantiate_backend(be, array_type) + backend.testing.assert_allclose(mask_no_constraint, backend.zeros((sz, sz))) assert backend.belongs_to_backend(mask_no_constraint) mask_itakura = tslearn.metrics.dtw_variants.compute_mask( @@ -390,7 +414,7 @@ def test_masks(): mask_itakura_bis = tslearn.metrics.dtw_variants.compute_mask( ts0, ts1, itakura_max_slope=2.0, be=be ) - np.testing.assert_allclose(mask_itakura, mask_itakura_bis) + backend.testing.assert_allclose(mask_itakura, mask_itakura_bis) assert backend.belongs_to_backend(mask_itakura) mask_sakoe = tslearn.metrics.dtw_variants.compute_mask( @@ -400,7 +424,7 @@ def test_masks(): mask_sakoe_bis = tslearn.metrics.dtw_variants.compute_mask( ts0, ts1, sakoe_chiba_radius=1, be=be ) - np.testing.assert_allclose(mask_sakoe, mask_sakoe_bis) + backend.testing.assert_allclose(mask_sakoe, mask_sakoe_bis) assert backend.belongs_to_backend(mask_sakoe) np.testing.assert_raises( @@ -433,12 +457,8 @@ def test_masks(): max_iter=5, random_state=0, ) - np.testing.assert_allclose( - estimator1.fit(time_series).labels_, estimator2.fit(time_series).labels_ - ) - np.testing.assert_allclose( - estimator1.fit(time_series).labels_, estimator3.fit(time_series).labels_ - ) + backend.testing.assert_allclose(estimator1.fit(time_series).labels_, estimator2.fit(time_series).labels_) + backend.testing.assert_allclose(estimator1.fit(time_series).labels_, estimator3.fit(time_series).labels_) def test_gak(): @@ -449,8 +469,8 @@ def test_gak(): g = tslearn.metrics.cdist_gak( cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), sigma=2.0, be=be ) - np.testing.assert_allclose( - g, np.array([[1.0, 0.656297], [0.656297, 1.0]]), atol=1e-5 + backend.testing.assert_allclose( + g, backend.array([[1.0, 0.656297], [0.656297, 1.0]]), atol=1e-5 ) assert backend.belongs_to_backend(g) g = tslearn.metrics.cdist_gak( @@ -459,8 +479,8 @@ def test_gak(): sigma=2.0, be=be, ) - np.testing.assert_allclose( - g, np.array([[0.710595, 0.297229], [0.656297, 1.0]]), atol=1e-5 + backend.testing.assert_allclose( + g, backend.array([[0.710595, 0.297229], [0.656297, 1.0]]), atol=1e-5 ) assert backend.belongs_to_backend(g) @@ -468,8 +488,8 @@ def test_gak(): d = tslearn.metrics.cdist_soft_dtw( cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type), gamma=0.01, be=be ) - np.testing.assert_allclose( - d, np.array([[-0.010986, 1.0], [1.0, 0.0]]), atol=1e-5 + backend.testing.assert_allclose( + d, backend.array([[-0.010986, 1.0], [1.0, 0.0]]), atol=1e-5 ) assert backend.belongs_to_backend(d) @@ -479,8 +499,8 @@ def test_gak(): gamma=0.01, be=be, ) - np.testing.assert_allclose( - d, np.array([[-0.010986, 1.0], [1.0, 0.0]]), atol=1e-5 + backend.testing.assert_allclose( + d, backend.array([[-0.010986, 1.0], [1.0, 0.0]]), atol=1e-5 ) assert backend.belongs_to_backend(d) @@ -499,7 +519,7 @@ def test_gak(): cast(np.array(v2.flat), array_type), be=be) c_dist = cdist(v1, v2, metric="sqeuclidean") sqeuc_compute = sqeuc.compute() - np.testing.assert_allclose(sqeuc_compute, c_dist, atol=1e-5) + backend.testing.assert_allclose(sqeuc_compute, c_dist, atol=1e-5) assert backend.belongs_to_backend(sqeuc_compute) @@ -511,12 +531,12 @@ def test_gak(): def test_gamma_soft_dtw(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) dataset = cast([[1, 2, 2, 3], [1.0, 2.0, 3.0, 4.0]], array_type) gamma = tslearn.metrics.gamma_soft_dtw( dataset=dataset, n_samples=200, random_state=0, be=be ) - np.testing.assert_allclose(gamma, 8.0) - backend = instantiate_backend(be, array_type) + backend.testing.assert_allclose(gamma, 8.0) assert backend.belongs_to_backend(gamma) @@ -533,26 +553,26 @@ def test_symmetric_cdist(): dataset = rng.randn(5, 10, 2) dataset = cast(dataset, array_type) c_dist_dtw = tslearn.metrics.cdist_dtw(dataset, be=be) - np.testing.assert_allclose( + backend.testing.assert_allclose( c_dist_dtw, tslearn.metrics.cdist_dtw(dataset, dataset, be=be), ) assert backend.belongs_to_backend(c_dist_dtw) c_dist_gak = tslearn.metrics.cdist_gak(dataset, be=be) - np.testing.assert_allclose( + backend.testing.assert_allclose( c_dist_gak, tslearn.metrics.cdist_gak(dataset, dataset, be=be), atol=1e-5, ) assert backend.belongs_to_backend(c_dist_gak) c_dist_soft_dtw = tslearn.metrics.cdist_soft_dtw(dataset, be=be) - np.testing.assert_allclose( + backend.testing.assert_allclose( c_dist_soft_dtw, tslearn.metrics.cdist_soft_dtw(dataset, dataset, be=be), ) assert backend.belongs_to_backend(c_dist_soft_dtw) c_dist_soft_dtw_normalized = tslearn.metrics.cdist_soft_dtw_normalized(dataset, be=be) - np.testing.assert_allclose( + backend.testing.assert_allclose( c_dist_soft_dtw_normalized, tslearn.metrics.cdist_soft_dtw_normalized(dataset, dataset, be=be), ) @@ -562,15 +582,15 @@ def test_symmetric_cdist(): def test_lb_keogh(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) ts1 = cast([1, 2, 3, 2, 1], array_type) env_low, env_up = tslearn.metrics.lb_envelope(ts1, radius=1, be=be) - np.testing.assert_allclose( + backend.testing.assert_allclose( env_low, np.array([[1.0], [1.0], [2.0], [1.0], [1.0]]) ) - np.testing.assert_allclose( + backend.testing.assert_allclose( env_up, np.array([[2.0], [3.0], [3.0], [3.0], [2.0]]) ) - backend = instantiate_backend(be, array_type) assert backend.belongs_to_backend(env_low) assert backend.belongs_to_backend(env_up) @@ -590,9 +610,9 @@ def test_dtw_path_from_metric(): path, dist = tslearn.metrics.dtw_path_from_metric( s1, s2, metric="sqeuclidean", be=be ) - np.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(path, path_ref) assert isinstance(path, list) - np.testing.assert_allclose(backend.sqrt(dist), dist_ref) + backend.testing.assert_allclose(backend.sqrt(dist), dist_ref) assert backend.belongs_to_backend(dist) # Test of defining a custom function @@ -602,26 +622,30 @@ def sqeuclidean(x, y): path, dist = tslearn.metrics.dtw_path_from_metric( s1, s2, metric=sqeuclidean, be=be ) - np.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(path, path_ref) assert isinstance(path, list) - np.testing.assert_allclose(backend.sqrt(dist), dist_ref) + backend.testing.assert_allclose(backend.sqrt(dist), dist_ref) assert backend.belongs_to_backend(dist) # Test of precomputing the distance matrix - dist_matrix = cdist(s1, s2, metric="sqeuclidean") + dist_matrix = backend.cdist( + cast(s1, array_type=backend.backend_string), + cast(s2, array_type=backend.backend_string), + metric="sqeuclidean") dist_matrix = cast(dist_matrix, array_type) path, dist = tslearn.metrics.dtw_path_from_metric( dist_matrix, metric="precomputed", be=be ) - np.testing.assert_equal(path, path_ref) + backend.testing.assert_equal(path, path_ref) assert isinstance(path, list) - np.testing.assert_allclose(backend.sqrt(dist), dist_ref) + backend.testing.assert_allclose(backend.sqrt(dist), dist_ref) assert backend.belongs_to_backend(dist) def test_softdtw(): for be in backends: for array_type in array_types: + backend = instantiate_backend(be, array_type) rng = np.random.RandomState(0) s1 = cast(rng.rand(10, 2), array_type) s2 = cast(rng.rand(30, 2), array_type) @@ -629,7 +653,6 @@ def test_softdtw(): # Use dtw_path as a reference path_ref, dist_ref = tslearn.metrics.dtw_path(s1, s2, be=be) assert isinstance(path_ref, list) - backend = instantiate_backend(be, array_type) assert backend.belongs_to_backend(dist_ref) mat_path_ref = np.zeros((10, 30)) for i, j in path_ref: @@ -640,8 +663,8 @@ def test_softdtw(): assert backend.belongs_to_backend(matrix_path) assert backend.belongs_to_backend(dist) - np.testing.assert_equal(dist, dist_ref**2) - np.testing.assert_allclose(matrix_path, mat_path_ref) + backend.testing.assert_allclose(dist, dist_ref**2, atol=1e-6) + backend.testing.assert_allclose(matrix_path, mat_path_ref) ts1 = cast([[0.0]], array_type) ts2 = cast([[1.0]], array_type) @@ -673,7 +696,7 @@ def test_dtw_path_with_empty_or_nan_inputs(): @pytest.mark.skipif( - len(backends) == 1, + not HAS_TORCH, reason="Skipping test that requires pytorch backend", ) def test_soft_dtw_loss_pytorch(): diff --git a/tslearn/utils/utils.py b/tslearn/utils/utils.py index 9f4b6458..49ea6c1b 100644 --- a/tslearn/utils/utils.py +++ b/tslearn/utils/utils.py @@ -17,7 +17,7 @@ from sklearn.utils.estimator_checks import _NotAnArray as NotAnArray except ImportError: # Old sklearn versions from sklearn.utils.estimator_checks import NotAnArray -from tslearn.backend import instantiate_backend +from tslearn.backend import cast, instantiate_backend from tslearn.bases import TimeSeriesBaseEstimator __author__ = "Romain Tavenard romain.tavenard[at]univ-rennes2.fr" @@ -156,7 +156,7 @@ def to_time_series(ts, remove_nans=False, be=None): to_time_series_dataset : Transforms a dataset of time series """ be = instantiate_backend(be, ts) - ts_out = be.array(ts) + ts_out = cast(ts, array_type=be.backend_string) if ts_out.ndim <= 1: ts_out = be.reshape(ts_out, (-1, 1)) if not be.is_float(ts_out): @@ -219,7 +219,7 @@ def to_time_series_dataset(dataset, dtype=float, be=None): return to_time_series_dataset(be.array(dataset), be=be) if len(dataset) == 0: return be.zeros((0, 0, 0)) - if be.ndim(be.array(dataset[0])) == 0: + if be.ndim(cast(dataset[0], array_type=be.backend_string)) == 0: dataset = [dataset] n_ts = len(dataset) max_sz = max(