Skip to content

Commit

Permalink
fix caching (#336)
Browse files Browse the repository at this point in the history
* fix caching

* fix style

* fix exception

* add comments

* add tests
  • Loading branch information
lvwerra authored Nov 8, 2022
1 parent bad198d commit e3cf0e8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
40 changes: 29 additions & 11 deletions src/evaluate/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,9 @@ def get_module(self) -> ImportableModule:
# get most recent

def _get_modification_time(module_hash):
return (Path(importable_directory_path) / module_hash / (self.name + ".py")).stat().st_mtime
return (
(Path(importable_directory_path) / module_hash / (self.name.split("--")[-1] + ".py")).stat().st_mtime
)

hash = sorted(hashes, key=_get_modification_time)[-1]
logger.warning(
Expand All @@ -550,7 +552,9 @@ def _get_modification_time(module_hash):
f"couldn't be found locally at {self.name}, or remotely on the Hugging Face Hub."
)
# make the new module to be noticed by the import system
module_path = ".".join([os.path.basename(dynamic_modules_path), self.module_type, self.name, hash, self.name])
module_path = ".".join(
[os.path.basename(dynamic_modules_path), self.module_type, self.name, hash, self.name.split("--")[-1]]
)
importlib.invalidate_caches()
return ImportableModule(module_path, hash)

Expand Down Expand Up @@ -658,15 +662,29 @@ def evaluation_module_factory(
dynamic_modules_path=dynamic_modules_path,
).get_module()
except Exception as e1: # noqa: all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedEvaluationModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
except Exception as e2: # noqa: if it's not in the cache, then it doesn't exist.
if not isinstance(e1, (ConnectionError, FileNotFoundError)):
raise e1 from None
raise FileNotFoundError(
f"Couldn't find a module script at {relative_to_absolute_path(combined_path)}. "
f"Module '{path}' doesn't exist on the Hugging Face Hub either."
) from None
# if it's a canonical module we need to check if it's any of the types
if path.count("/") == 0:
for current_type in ["metric", "comparison", "measurement"]:
try:
return CachedEvaluationModuleFactory(
f"evaluate-{current_type}--{path}", dynamic_modules_path=dynamic_modules_path
).get_module()
except Exception as e2: # noqa: if it's not in the cache, then it doesn't exist.
pass
# if it's a community module we just need to check on path
elif path.count("/") == 1:
try:
return CachedEvaluationModuleFactory(
path.replace("/", "--"), dynamic_modules_path=dynamic_modules_path
).get_module()
except Exception as e2: # noqa: if it's not in the cache, then it doesn't exist.
pass
if not isinstance(e1, (ConnectionError, FileNotFoundError)):
raise e1 from None
raise FileNotFoundError(
f"Couldn't find a module script at {relative_to_absolute_path(combined_path)}. "
f"Module '{path}' doesn't exist on the Hugging Face Hub either."
) from None
else:
raise FileNotFoundError(f"Couldn't find a module script at {relative_to_absolute_path(combined_path)}.")

Expand Down
31 changes: 30 additions & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pytest

import evaluate
from evaluate.loading import CachedEvaluationModuleFactory, HubEvaluationModuleFactory, LocalEvaluationModuleFactory
from evaluate.loading import (
CachedEvaluationModuleFactory,
HubEvaluationModuleFactory,
LocalEvaluationModuleFactory,
evaluation_module_factory,
)
from evaluate.utils.file_utils import DownloadConfig

from .utils import OfflineSimulationMode, offline
Expand Down Expand Up @@ -109,3 +114,27 @@ def test_CachedMetricModuleFactory(self):
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None

def test_cache_with_remote_canonical_module(self):
metric = "accuracy"
evaluation_module_factory(
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
)

for offline_mode in OfflineSimulationMode:
with offline(offline_mode):
evaluation_module_factory(
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
)

def test_cache_with_remote_community_module(self):
metric = "lvwerra/test"
evaluation_module_factory(
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
)

for offline_mode in OfflineSimulationMode:
with offline(offline_mode):
evaluation_module_factory(
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
)

0 comments on commit e3cf0e8

Please sign in to comment.