Skip to content

Commit

Permalink
Merge branch 'devel' into fix_dt_model
Browse files Browse the repository at this point in the history
  • Loading branch information
marscher authored Jun 26, 2017
2 parents 4ca786f + 16dd0c3 commit 3950db9
Show file tree
Hide file tree
Showing 16 changed files with 634 additions and 223 deletions.
30 changes: 9 additions & 21 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
import pytest
import sys

# this is import for measuring coverage
assert 'pyemma' not in sys.modules


@pytest.fixture(scope='session')
def no_progress_bars():
""" disables progress bars during testing """
import subprocess

def pg_enabled():
out = subprocess.check_output(['python', '-c', 'from __future__ import print_function; import pyemma; print(pyemma.config.show_progress_bars)'])
return out.find('True') != -1

def cache_enabled():
out = subprocess.check_output(['python', '-c', 'from __future__ import print_function; import pyemma; print(pyemma.config.use_trajectory_lengths_cache)'])
return out.find('True') != -1

cfg_script = "import pyemma; pyemma.config.show_progress_bars = {pg}; pyemma.config.use_trajectory_lengths_cache = {cache};pyemma.config.save()"

pg_old_state = pg_enabled()
cache_old_state = cache_enabled()

enable = cfg_script.format(pg=pg_old_state, cache=cache_old_state)
disable = cfg_script.format(pg=False, cache=False)

subprocess.call(['python', '-c', disable])
yield # run session, after generator returned, session is cleaned up.
subprocess.call(['python', '-c', enable])
if 'pyemma' in sys.modules:
pyemma = sys.modules['pyemma']
pyemma.config.show_progress_bars = False
pyemma.config.use_trajectory_lengths_cache = False
yield
2 changes: 2 additions & 0 deletions doc/source/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ Changelog

- datasets: fixed get_multi_temperature_data and get_umbrella_sampling_data for Python 3. #1102
- coordinates: fixed StreamingTransformers (TICA, Kmeans, etc.) not respecting the in_memory flag. #1112
- coordinates: made TrajectoryInfoCache more fail-safe in case of concurrent processes. #1122
- msm: fix setting of dt_model for BayesianMSM. This bug led to wrongly scaled time units for mean first passage times,
correlation and relaxation times as well for timescales for this estimator. #1116


2.4 (05-19-2017)
----------------

