Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

[MSM] fix dt_model in bayesian msm, added tests for other estimators #1116

Merged
merged 6 commits into from
Jun 29, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyemma/msm/estimators/maximum_likelihood_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def __init__(self, lag=1, reversible=True, count_mode='sliding', sparse=False,
# time step
self.dt_traj = dt_traj
self.timestep_traj = _TimeUnit(dt_traj)
self.dt_model = self.dt_traj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hI @marscher I can't really control much, but at what other instance was dt_model being set in the models that HAD it set?


# score
self.score_method = score_method
Expand Down
9 changes: 7 additions & 2 deletions pyemma/msm/tests/test_bayesian_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def setUpClass(cls):
cls.nsamples = 100

cls.lag = 100
cls.bmsm_rev = bayesian_markov_model(obs_macro, cls.lag,
cls.bmsm_rev = bayesian_markov_model(obs_macro, cls.lag, dt_traj='4 fs',
reversible=True, nsamples=cls.nsamples)
cls.bmsm_revpi = bayesian_markov_model(obs_macro, cls.lag,
cls.bmsm_revpi = bayesian_markov_model(obs_macro, cls.lag, dt_traj='4 fs',
reversible=True, statdist=pi_macro,
nsamples=cls.nsamples)

Expand Down Expand Up @@ -309,6 +309,11 @@ def _timescales_stats(self, msm):

# TODO: these tests can be made compact because they are almost the same. can define general functions for testing
# TODO: samples and stats, only need to implement consistency check individually.

def test_dt_model(self):
from pyemma.util.units import TimeUnit
tu = TimeUnit("4 fs").get_scaled(self.bmsm_rev.lag)
self.assertEqual(self.bmsm_rev.dt_model, tu)

if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions pyemma/msm/tests/test_hmsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def test_simulate_HMSM(self):
assert (len(traj) <= N)
assert (len(np.unique(traj)) <= len(hmsm.transition_matrix))

def test_dt_model(self):
from pyemma.util.units import TimeUnit
tu = TimeUnit("1 step").get_scaled(self.hmsm_lag10.lag)
self.assertEqual(self.hmsm_lag10.dt_model, tu)

# ----------------------------------
# MORE COMPLEX TESTS / SANITY CHECKS
# ----------------------------------
Expand Down
13 changes: 13 additions & 0 deletions pyemma/msm/tests/test_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,19 @@ def test_timestep(self):
self._timestep(self.msmrevpi_sparse)
self._timestep(self.msm_sparse)

def _dt_model(self, msm):
from pyemma.util.units import TimeUnit
tu = TimeUnit("1 step").get_scaled(self.msm.lag)
self.assertEqual(msm.dt_model, tu)

def test_dt_model(self):
self._dt_model(self.msmrev)
self._dt_model(self.msmrevpi)
self._dt_model(self.msm)
self._dt_model(self.msmrev_sparse)
self._dt_model(self.msmrevpi_sparse)
self._dt_model(self.msm_sparse)

def _transition_matrix(self, msm):
P = msm.transition_matrix
# should be ndarray by default
Expand Down
10 changes: 10 additions & 0 deletions pyemma/util/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __str__(self):
else:
return str(self._factor)+' '+self._unit_names[self._unit]

def __repr__(self):
return "[TimeUnit {}]".format(self)

@property
def dt(self):
return self._factor
Expand Down Expand Up @@ -140,6 +143,13 @@ def rescale_around1(self, times):
# nothing to do
return times, self._unit

def __eq__(self, other):
if not isinstance(other, TimeUnit):
return False

return self._unit == other._unit and self._factor == other._factor


def bytes_to_string(num, suffix='B'):
"""
Returns the size of num (bytes) in a human readable form up to Yottabytes (YB).
Expand Down