diff --git a/tensorflow_datasets/core/lazy_imports_lib.py b/tensorflow_datasets/core/lazy_imports_lib.py index 268bfa04bde..271ad399e91 100644 --- a/tensorflow_datasets/core/lazy_imports_lib.py +++ b/tensorflow_datasets/core/lazy_imports_lib.py @@ -20,6 +20,8 @@ from __future__ import division from __future__ import print_function +from builtins import __import__ as original_import_fun +import builtins import contextlib import importlib import sys @@ -75,7 +77,6 @@ def _try_import(module_name): class FakeModule(types.ModuleType): """A fake module which raise ImportError whenever an unknown attribute is accessed""" def __init__(self, name): - self.__path__ = None super(FakeModule, self).__init__(name) def __getattr__(self, attr): @@ -83,22 +84,13 @@ def __getattr__(self, attr): raise ImportError(err_msg) -class CustomImporter(object): - """Finder and Loader for modules in `_ALLOWED_LAZY_DEPS`""" - - def find_module(self, fullname, _): - """Accept if fullname is present in `_ALLOWED_LAZY_DEPS`, else return None""" - if fullname in _ALLOWED_LAZY_DEPS: - return self - return None - - def load_module(self, fullname): - """Load a `FakeModule` if the requested module is not present in sys.modules""" - if fullname not in sys.modules: - mod = FakeModule(fullname) - mod.__loader__ = self - sys.modules[fullname] = mod - return sys.modules[fullname] +def _custom_import(name, *args, **kwargs): + try: + return original_import_fun(name, *args, **kwargs) + except ImportError: + if name in _ALLOWED_LAZY_DEPS: + return FakeModule(name) + raise ImportError @contextlib.contextmanager @@ -109,16 +101,14 @@ def try_import(): It is created only if module_name is present in `_ALLOWED_LAZY_DEPS`. """ try: - sys.meta_path.append(CustomImporter()) - yield + with utils.temporary_assignment(builtins, "__import__", _custom_import): + yield except ImportError as err: err_msg = ("Unknown import {name}. Currently lazy_imports does not " "support {name} module. If you believe this is correct, " "please add it to the list of `_ALLOWED_LAZY_DEPS` in " "`tfds/core/lazy_imports_lib.py`".format(name=err.name)) utils.reraise(suffix=err_msg) - finally: - sys.meta_path.pop() class LazyImporter(object): diff --git a/tensorflow_datasets/core/lazy_imports_lib_test.py b/tensorflow_datasets/core/lazy_imports_lib_test.py index 99aed7c2b90..4a2b7bf0237 100644 --- a/tensorflow_datasets/core/lazy_imports_lib_test.py +++ b/tensorflow_datasets/core/lazy_imports_lib_test.py @@ -36,7 +36,6 @@ class LazyImportsTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( "cv2", "langdetect", - "matplotlib", "mwparserfromhell", "nltk", "os", @@ -63,10 +62,8 @@ def test_bad_import(self): def test_lazy_import_context_manager(self): with tfds.core.try_import(): import pandas - import matplotlib.pyplot as plt self.assertTrue(hasattr(pandas, "read_csv")) - self.assertTrue(hasattr(plt, "figure")) def test_import_without_context_manager(self): import nltk @@ -90,6 +87,17 @@ def test_lazy_import_context_manager_errors(self): with self.assertRaisesWithPredicateMatch(ImportError, "extras_require"): new_module.some_function() + def test_nested_imports(self): + with tfds.core.try_import(): + import matplotlib + + self.assertFalse(hasattr(matplotlib, "pyplot")) + + with tfds.core.try_import(): + import matplotlib.pyplot + + self.assertTrue(hasattr(matplotlib, "pyplot")) + # pylint: enable=import-outside-toplevel, unused-import