Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding log_prob option for chat models #219

Merged
merged 10 commits into from
Feb 25, 2025
1 change: 1 addition & 0 deletions src/agentlab/llm/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
max_new_tokens: int = None
temperature: float = 0.1
vision_support: bool = False
log_probs: bool = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_probs argument is now part of all chat_model_args, and has to be set to True in your llm config @optimass


@abstractmethod
def make_model(self) -> AbstractChatModel:
Expand Down
23 changes: 20 additions & 3 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -100,6 +101,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -115,6 +117,7 @@ def make_model(self):
temperature=self.temperature,
max_tokens=self.max_new_tokens,
deployment_name=self.deployment_name,
log_probs=self.log_probs,
)


Expand Down Expand Up @@ -142,6 +145,7 @@ def make_model(self):
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
n_retry_server=self.n_retry_server,
log_probs=self.log_probs,
)
elif self.backend == "vllm":
return VLLMChatModel(
Expand Down Expand Up @@ -232,6 +236,7 @@ def __init__(
client_class=OpenAI,
client_args=None,
pricing_func=None,
log_probs=False,
):
assert max_retry > 0, "max_retry should be greater than 0"

Expand All @@ -240,6 +245,7 @@ def __init__(
self.max_tokens = max_tokens
self.max_retry = max_retry
self.min_retry_wait_time = min_retry_wait_time
self.log_probs = log_probs

# Get the API key from the environment variable if not provided
if api_key_env_var:
Expand Down Expand Up @@ -286,6 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
n=n_samples,
temperature=temperature,
max_tokens=self.max_tokens,
log_probs=self.log_probs,
)

if completion.usage is None:
Expand Down Expand Up @@ -315,7 +322,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

if n_samples == 1:
return AIMessage(completion.choices[0].message.content)
res = AIMessage(completion.choices[0].message.content)
if self.log_probs:
res["log_probs"] = completion.choices[0].log_probs
Comment on lines +325 to +327

This comment was marked as resolved.

return res
else:
return [AIMessage(c.message.content) for c in completion.choices]

Expand All @@ -335,6 +345,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
super().__init__(
model_name=model_name,
Expand All @@ -346,6 +357,7 @@ def __init__(
api_key_env_var="OPENAI_API_KEY",
client_class=OpenAI,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -358,6 +370,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
client_args = {
"base_url": "https://openrouter.ai/api/v1",
Expand All @@ -373,6 +386,7 @@ def __init__(
client_class=OpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openrouter,
log_probs=log_probs,
)


Expand All @@ -386,6 +400,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
Expand All @@ -406,6 +421,7 @@ def __init__(
client_class=AzureOpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -419,8 +435,9 @@ def __init__(
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
log_probs: Optional[bool] = False,
):
super().__init__(model_name, base_model_name, n_retry_server)
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature
Expand All @@ -429,7 +446,7 @@ def __init__(
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)


class VLLMChatModel(ChatModel):
Expand Down
14 changes: 9 additions & 5 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import time
from typing import Any, List, Optional, Union

from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.llm_utils import AIMessage, Discussion
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast


class HFBaseChatModel(AbstractChatModel):
Expand Down Expand Up @@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel):
description="The number of times to retry the server if it fails to respond",
)

def __init__(self, model_name, base_model_name, n_retry_server):
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
super().__init__()
self.n_retry_server = n_retry_server
self.log_probs = log_probs

if base_model_name is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -100,7 +100,11 @@ def __call__(
while True:
try:
temperature = temperature if temperature is not None else self.temperature
response = AIMessage(self.llm(prompt, temperature=temperature))
answer = self.llm(prompt, temperature=temperature)
response = AIMessage(answer)
if self.log_probs:
response["content"] = answer.generated_text
response["log_prob"] = answer.details
responses.append(response)
break
except Exception as e:
Expand Down
11 changes: 8 additions & 3 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,14 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):


class BaseMessage(dict):
def __init__(self, role: str, content: Union[str, list[dict]]):
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
allowed_attrs = {"log_probs"}
invalid_attrs = set(kwargs.keys()) - allowed_attrs
if invalid_attrs:
raise ValueError(f"Invalid attributes: {invalid_attrs}")
self["role"] = role
self["content"] = deepcopy(content)
self.update(kwargs)

def __str__(self, warn_if_image=False) -> str:
if isinstance(self["content"], str):
Expand Down Expand Up @@ -464,8 +469,8 @@ def __init__(self, content: Union[str, list[dict]]):


class AIMessage(BaseMessage):
def __init__(self, content: Union[str, list[dict]]):
super().__init__("assistant", content)
def __init__(self, content: Union[str, list[dict]], log_probs=None):
super().__init__("assistant", content, log_probs=log_probs)


class Discussion:
Expand Down