Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
check output for finite data and raise a useful error message
Browse files Browse the repository at this point in the history
Fixes #1272
  • Loading branch information
marscher committed Mar 19, 2018
1 parent cc6dd5b commit e16e321
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
10 changes: 8 additions & 2 deletions pyemma/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_quadwell_data(ntraj=10, nstep=10000, x0=0., nskip=1, dt=0.01, kT=1.0, ma
number of integrator steps
dt: float, default=0.01
time step size
kT: float, default=10.0
kT: float, default=1.0
temperature factor
mass: float, default=1.0
mass
Expand All @@ -149,5 +149,11 @@ def get_quadwell_data(ntraj=10, nstep=10000, x0=0., nskip=1, dt=0.01, kT=1.0, ma
"""
from .potentials import PrinzModel
pw = PrinzModel(dt, kT, mass=mass, damping=damping)
trajs = [pw.sample(x0, nstep, nskip=nskip) for _ in range(ntraj)]
import warnings
import numpy as np
with warnings.catch_warnings(record=True) as w:
trajs = [pw.sample(x0, nstep, nskip=nskip) for _ in range(ntraj)]
if not np.all(tuple(np.isfinite(x) for x in trajs)):
raise RuntimeError('integrator detected invalid values in output. If you used a high temperature value (kT),'
' try decreasing the integration time step dt.')
return trajs
15 changes: 12 additions & 3 deletions pyemma/datasets/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyemma.thermo import estimate_umbrella_sampling
from pyemma.thermo import estimate_multi_temperature
from numpy.testing import assert_allclose
import unittest


def test_umbrella_sampling_data():
Expand Down Expand Up @@ -82,6 +83,14 @@ def test_multi_temperature_data():
assert_allclose(pi, [0.3, 0.7], rtol=0.25, atol=0.1)


def test_prinz_potential():
from pyemma.datasets import get_quadwell_data
get_quadwell_data()
class TestPrinzPotential(unittest.TestCase):

def test_prinz_potential(self):
from pyemma.datasets import get_quadwell_data
import numpy as np

d = get_quadwell_data(ntraj=1, nstep=int(1e5))
assert np.all(np.isfinite(x) for x in d)

with self.assertRaises(RuntimeError):
get_quadwell_data(ntraj=1, dt=1, kT=100)

0 comments on commit e16e321

Please sign in to comment.