Expand Down
11 changes: 7 additions & 4 deletions pyemma/_ext/variational/estimators/running_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def combine(self, other, mean_free=False):
w1 = self.w
w2 = other.w
w = w1 + w2
dsx = (w2/w1) * self.sx - other.sx
dsy = (w2/w1) * self.sy - other.sy
# TODO: fix this div by zero error
q = w2 / w1
dsx = q * self.sx - other.sx
dsy = q * self.sy - other.sy
# update
self.w = w1 + w2
self.sx = self.sx + other.sx
Expand Down Expand Up @@ -239,9 +241,10 @@ def add(self, X, Y=None, weights=None):
weights = weights * np.ones(T, dtype=float)
# Check appropriate length if weights is an array:
elif isinstance(weights, np.ndarray):
assert weights.shape[0] == T, 'weights and X must have equal length'
if len(weights) != T:
raise ValueError('weights and X must have equal length. Was {} and {} respectively.'.format(len(weights), len(X)))
else:
raise TypeError('weights is of type %s, must be a number or ndarray'%(type(weights)))
raise TypeError('weights is of type %s, must be a number or ndarray' % (type(weights)))
# estimate and add to storage
if self.compute_XX and not self.compute_XY:
w, s_X, C_XX = moments_XX(X, remove_mean=self.remove_mean, weights=weights, sparse_mode=self.sparse_mode, modify_data=self.modify_data)
Expand Down
9 changes: 7 additions & 2 deletions pyemma/coordinates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,9 @@ def tica(data=None, lag=10, dim=-1, var_cutoff=0.95, kinetic_map=True, commute_m
raise ValueError("reweighting must be either 'empirical', 'koopman' or an object with a weights(data) method.")
elif hasattr(weights, 'weights') and type(getattr(weights, 'weights')) == types.MethodType:
weights = weights
elif isinstance(weights, (list, tuple)) and all(isinstance(w, _np.ndarray) for w in weights):
if data is not None and len(data) != len(weights):
raise ValueError("len of weights({}) must match len of data({}).".format(len(weights), len(data)))
else:
raise ValueError("reweighting must be either 'empirical', 'koopman' or an object with a weights(data) method.")

Expand All @@ -1217,7 +1220,7 @@ def tica(data=None, lag=10, dim=-1, var_cutoff=0.95, kinetic_map=True, commute_m


def covariance_lagged(data=None, c00=True, c0t=True, ctt=False, remove_constant_mean=None, remove_data_mean=False,
reversible=False, bessel=True, lag=0, weights="empirical", stride=1, skip=0, chunksize=None):
reversible=False, bessel=True, lag=0, weights="empirical", stride=1, skip=0, chunksize=1000):
"""
Compute lagged covariances between time series. If data is available as an array of size (TxN), where T is the
number of time steps and N the number of dimensions, this function can compute lagged covariances like
Expand Down Expand Up @@ -1290,7 +1293,9 @@ def covariance_lagged(data=None, c00=True, c0t=True, ctt=False, remove_constant_
else:
raise ValueError("reweighting must be either 'empirical', 'koopman' or an object with a weights(data) method.")
elif hasattr(weights, 'weights') and type(getattr(weights, 'weights')) == types.MethodType:
weights = weights
pass
elif isinstance(weights, (list, tuple, _np.ndarray)):
pass
else:
raise ValueError("reweighting must be either 'empirical', 'koopman' or an object with a weights(data) method.")

Expand Down
21 changes: 15 additions & 6 deletions pyemma/coordinates/data/_base/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def __init__(self, chunksize=1000):
super(StreamingTransformer, self).__init__(chunksize=chunksize)
self.data_producer = None
self._Y_source = None
self._estimated = True # this class should only transform data and need no estimation.

@abstractmethod
def dimension(self):
pass

@property
# overload of DataSource
Expand Down Expand Up @@ -180,12 +185,6 @@ def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=True, col
return StreamingTransformerIterator(self, skip=skip, chunk=chunk, stride=stride,
return_trajindex=return_trajindex, cols=cols)

def get_output(self, dimensions=slice(0, None), stride=1, skip=0, chunk=None):
if not self._estimated:
self.estimate(self.data_producer, stride=stride)

return super(StreamingTransformer, self).get_output(dimensions, stride, skip, chunk)

@property
def chunksize(self):
"""chunksize defines how much data is being processed at once."""
Expand Down Expand Up @@ -214,6 +213,10 @@ def n_frames_total(self, stride=1, skip=0):


class StreamingEstimationTransformer(StreamingTransformer, StreamingEstimator):
def __init__(self):
super(StreamingEstimationTransformer, self).__init__()
self._estimated = False

""" Basis class for pipelined Transformers, which perform also estimation. """
def estimate(self, X, **kwargs):
super(StreamingEstimationTransformer, self).estimate(X, **kwargs)
Expand All @@ -223,6 +226,12 @@ def estimate(self, X, **kwargs):
self._map_to_memory()
return self

def get_output(self, dimensions=slice(0, None), stride=1, skip=0, chunk=None):
if not self._estimated:
self.estimate(self.data_producer, stride=stride)

return super(StreamingTransformer, self).get_output(dimensions, stride, skip, chunk)


class StreamingTransformerIterator(DataSourceIterator):

Expand Down
5 changes: 5 additions & 0 deletions pyemma/coordinates/data/data_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def load_from_files(cls, files):
def describe(self):
return "[DataInMemory array shapes: %s]" % [np.shape(x) for x in self.data]

def __str__(self):
return self.describe()

__repr__ = __str__


class DataInMemoryCuboidRandomAccessStrategy(RandomAccessStrategy):
def _handle_slice(self, idx):
Expand Down
90 changes: 52 additions & 38 deletions pyemma/coordinates/data/util/traj_info_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def convert_array(text):
# Converts TEXT to np.array when selecting
sqlite3.register_converter("NPARRAY", convert_array)
self._database = sqlite3.connect(filename if filename is not None else ":memory:",
detect_types=sqlite3.PARSE_DECLTYPES, timeout=1000*1000,
detect_types=sqlite3.PARSE_DECLTYPES, timeout=5,
isolation_level=None)
self.filename = filename

self.lru_timeout = 5.0 # python sqlite3 specifies timeout in seconds instead of milliseconds.

try:
cursor = self._database.execute("select num from version")
row = cursor.fetchone()
Expand All @@ -148,8 +150,8 @@ def convert_array(text):

def _create_new_db(self):
# assumes self.database is a sqlite3.Connection
create_version_table = "CREATE TABLE version (num INTEGER PRIMARY KEY);"
create_info_table = """CREATE TABLE traj_info(
create_version_table = "CREATE TABLE IF NOT EXISTS version (num INTEGER PRIMARY KEY);"
create_info_table = """CREATE TABLE IF NOT EXISTS traj_info(
hash VARCHAR(64) PRIMARY KEY,
length INTEGER,
ndim INTEGER,
Expand All @@ -176,8 +178,13 @@ def db_version(self):

@db_version.setter
def db_version(self, val):
self._database.execute("insert into version VALUES (?)", [val])
self._database.commit()
import sqlite3
with self._database:
try:
self._database.execute("insert into version VALUES (?)", [val])
except sqlite3.IntegrityError:
pass
# self._database.commit()

@property
def num_entries(self):
Expand All @@ -197,11 +204,11 @@ def set(self, traj_info):
statement = ("INSERT INTO traj_info (hash, length, ndim, offsets, abs_path, version, lru_db)"
"VALUES (?, ?, ?, ?, ?, ?, ?)", values)
try:
self._database.execute(*statement)
with self._database as c:
c.execute(*statement)
except sqlite3.IntegrityError as ie:
logger.exception("insert failed: %s " % ie)
return
self._database.commit()

self._update_time_stamp(hash_value=traj_info.hash_value)

Expand Down Expand Up @@ -257,20 +264,29 @@ def _update_time_stamp(self, hash_value):
if not db_name:
db_name=':memory:'

import sqlite3

with sqlite3.connect(db_name) as conn:
""" last_read is a result of time.time()"""
conn.execute('CREATE TABLE IF NOT EXISTS usage '
'(hash VARCHAR(32), last_read FLOAT)')
conn.commit()
cur = conn.execute('select * from usage where hash=?', (hash_value,))
row = cur.fetchone()
if not row:
conn.execute("insert into usage(hash, last_read) values(?, ?)", (hash_value, time.time()))
else:
conn.execute("update usage set last_read=? where hash=?", (time.time(), hash_value))
conn.commit()
def _update():
import sqlite3
try:
with sqlite3.connect(db_name, timeout=self.lru_timeout) as conn:
""" last_read is a result of time.time()"""
conn.execute('CREATE TABLE IF NOT EXISTS usage '
'(hash VARCHAR(32), last_read FLOAT)')
conn.commit()
cur = conn.execute('select * from usage where hash=?', (hash_value,))
row = cur.fetchone()
if not row:
conn.execute("insert into usage(hash, last_read) values(?, ?)", (hash_value, time.time()))
else:
conn.execute("update usage set last_read=? where hash=?", (time.time(), hash_value))
conn.commit()
except sqlite3.OperationalError:
# if there are many jobs to write to same database at same time, the timeout could be hit
logger.debug('could not update LRU info for db %s', db_name)

# this could lead to another (rare) race condition during cleaning...
#import threading
#threading.Thread(target=_update).start()
_update()

@staticmethod
def _create_traj_info(row):
Expand Down Expand Up @@ -324,11 +340,9 @@ def _clean(self, n):

# debug: distribution
len_by_db = {os.path.basename(db): len(hashs_by_db[db]) for db in hashs_by_db.keys()}
logger.debug("distribution of lru: %s" % str(len_by_db))
logger.debug("distribution of lru: %s", str(len_by_db))
### end dbg

self.lru_timeout = 1000 #1 sec

# collect timestamps from databases
for db in hashs_by_db.keys():
with sqlite3.connect(db, timeout=self.lru_timeout) as conn:
Expand All @@ -345,17 +359,17 @@ def _clean(self, n):

sql_compatible_ids = SqliteDB._format_tuple_for_sql(ids)

stmnt = "DELETE FROM traj_info WHERE hash in (%s)" % sql_compatible_ids
cur = self._database.execute(stmnt)
self._database.commit()
assert cur.rowcount == len(ids), "deleted not as many rows(%s) as desired(%s)" %(cur.rowcount, len(ids))

# iterate over all LRU databases and delete those ids, we've just deleted from the main db.
age_by_hash.sort(key=itemgetter(2))
for db, values in itertools.groupby(age_by_hash, key=itemgetter(2)):
values = tuple(v[0] for v in values)
with sqlite3.connect(db, timeout=self.lru_timeout) as conn:
stmnt = "DELETE FROM usage WHERE hash IN (%s)" \
% SqliteDB._format_tuple_for_sql(values)
curr = conn.execute(stmnt)
assert curr.rowcount == len(values), curr.rowcount
with self._database as c:
c.execute("DELETE FROM traj_info WHERE hash in (%s)" % sql_compatible_ids)

# iterate over all LRU databases and delete those ids, we've just deleted from the main db.
# Do this within the same execution block of the main database, because we do not want the entry to be deleted,
# in case of a subsequent failure.
age_by_hash.sort(key=itemgetter(2))
for db, values in itertools.groupby(age_by_hash, key=itemgetter(2)):
values = tuple(v[0] for v in values)
with sqlite3.connect(db, timeout=self.lru_timeout) as conn:
stmnt = "DELETE FROM usage WHERE hash IN (%s)" \
% SqliteDB._format_tuple_for_sql(values)
curr = conn.execute(stmnt)
assert curr.rowcount == len(values), curr.rowcount
Loading

0 comments on commit 3950db9

Please sign in to comment.