Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sakoe-Chiba band constraint #9

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdtw/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .soft_dtw_fast import _jacobian_product_sq_euc


class SquaredEuclidean(object):

def __init__(self, X, Y):
Expand Down
19 changes: 13 additions & 6 deletions sdtw/soft_dtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class SoftDTW(object):

def __init__(self, D, gamma=1.0):
def __init__(self, D, gamma=1.0, sakoe_chiba_band=-1):
"""
Parameters
----------
Expand All @@ -20,6 +20,10 @@ def __init__(self, D, gamma=1.0):
Regularization parameter.
Lower is less smoothed (closer to true DTW).

sakoe_chiba_band: int
If non-negative, the DTW is restricted to a Sakoe-Chiba band around
the diagonal. The band has a width of 2 * sakoe_chiba_band + 1.

Attributes
----------
self.R_: array, shape = [m + 2, n + 2]
Expand All @@ -33,6 +37,7 @@ def __init__(self, D, gamma=1.0):
self.D = self.D.astype(np.float64)

self.gamma = gamma
self.sakoe_chiba_band = sakoe_chiba_band

def compute(self):
"""
Expand All @@ -48,9 +53,10 @@ def compute(self):
# Allocate memory.
# We need +2 because we use indices starting from 1
# and to deal with edge cases in the backward recursion.
self.R_ = np.zeros((m+2, n+2), dtype=np.float64)
self.R_ = np.zeros((m + 2, n + 2), dtype=np.float64)

_soft_dtw(self.D, self.R_, gamma=self.gamma)
_soft_dtw(self.D, self.R_, gamma=self.gamma,
sakoe_chiba_band=self.sakoe_chiba_band)

return self.R_[m, n]

Expand All @@ -71,13 +77,14 @@ def grad(self):
# Add an extra row and an extra column to D.
# Needed to deal with edge cases in the recursion.
D = np.vstack((self.D, np.zeros(n)))
D = np.hstack((D, np.zeros((m+1, 1))))
D = np.hstack((D, np.zeros((m + 1, 1))))

# Allocate memory.
# We need +2 because we use indices starting from 1
# and to deal with edge cases in the recursion.
E = np.zeros((m+2, n+2))
E = np.zeros((m + 2, n + 2))

_soft_dtw_grad(D, self.R_, E, gamma=self.gamma)
_soft_dtw_grad(D, self.R_, E, gamma=self.gamma,
sakoe_chiba_band=self.sakoe_chiba_band)

return E[1:-1, 1:-1]
74 changes: 63 additions & 11 deletions sdtw/soft_dtw_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,35 @@ cdef inline double _softmin3(double a,
return -gamma * (log(tmp) + max_val)


cdef inline int is_outside_sakoe_chiba_band(int sakoe_chiba_band,
int i,
int j,
int m,
int n):
"""True if the Sakoe-Chiba band constraint is used, and if (i, j) is outside

This constraints the wrapping to a band around the diagonal.
"""
cdef int diff, bound

if sakoe_chiba_band < 0:
return 0
else:
# since (i, j) starts at (1, 1)
i, j = i - 1, j - 1

diff = i * (n - 1) - j * (m - 1)
diff = abs(diff * 2)
bound = max(m, n) * (sakoe_chiba_band + 1)
is_in_band = diff < bound

return not is_in_band


def _soft_dtw(np.ndarray[double, ndim=2] D,
np.ndarray[double, ndim=2] R,
double gamma):
double gamma,
int sakoe_chiba_band=-1):

cdef int m = D.shape[0]
cdef int n = D.shape[1]
Expand All @@ -57,17 +83,22 @@ def _soft_dtw(np.ndarray[double, ndim=2] D,
# DP recursion.
for i in range(1, m + 1):
for j in range(1, n + 1):
# D is indexed starting from 0.
R[i, j] = D[i-1, j-1] + _softmin3(R[i-1, j],
R[i-1, j-1],
R[i, j-1],
gamma)

if is_outside_sakoe_chiba_band(sakoe_chiba_band, i, j, m, n):
R[i, j] = DBL_MAX
else:
# D is indexed starting from 0.
R[i, j] = D[i-1, j-1] + _softmin3(R[i-1, j],
R[i-1, j-1],
R[i, j-1],
gamma)


def _soft_dtw_grad(np.ndarray[double, ndim=2] D,
np.ndarray[double, ndim=2] R,
np.ndarray[double, ndim=2] E,
double gamma):
double gamma,
int sakoe_chiba_band=-1):

