diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 36003319856c..b6273368c565 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -86,7 +86,8 @@ def _hf_model_list() -> List[ModelInfo]: cache_dir = os.getenv("HF_HOME", "~/.cache/huggingface") cache_file_path = os.path.join(cache_dir, "DS_model_cache.pkl") - cache_expiration_seconds = 60 * 60 * 24 # 1 day + num_days = os.getenv("HF_CACHE_EXPIRY_DAYS", 1) + cache_expiration_seconds = num_days * 60 * 60 * 24 # Load or initialize the cache model_data = {"cache_time": 0, "model_list": []} @@ -97,7 +98,8 @@ def _hf_model_list() -> List[ModelInfo]: current_time = time.time() # Update the cache if it has expired - if (model_data["cache_time"] + cache_expiration_seconds) < current_time: + if ((model_data["cache_time"] + cache_expiration_seconds) < current_time) or os.getenv("FORCE_UPDATE_HF_CACHE", + default=False): api = HfApi() model_data["model_list"] = [ ModelInfo(modelId=m.modelId, pipeline_tag=m.pipeline_tag, tags=m.tags) for m in api.list_models()