From f5d0531cc2fab93f5d7007677373d95d28350bf2 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Tue, 1 Mar 2022 13:47:09 +0100 Subject: [PATCH] fixes issue #1541 --- pyemma/coordinates/api.py | 2 +- pyemma/coordinates/data/feature_reader.py | 16 +++++++---- .../data/featurization/featurizer.py | 5 ++-- pyemma/coordinates/data/h5_reader.py | 4 ++- pyemma/coordinates/data/numpy_filereader.py | 5 ++-- pyemma/coordinates/data/py_csv_reader.py | 6 ++-- pyemma/coordinates/data/util/reader_utils.py | 28 +++++++++++-------- .../coordinates/data/util/traj_info_cache.py | 11 ++++---- .../coordinates/tests/test_traj_info_cache.py | 26 +++++++++++++++-- pyemma/coordinates/util/patches.py | 6 ++-- 10 files changed, 74 insertions(+), 35 deletions(-) diff --git a/pyemma/coordinates/api.py b/pyemma/coordinates/api.py index 32862d053..754d4c0fb 100644 --- a/pyemma/coordinates/api.py +++ b/pyemma/coordinates/api.py @@ -62,7 +62,7 @@ 'assign_to_centers', ] -_string_types = str +_string_types = (str, Path) # ============================================================================== # diff --git a/pyemma/coordinates/data/feature_reader.py b/pyemma/coordinates/data/feature_reader.py index 66e490971..869b5dd10 100644 --- a/pyemma/coordinates/data/feature_reader.py +++ b/pyemma/coordinates/data/feature_reader.py @@ -17,6 +17,8 @@ from copy import copy +from pathlib import Path + import mdtraj import numpy as np @@ -24,6 +26,7 @@ from pyemma.coordinates.data._base.datasource import DataSource, EncapsulatedIterator from pyemma.coordinates.data._base.random_accessible import RandomAccessStrategy from pyemma.coordinates.data.featurization.featurizer import MDFeaturizer +from pyemma.coordinates.data.util.reader_utils import file_suffix from pyemma.coordinates.data.util.traj_info_cache import TrajInfo from pyemma.coordinates.util import patches from pyemma.util.annotators import deprecated, fix_docs @@ -91,15 +94,15 @@ def __init__(self, trajectories, topologyfile=None, chunksize=1000, featurizer=N super(FeatureReader, self).__init__(chunksize=chunksize) self._is_reader = True self.topfile = topologyfile - self.filenames = copy(trajectories) # this is modified in-place in mdtraj.load + if not isinstance(trajectories, (list, tuple)): + trajectories = [trajectories] + self.filenames = copy([str(traj) for traj in trajectories]) # this is modified in-place in mdtraj.load self._return_traj_obj = False - self._is_random_accessible = all( - (f.endswith(FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS) - for f in self.filenames) - ) + self._is_random_accessible = all(file_suffix(f) in FeatureReader.SUPPORTED_RANDOM_ACCESS_FORMATS + for f in self.filenames) # check we have at least mdtraj-1.6.1 to efficiently seek xtc, trr formats - if any(f.endswith('.xtc') or f.endswith('.trr') for f in trajectories): + if any(file_suffix(f) == '.xtc' or file_suffix(f) == '.trr' for f in trajectories): from distutils.version import LooseVersion xtc_trr_random_accessible = True if LooseVersion(mdtraj.version.version) >= LooseVersion('1.6.1') else False self._is_random_accessible &= xtc_trr_random_accessible @@ -128,6 +131,7 @@ def trajfiles(self): return self.filenames def _get_traj_info(self, filename): + filename = str(filename) if isinstance(filename, Path) else filename with mdtraj.open(filename, mode='r') as fh: try: length = len(fh) diff --git a/pyemma/coordinates/data/featurization/featurizer.py b/pyemma/coordinates/data/featurization/featurizer.py index a75e49ec0..2b79558ce 100644 --- a/pyemma/coordinates/data/featurization/featurizer.py +++ b/pyemma/coordinates/data/featurization/featurizer.py @@ -17,6 +17,7 @@ import warnings +from pathlib import Path from pyemma._base.loggable import Loggable from pyemma._base.serialization.serialization import SerializableMixIn @@ -68,8 +69,8 @@ def topologyfile(self): @topologyfile.setter def topologyfile(self, topfile): self._topologyfile = topfile - if isinstance(topfile, str): - self.topology = load_topology_cached(topfile) + if isinstance(topfile, (Path, str)): + self.topology = load_topology_cached(str(topfile)) self._topologyfile = topfile elif isinstance(topfile, mdtraj.Topology): self.topology = topfile diff --git a/pyemma/coordinates/data/h5_reader.py b/pyemma/coordinates/data/h5_reader.py index 30cf6da20..2753672ee 100644 --- a/pyemma/coordinates/data/h5_reader.py +++ b/pyemma/coordinates/data/h5_reader.py @@ -58,7 +58,9 @@ def __init__(self, filenames, selection='/*', chunk_size=5000, **kw): # and the interface of the cache does not allow for such a mapping (1:1 relation filename:(dimension, len)). from pyemma.util.contexts import settings with settings(use_trajectory_lengths_cache=False): - self.filenames = filenames + if not isinstance(filenames, (list, tuple)): + filenames = [filenames] + self.filenames = [str(fname) for fname in filenames] # we need to override the ntraj attribute to be equal with the itraj_counter to respect all data sets. self._ntraj = self._itraj_counter diff --git a/pyemma/coordinates/data/numpy_filereader.py b/pyemma/coordinates/data/numpy_filereader.py index dd258e091..125d7ce13 100644 --- a/pyemma/coordinates/data/numpy_filereader.py +++ b/pyemma/coordinates/data/numpy_filereader.py @@ -22,6 +22,7 @@ import functools +from pathlib import Path import numpy as np @@ -57,12 +58,12 @@ def __init__(self, filenames, chunksize=1000, mmap_mode='r'): filenames = [filenames] for f in filenames: - if not f.endswith('.npy'): + if Path(f).suffix != '.npy': raise ValueError('given file "%s" is not supported by this' ' reader, since it does not end with .npy' % f) self.mmap_mode = mmap_mode - self.filenames = filenames + self.filenames = [str(fname) for fname in filenames] def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None): return NPYIterator(self, skip=skip, chunk=chunk, stride=stride, diff --git a/pyemma/coordinates/data/py_csv_reader.py b/pyemma/coordinates/data/py_csv_reader.py index 710cc5ff1..bd90c22a1 100644 --- a/pyemma/coordinates/data/py_csv_reader.py +++ b/pyemma/coordinates/data/py_csv_reader.py @@ -24,6 +24,7 @@ import csv import os from math import ceil +from pathlib import Path import numpy as np @@ -202,8 +203,9 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#', if isinstance(filenames, (tuple, list)): n = len(filenames) - elif isinstance(filenames, str): + elif isinstance(filenames, (str, Path)): n = 1 + filenames = [filenames] else: raise TypeError("'filenames' argument has to be list, tuple or string") self._comments = PyCSVReader.__parse_args(comments, '#', n) @@ -216,7 +218,7 @@ def __init__(self, filenames, chunksize=1000, delimiters=None, comments='#', self._skip = np.zeros(n, dtype=int) # invoke filename setter - self.filenames = filenames + self.filenames = [str(fname) for fname in filenames] @staticmethod def __parse_args(arg, default, n): diff --git a/pyemma/coordinates/data/util/reader_utils.py b/pyemma/coordinates/data/util/reader_utils.py index 7ec134715..dac001b90 100644 --- a/pyemma/coordinates/data/util/reader_utils.py +++ b/pyemma/coordinates/data/util/reader_utils.py @@ -21,9 +21,15 @@ from numpy import vstack import mdtraj as md import numpy as np -import os +def file_suffix(path): + r""" Returns the suffix of a path. The path may be any kind of object that can be converted into a pathlib Path + object. """ + if not isinstance(path, Path): + path = Path(path) + return path.suffix + def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw): r""" @@ -51,15 +57,14 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw): return FragmentedTrajectoryReader(input_files, topology, chunksize, featurizer) # normal trajectories - if (isinstance(input_files, str) + if (isinstance(input_files, (Path, str)) or (isinstance(input_files, (list, tuple)) - and (any(isinstance(item, str) for item in input_files) + and (any(isinstance(item, (Path, str)) for item in input_files) or len(input_files) == 0))): - reader = None # check: if single string create a one-element list - if isinstance(input_files, str): + if isinstance(input_files, (Path, str)): input_list = [input_files] - elif len(input_files) > 0 and all(isinstance(item, str) for item in input_files): + elif len(input_files) > 0 and all(isinstance(item, (Path, str)) for item in input_files): input_list = input_files else: if len(input_files) == 0: @@ -68,20 +73,21 @@ def create_file_reader(input_files, topology, featurizer, chunksize=None, **kw): raise ValueError("The passed list did not exclusively contain strings or was a list of lists " "(fragmented trajectory).") - # TODO: this does not handle suffixes like .xyz.gz (rare) - _, suffix = os.path.splitext(input_list[0]) + # convert to list of paths + input_list = [Path(f) for f in input_list] - suffix = str(suffix) + # TODO: this does not handle suffixes like .xyz.gz (rare) + suffix = input_list[0].suffix # check: do all files have the same file type? If not: raise ValueError. - if all(item.endswith(suffix) for item in input_list): + if all(item.suffix == suffix for item in input_list): # do all the files exist? If not: Raise value error all_exist = True from six import StringIO err_msg = StringIO() for item in input_list: - if not os.path.isfile(item): + if not item.is_file(): err_msg.write('\n' if err_msg.tell() > 0 else "") err_msg.write('File %s did not exist or was no file' % item) all_exist = False diff --git a/pyemma/coordinates/data/util/traj_info_cache.py b/pyemma/coordinates/data/util/traj_info_cache.py index d6915fd95..14a599c1d 100644 --- a/pyemma/coordinates/data/util/traj_info_cache.py +++ b/pyemma/coordinates/data/util/traj_info_cache.py @@ -27,6 +27,7 @@ import warnings from io import BytesIO from logging import getLogger +from pathlib import Path import numpy as np @@ -180,8 +181,10 @@ def _handle_csv(self, reader, filename, length): def __getitem__(self, filename_reader_tuple): filename, reader = filename_reader_tuple + if isinstance(filename, Path): + filename = str(filename) abs_path = os.path.abspath(filename) - key = self._get_file_hash_v2(filename) + key = self.compute_file_hash(abs_path) try: info = self._database.get(key) if not isinstance(info, TrajInfo): @@ -226,15 +229,13 @@ def _get_file_hash(self, filename): hash_value ^= hash(data) return str(hash_value) - def _get_file_hash_v2(self, filename): + @staticmethod + def compute_file_hash(filename): statinfo = os.stat(filename) # now read the first megabyte and hash it with open(filename, mode='rb') as fh: data = fh.read(1024) - if sys.version_info > (3,): - long = int - hasher = hashlib.md5() hasher.update(os.path.basename(filename).encode('utf-8')) hasher.update(str(statinfo.st_mtime).encode('ascii')) diff --git a/pyemma/coordinates/tests/test_traj_info_cache.py b/pyemma/coordinates/tests/test_traj_info_cache.py index c578e7c3b..31c25e582 100644 --- a/pyemma/coordinates/tests/test_traj_info_cache.py +++ b/pyemma/coordinates/tests/test_traj_info_cache.py @@ -19,8 +19,8 @@ @author: marscher ''' - - +import shutil +from pathlib import Path from tempfile import NamedTemporaryFile import os @@ -35,7 +35,7 @@ from pyemma.coordinates.data.py_csv_reader import PyCSVReader from pyemma.coordinates.data.util.traj_info_backends import SqliteDB from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache -from pyemma.coordinates.tests.util import create_traj +from pyemma.coordinates.tests.util import create_traj, get_top from pyemma.datasets import get_bpti_test_data from pyemma.util import config from pyemma.util.contexts import settings @@ -268,6 +268,26 @@ def test_max_n_entries(self): self.assertLessEqual(self.db.num_entries, max_entries) self.assertGreater(self.db.num_entries, 0) + def test_cache_miss_same_filename(self): + # reproduces issue #1541 + tmpdir = None + try: + fname_pdb = os.path.basename(pdbfile) + fname_xtc = os.path.basename(xtcfiles[0]) + tmpdir = Path(tempfile.mkdtemp()) + shutil.copyfile(pdbfile, tmpdir / fname_pdb) + shutil.copyfile(xtcfiles[0], tmpdir / fname_xtc) + _ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb) + shutil.copyfile(get_top(), tmpdir / fname_pdb) # overwrite pdb + + t = mdtraj.load(tmpdir / fname_pdb) + t.xyz = np.zeros(shape=(400, 3, 3)) + t.time = np.arange(len(t.xyz)) + t.save(tmpdir / fname_xtc, force_overwrite=True) + _ = pyemma.coordinates.source(tmpdir / fname_xtc, top=tmpdir / fname_pdb) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + def test_max_size(self): data = [np.random.random((150, 10)) for _ in range(150)] max_size = 1 diff --git a/pyemma/coordinates/util/patches.py b/pyemma/coordinates/util/patches.py index 95985e5da..d274c468e 100644 --- a/pyemma/coordinates/util/patches.py +++ b/pyemma/coordinates/util/patches.py @@ -34,17 +34,19 @@ from mdtraj.utils import in_units_of from mdtraj.utils.validation import cast_indices +from pyemma.coordinates.data.util.traj_info_cache import TrajectoryInfoCache + TrajData = namedtuple("traj_data", ('xyz', 'unitcell_lengths', 'unitcell_angles', 'box')) @lru_cache(maxsize=32) -def _load(top_file): +def _load(top_file, hash): return load_topology(top_file) def load_topology_cached(top_file): if isinstance(top_file, str): - return _load(top_file) + return _load(top_file, TrajectoryInfoCache.compute_file_hash(top_file)) if isinstance(top_file, Topology): return top_file if isinstance(top_file, Trajectory):