# We added an extra row and an extra column on the Python side.
cdef int m = D.shape[0] - 1
Expand Down Expand Up @@ -95,10 +126,27 @@ def _soft_dtw_grad(np.ndarray[double, ndim=2] D,
# DP recursion.
for j in reversed(range(1, n+1)): # ranges from n to 1
for i in reversed(range(1, m+1)): # ranges from m to 1
a = exp((R[i+1, j] - R[i, j] - D[i, j-1]) / gamma)
b = exp((R[i, j+1] - R[i, j] - D[i-1, j]) / gamma)
c = exp((R[i+1, j+1] - R[i, j] - D[i, j]) / gamma)
E[i, j] = E[i+1, j] * a + E[i, j+1] * b + E[i+1,j+1] * c

if is_outside_sakoe_chiba_band(sakoe_chiba_band, i, j, m, n):
E[i, j] = 0
R[i, j] = -DBL_MAX
else:
if E[i+1, j] == 0:
a = 0
else:
a = exp((R[i+1, j] - R[i, j] - D[i, j-1]) / gamma)

if E[i, j+1] == 0:
b = 0
else:
b = exp((R[i, j+1] - R[i, j] - D[i-1, j]) / gamma)

if E[i+1,j+1] == 0:
c = 0
else:
c = exp((R[i+1, j+1] - R[i, j] - D[i, j]) / gamma)

E[i, j] = E[i+1, j] * a + E[i, j+1] * b + E[i+1,j+1] * c


def _jacobian_product_sq_euc(np.ndarray[double, ndim=2] X,
Expand All @@ -108,8 +156,12 @@ def _jacobian_product_sq_euc(np.ndarray[double, ndim=2] X,
cdef int m = X.shape[0]
cdef int n = Y.shape[0]
cdef int d = X.shape[1]
cdef int i, j, k

for i in range(m):
for j in range(n):

if E[i,j] == 0:
continue
for k in range(d):
G[i, k] += E[i,j] * 2 * (X[i, k] - Y[j, k])
45 changes: 45 additions & 0 deletions sdtw/tests/test_soft_dtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_soft_dtw():
assert_almost_equal(SoftDTW(D, gamma).compute(),
_soft_dtw_bf(D, gamma=gamma))


def test_soft_dtw_grad():
def make_func(gamma):
def func(d):
Expand Down Expand Up @@ -71,3 +72,47 @@ def func(x):
func = make_func(gamma)
G_num = approx_fprime(X.ravel(), func, 1e-6).reshape(*G.shape)
assert_array_almost_equal(G, G_num, 5)


def test_soft_dtw_grad_band():
def make_func(gamma, sakoe_chiba_band):
def func(d):
D_ = d.reshape(*D.shape)
return SoftDTW(D_, gamma, sakoe_chiba_band).compute()
return func

for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000):
for sakoe_chiba_band in [-1, 0, 1, 3]:
sdtw = SoftDTW(D, gamma, sakoe_chiba_band)
sdtw.compute()
E = sdtw.grad()
func = make_func(gamma, sakoe_chiba_band)
E_num = approx_fprime(D.ravel(), func, 1e-6).reshape(*E.shape)
assert_array_almost_equal(E, E_num, 5)


def _path_is_in_band(A, sakoe_chiba_band):
if sakoe_chiba_band < 0:
return True

# construct a mask which is True inside the Sakoe-Chiba band
mm, nn = A.shape
ii = np.arange(mm)[:, None]
jj = np.arange(nn)[None, :]
mask = ii * (nn - 1) - jj * (mm - 1)
mask = np.abs(2 * mask) < (max(mm, nn) * (sakoe_chiba_band + 1))

return np.all(A[~mask] == 0)


def _soft_dtw_band_bf(D, gamma, sakoe_chiba_band):
costs = [np.sum(A * D) for A in gen_all_paths(D.shape[0], D.shape[1])
if _path_is_in_band(A, sakoe_chiba_band)]
return _softmin(costs, gamma)


def test_soft_dtw_band():
gamma = 0.01
for sakoe_chiba_band in [-1, 0, 1, 2]:
assert_almost_equal(SoftDTW(D, gamma, sakoe_chiba_band).compute(),
_soft_dtw_band_bf(D, gamma, sakoe_chiba_band))