Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions tslearn/metrics/soft_dtw_keras
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
try:
import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
import keras
from keras import layers

HAS_DEPS = True
except ImportError:
HAS_DEPS = False

if not HAS_DEPS:
class SoftDTWLossTF:
def __init__(self, *args, **kwargs):
raise ValueError(
"Could not use SoftDTWLossTF since keras and/or tensorflow are not installed"
)
else:
DBL_MAX = np.finfo("double").max

def euclidean_squared_dist(x, y):
"""Calculates the Euclidean squared distance between each element in x and y per timestep.

Parameters
----------
x : Tensor, shape=[b, m, d]
Batch of time series.
y : Tensor, shape=[b, n, d]
Batch of time series.

Returns
-------
dist : Tensor, shape=[b, m, n]
The pairwise squared Euclidean distances.
"""
# Broadcast to each other
x = x[:, :, None]
y = y[:, None]
return K.sum(K.pow(x - y, 2), axis=3)

def softmin3(a, b, c, *, gamma):
r"""Compute softmin of 3 input variables with parameter gamma.
The inputs need to be of the same shape.

In the limit case :math:`\gamma = 0`, the softmin operator reduces to
a hard-min operator.

Parameters
----------
a : float64[...]
First input variable.
b : float64[...]
Second input variable.
c : float64[...]
Third input variable.
gamma : float64
Regularization parameter.

Returns
-------
softmin_value : float64
Softmin value.
"""
a /= -gamma
b /= -gamma
c /= -gamma

max_val = layers.Maximum()(inputs=[a, b, c])

tmp = (
K.exp(a - max_val)
+ K.exp(b - max_val)
+ K.exp(c - max_val)
)
return -gamma * (K.log(tmp) + max_val)


@tf.custom_gradient
def soft_dtw(D, gamma):
"""Compute soft dynamic time warping.

Parameters
----------
D : array-like, shape=(m, n), dtype=float64
R : array-like, shape=(m+2, n+2), dtype=float64
gamma : float64
Regularization parameter.
"""
b, m, n = D.shape

# Initialization.
R = np.empty((m + 2, n + 2), dtype=object)
# Doesn't work because of numpy auto broadcast
#R[: m + 1, 0] = K.constant(DBL_MAX)
#R[0, : n + 1] = K.constant(DBL_MAX)
# Instead we check for None later
dbl_max = K.constant(DBL_MAX)[None, None, None]

R[0, 0] = K.constant(0)[None, None, None]

# DP recursion.
for i in range(1, m + 1):
for j in range(1, n + 1):
# D is indexed starting from 0.
R[i, j] = D[:, i - 1, j - 1] + softmin3(
R[i - 1, j] if R[i - 1, j] is not None else dbl_max,
R[i - 1, j - 1] if R[i - 1, j - 1] is not None else dbl_max,
R[i, j - 1] if R[i, j - 1] is not None else dbl_max,
gamma=gamma
)

def grad(upstream):
D_ = np.zeros((m, n), dtype=object)
R_ = np.zeros((m + 2, n + 2), dtype=object)
E_ = np.zeros((m + 2, n + 2), dtype=object)

for i in range(m):
for j in range(n):
D_[i, j] = D[:, i, j]


for i in range(m+2):
for j in range(n+2):
R_[i, j] = R[i, j]

m_ = m - 1
n_ = n - 1

# Initialization.
D_[:m_, n_] = 0
D_[m_, :n_] = 0
for i in range(1, m_ + 1):
R_[i, n_ + 1] = -dbl_max
for i in range(1, n_ + 1):
R_[m_ + 1, i] = -dbl_max

E_[m_ + 1, n_ + 1] = 1
R_[m_ + 1, n_ + 1] = R[m_, n_]
D_[m_, n_] = 0

for j in range(n_, 0, -1): # ranges from n to 1
for i in range(m_, 0, -1): # ranges from m to 1
a = K.exp((R_[i + 1, j] - R_[i, j] - D_[i, j - 1]) / gamma)
b = K.exp((R_[i, j + 1] - R_[i, j] - D_[i - 1, j]) / gamma)
c = K.exp((R_[i + 1, j + 1] - R_[i, j] - D_[i, j]) / gamma)
E_[i, j] = E_[i + 1, j] * a + E_[i, j + 1] * b + E_[i + 1, j + 1] * c

ll = list(E_[1 : m + 1, 1 : n + 1].flatten())
for i in range(len(ll)):
if isinstance(ll[i], int):
ll[i] = K.constant(ll[i])[None, None, None]

E_T = layers.Reshape((m, n), name="dtw_grad_reshape")(inputs=layers.Concatenate()(inputs=ll))

return upstream[:, None, None] * E_T, None

return R[-2, -2], grad

class SoftDTWLossTF(tf.keras.losses.Loss):
def __init__(self, gamma=1.0, normalize=False, dist_func=None):
super().__init__()
self.gamma = gamma
self.normalize = normalize
self.dist_func = dist_func if dist_func is not None else euclidean_squared_dist

def call(self, y_true, y_pred):
x = y_true
y = y_pred

bx, lx, dx = x.shape
by, ly, dy = y.shape

assert bx == by
assert dx == dy

d_xy = self.dist_func(x, y)
loss_xy = soft_dtw(d_xy, self.gamma)

if self.normalize:
d_xx = self.dist_func(x, x)
d_yy = self.dist_func(y, y)
loss_xx = soft_dtw(d_xx, self.gamma)
loss_yy = soft_dtw(d_yy, self.gamma)
return loss_xy - 1 / 2 * (loss_xx + loss_yy)
else:
return loss_xy