Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Improved topology LRU cache #1545

Merged
merged 1 commit into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyemma/coordinates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
'assign_to_centers',
]

_string_types = str
_string_types = (str, Path)

# ==============================================================================
#
Expand Down
16 changes: 10 additions & 6 deletions pyemma/coordinates/data/feature_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@


from copy import copy
from pathlib import Path

import mdtraj
import numpy as np

from pyemma._base.serialization.serialization import SerializableMixIn
from pyemma.coordinates.data._base.datasource import DataSource, EncapsulatedIterator
from pyemma.coordinates.data._base.random_accessible import RandomAccessStrategy
from pyemma.coordinates.data.featurization.featurizer import MDFeaturizer
from pyemma.coordinates.data.util.reader_utils import file_suffix
from pyemma.coordinates.data.util.traj_info_cache import TrajInfo
from pyemma.coordinates.util import patches
from pyemma.util.annotators import deprecated, fix_docs
Expand Down Expand Up @@ -91,15 +94,15 @@ def __init__(self, trajectories, topologyfile=None, chunksize=1000, featurizer=N
super(FeatureReader, self).__init__(chunksize=chunksize)
self._is_reader = True
self.topfile = topologyfile
self.filenames = copy(trajectories) # this is modified in-place in mdtraj.load
if not isinstance(trajectories, (list, tuple)):
trajectories = [trajectories]
self.filenames = copy([str(traj) for traj in trajectories]) # this is modified in-place in mdtraj.load
self._return_traj_obj = False

self._is_random_accessible = all(
(f.endswith(FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS)
for f in self.filenames)
)
self._is_random_accessible = all(file_suffix(f) in FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS
for f in self.filenames)
# check we have at least mdtraj-1.6.1 to efficiently seek xtc, trr formats
if any(f.endswith('.xtc') or f.endswith('.trr') for f in trajectories):
if any(file_suffix(f) == '.xtc' or file_suffix(f) == '.trr' for f in trajectories):
from distutils.version import LooseVersion
xtc_trr_random_accessible = True if LooseVersion(mdtraj.version.version) >= LooseVersion('1.6.1') else False
self._is_random_accessible &= xtc_trr_random_accessible
Expand Down Expand Up @@ -128,6 +131,7 @@ def trajfiles(self):
return self.filenames

def _get_traj_info(self, filename):
filename = str(filename) if isinstance(filename, Path) else filename
with mdtraj.open(filename, mode='r') as fh:
try:
length = len(fh)
Expand Down
5 changes: 3 additions & 2 deletions pyemma/coordinates/data/featurization/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import warnings
from pathlib import Path

from pyemma._base.loggable import Loggable
from pyemma._base.serialization.serialization import SerializableMixIn
Expand Down Expand Up @@ -68,8 +69,8 @@ def topologyfile(self):
@topologyfile.setter
def topologyfile(self, topfile):
self._topologyfile = topfile
if isinstance(topfile, str):
self.topology = load_topology_cached(topfile)
if isinstance(topfile, (Path, str)):
self.topology = load_topology_cached(str(topfile))
self._topologyfile = topfile
elif isinstance(topfile, mdtraj.Topology):
self.topology = topfile
Expand Down
4 changes: 3 additions & 1 deletion pyemma/coordinates/data/h5_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(self, filenames, selection='/*', chunk_size=5000, **kw):
# and the interface of the cache does not allow for such a mapping (1:1 relation filename:(dimension, len)).
from pyemma.util.contexts import settings
with settings(use_trajectory_lengths_cache=False):
self.filenames = filenames
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
self.filenames = [str(fname) for fname in filenames]

# we need to override the ntraj attribute to be equal with the itraj_counter to respect all data sets.
self._ntraj = self._itraj_counter
Expand Down
5 changes: 3 additions & 2 deletions pyemma/coordinates/data/numpy_filereader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


import functools
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -57,12 +58,12 @@ def __init__(self, filenames, chunksize=1000, mmap_mode='r'):
filenames = [filenames]

for f in filenames:
if not f.endswith('.npy'):
if Path(f).suffix != '.npy':
raise ValueError('given file "%s" is not supported by this'
' reader, since it does not end with .npy' % f)

self.mmap_mode = mmap_mode
self.filenames = filenames
self.filenames = [str(fname) for fname in filenames]

def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None):
return NPYIterator(self, skip=skip, chunk=chunk, stride=stride,
Expand Down
6 changes: 4 additions & 2 deletions pyemma/coordinates/data/py_csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import csv
import os
from math import ceil
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -202,8 +203,9 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#',

if isinstance(filenames, (tuple, list)):
n = len(filenames)
elif isinstance(filenames, str):
elif isinstance(filenames, (str, Path)):
n = 1
filenames = [filenames]
else:
raise TypeError("'filenames' argument has to be list, tuple or string")
self._comments = PyCSVReader.__parse_args(comments, '#', n)
Expand All @@ -216,7 +218,7 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#',

self._skip = np.zeros(n, dtype=int)
# invoke filename setter
self.filenames = filenames
self.filenames = [str(fname) for fname in filenames]

@staticmethod
def __parse_args(arg, default, n):
Expand Down
28 changes: 17 additions & 11 deletions pyemma/coordinates/data/util/reader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
from numpy import vstack
import mdtraj as md
import numpy as np
import os


def file_suffix(path):
r""" Returns the suffix of a path. The path may be any kind of object that can be converted into a pathlib Path
object. """
if not isinstance(path, Path):
path = Path(path)
return path.suffix


def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
r"""
Expand Down Expand Up @@ -51,15 +57,14 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
return FragmentedTrajectoryReader(input_files, topology, chunksize, featurizer)

# normal trajectories
if (isinstance(input_files, str)
if (isinstance(input_files, (Path, str))
or (isinstance(input_files, (list, tuple))
and (any(isinstance(item, str) for item in input_files)
and (any(isinstance(item, (Path, str)) for item in input_files)
or len(input_files) == 0))):
reader = None
# check: if single string create a one-element list
if isinstance(input_files, str):
if isinstance(input_files, (Path, str)):
input_list = [input_files]
elif len(input_files) > 0 and all(isinstance(item, str) for item in input_files):
elif len(input_files) > 0 and all(isinstance(item, (Path, str)) for item in input_files):
input_list = input_files
else:
if len(input_files) == 0:
Expand All @@ -68,20 +73,21 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw):
raise ValueError("The passed list did not exclusively contain strings or was a list of lists "
"(fragmented trajectory).")

# TODO: this does not handle suffixes like .xyz.gz (rare)
_, suffix = os.path.splitext(input_list[0])
# convert to list of paths
input_list = [Path(f) for f in input_list]

suffix = str(suffix)
# TODO: this does not handle suffixes like .xyz.gz (rare)
suffix = input_list[0].suffix

# check: do all files have the same file type? If not: raise ValueError.
if all(item.endswith(suffix) for item in input_list):
if all(item.suffix == suffix for item in input_list):

# do all the files exist? If not: Raise value error
all_exist = True
from six import StringIO
err_msg = StringIO()
for item in input_list:
if not os.path.isfile(item):
if not item.is_file():
err_msg.write('\n' if err_msg.tell() > 0 else "")
err_msg.write('File %s did not exist or was no file' % item)
all_exist = False
Expand Down
11 changes: 6 additions & 5 deletions pyemma/coordinates/data/util/traj_info_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import warnings
from io import BytesIO
from logging import getLogger
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -180,8 +181,10 @@ def _handle_csv(self, reader, filename, length):

def __getitem__(self, filename_reader_tuple):
filename, reader = filename_reader_tuple
if isinstance(filename, Path):
filename = str(filename)
abs_path = os.path.abspath(filename)
key = self._get_file_hash_v2(filename)
key = self.compute_file_hash(abs_path)
try:
info = self._database.get(key)
if not isinstance(info, TrajInfo):
Expand Down Expand Up @@ -226,15 +229,13 @@ def _get_file_hash(self, filename):
hash_value ^= hash(data)
return str(hash_value)

def _get_file_hash_v2(self, filename):
@staticmethod
def compute_file_hash(filename):
statinfo = os.stat(filename)
# now read the first megabyte and hash it
with open(filename, mode='rb') as fh:
data = fh.read(1024)

if sys.version_info > (3,):
long = int

hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode('utf-8'))
hasher.update(str(statinfo.st_mtime).encode('ascii'))
Expand Down
26 changes: 23 additions & 3 deletions pyemma/coordinates/tests/test_traj_info_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@author: marscher
'''


import shutil
from pathlib import Path
from tempfile import NamedTemporaryFile

import os
Expand All @@ -35,7 +35,7 @@
from pyemma.coordinates.data.py_csv_reader import PyCSVReader
from pyemma.coordinates.data.util.traj_info_backends import SqliteDB
from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache
from pyemma.coordinates.tests.util import create_traj
from pyemma.coordinates.tests.util import create_traj, get_top
from pyemma.datasets import get_bpti_test_data
from pyemma.util import config
from pyemma.util.contexts import settings
Expand Down Expand Up @@ -268,6 +268,26 @@ def test_max_n_entries(self):
self.assertLessEqual(self.db.num_entries, max_entries)
self.assertGreater(self.db.num_entries, 0)

def test_cache_miss_same_filename(self):
# reproduces issue #1541
tmpdir = None
try:
fname_pdb = os.path.basename(pdbfile)
fname_xtc = os.path.basename(xtcfiles[0])
tmpdir = Path(tempfile.mkdtemp())
shutil.copyfile(pdbfile, tmpdir / fname_pdb)
shutil.copyfile(xtcfiles[0], tmpdir / fname_xtc)
_ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb)
shutil.copyfile(get_top(), tmpdir / fname_pdb) # overwrite pdb

t = mdtraj.load(tmpdir / fname_pdb)
t.xyz = np.zeros(shape=(400, 3, 3))
t.time = np.arange(len(t.xyz))
t.save(tmpdir / fname_xtc, force_overwrite=True)
_ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

def test_max_size(self):
data = [np.random.random((150, 10)) for _ in range(150)]
max_size = 1
Expand Down
6 changes: 4 additions & 2 deletions pyemma/coordinates/util/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@
from mdtraj.utils import in_units_of
from mdtraj.utils.validation import cast_indices

from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache

TrajData = namedtuple("traj_data", ('xyz', 'unitcell_lengths', 'unitcell_angles', 'box'))


@lru_cache(maxsize=32)
def _load(top_file):
def _load(top_file, hash):
return load_topology(top_file)


def load_topology_cached(top_file):
if isinstance(top_file, str):
return _load(top_file)
return _load(top_file, TrajectoryInfoCache.compute_file_hash(top_file))
if isinstance(top_file, Topology):
return top_file
if isinstance(top_file, Trajectory):
Expand Down