Skip to content

Commit

Permalink
Add provider specific args and allow using unrecognized model names (#…
Browse files Browse the repository at this point in the history
…1621)

Add provider specific args and allow passing model name not from the mapping with warning

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel authored Feb 23, 2025
1 parent d9d9a9d commit 20a99df
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .artifact import Artifact
from .dataclass import InternalField, NonPositionalField
from .deprecation_utils import deprecation
from .error_utils import UnitxtError
from .error_utils import UnitxtError, UnitxtWarning
from .image_operators import (
EncodeImageToString,
ImageDataString,
Expand Down Expand Up @@ -2985,10 +2985,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
mapping each supported API to a corresponding
model identifier string. This mapping allows consistent access to models
across different API backends.
provider_specific_args: (Optional[Dict[str, Dict[str,str]]]) Args specific to a provider for example provider_specific_args={"watsonx": {"max_requests_per_second": 4}}
"""

label: str = "cross_provider"
provider: Optional[_supported_apis] = None
provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None

provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
"watsonx": {
Expand Down Expand Up @@ -3148,12 +3150,18 @@ def prepare_engine(self):
f"{provider} is not a configured API for CrossProviderInferenceEngine. Supported apis: {','.join(self.provider_model_map.keys())}"
)
if self.model not in self.provider_model_map[provider]:
raise UnitxtError(
f"{self.model} is not configured for provider {provider}. Supported models: {','.join(self.provider_model_map[provider].keys())}"
UnitxtWarning(
f"{self.model} is not configured for provider {provider}. Supported models: {','.join(self.provider_model_map[provider].keys())}. Using un normalized name will make it impossible to switch to different provider on request."
)
cls = self.__class__._provider_to_base_class[provider]
args = self.to_dict([StandardAPIParamsMixin])
args["model"] = self.provider_model_map[provider][self.model]
args["model"] = self.provider_model_map[provider].get(self.model, self.model)

if self.provider_specific_args is not None:
provider_args = self.provider_specific_args.get(provider)
if provider_args is not None:
args.update(provider_args)

params = list(args.keys())
if provider in self._provider_param_renaming:
for param in params:
Expand Down

0 comments on commit 20a99df

Please sign in to comment.