Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured Outputs with Pydantic for CRITIC (QA, Math, Code) #183

Merged
merged 41 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6601ccf
Update code.py
emmalin-7 Jun 21, 2024
c832c60
Merge pull request #2 from emmalin-7/emmalin-7-patch-7
emmalin-7 Jun 21, 2024
d2a689f
Update critic.py
emmalin-7 Jun 21, 2024
e7dcb61
Update base.py
emmalin-7 Jun 21, 2024
52f0456
Update critic.py
emmalin-7 Jun 21, 2024
b279798
Update agential/cog/agent/critic.py
emmalin-7 Jun 21, 2024
a73a1a6
Update agential/cog/strategies/critic/code.py
emmalin-7 Jun 21, 2024
031e08d
Update critic.py
emmalin-7 Jun 23, 2024
ba0778d
Update critic.py
emmalin-7 Jun 23, 2024
df4074e
Update base.py
emmalin-7 Jun 24, 2024
a438d04
Update code.py
emmalin-7 Jun 24, 2024
60b3f07
Update math.py
emmalin-7 Jun 24, 2024
7dc1d03
Update code.py
emmalin-7 Jun 24, 2024
b2ff0f9
Update qa.py
emmalin-7 Jun 24, 2024
9f35788
Update math.py
emmalin-7 Jun 24, 2024
793846f
Update base.py
emmalin-7 Jun 27, 2024
8e5cb5a
Update qa.py
emmalin-7 Jun 27, 2024
6031a3b
Update code.py
emmalin-7 Jun 27, 2024
b1cdc78
Update math.py
emmalin-7 Jun 27, 2024
708b44e
Update qa.py
emmalin-7 Jun 27, 2024
17324d5
Merge branch 'main' into main
emmalin-7 Jun 27, 2024
750b46b
Update code.py
emmalin-7 Jun 30, 2024
9ef4051
Update math.py
emmalin-7 Jun 30, 2024
422ec09
Update qa.py
emmalin-7 Jun 30, 2024
964237a
Update critic.py
emmalin-7 Jun 30, 2024
75b9ec1
Merge branch 'main' into main
emmalin-7 Jun 30, 2024
9558357
Update base.py
emmalin-7 Jun 30, 2024
bb86e03
Update code.py
emmalin-7 Jun 30, 2024
6b4a933
Update qa.py
emmalin-7 Jun 30, 2024
b436a48
Update qa.py
emmalin-7 Jun 30, 2024
3789081
Update critic.py
emmalin-7 Jun 30, 2024
8ac7cce
ok
alckasoc Jul 1, 2024
85d2c55
reset
alckasoc Jul 1, 2024
cb83560
added pydantic output class; TODO: linting + unit testing
alckasoc Jul 1, 2024
0703de7
Merge branch 'main' into main
alckasoc Jul 2, 2024
72daf65
testing
emmalin-7 Jul 2, 2024
aacdd58
?
emmalin-7 Jul 2, 2024
c1e78c1
Merge branch 'main' of https://github.com/emmalin-7/llmagential into …
emmalin-7 Jul 2, 2024
7bf371c
Update critic.py
alckasoc Jul 2, 2024
d7ef41f
Delete notebooks/critic.ipynb
alckasoc Jul 2, 2024
1afdb53
Create critic.ipynb
alckasoc Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 101 additions & 44 deletions agential/cog/agent/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,63 @@

from langchain_core.language_models.chat_models import BaseChatModel

from pydantic import BaseModel, Field

from agential.cog.agent.base import BaseAgent
from agential.cog.strategies.strategy_factory import CriticStrategyFactory

class CriticPydanticOutput(BaseModel):
class Config:
title = 'Critic Output'
description = 'Critic output for different modes'

class QA(BaseModel):
answer: str = Field(..., description="The answer generated by the agent.")
critique: str = Field(..., description="The critique of the answer generated by the agent.")
query: str = Field(..., description="The query requested by the agent.")
search_result: str = Field(..., description="The search result requested by the agent.")
revised_answer: str = Field(..., description="The revised answer generated by the agent.")

class Math(BaseModel):
code: str = Field(..., description="The code generated by the agent.")
critique: str = Field(..., description="The critique of the code generated by the agent.")
execution_status: str = Field(..., description="The execution status of the agent.")
code_answer: str = Field(..., description="The code answer generated by the agent.")
improved_code: str = Field(..., description="The improved code generated by the agent.")

class Code(BaseModel):
code: str = Field(..., description="The code generated by the agent.")
critique: str = Field(..., description="The critique of the code generated by the agent.")
execution_status: str = Field(..., description="The execution status of the agent.")
improved_code: str = Field(..., description="The improved code generated by the agent.")


#answer: str = Field(..., description = "The answer generated by the agent.")
#critique : str = Field(..., description = "The critique of the answer generated by the agent.")
#query: str = Field(..., description = "The query requested by the agent. ")
#search_result: str = Field(..., description = "The search result requested by the agent.")
#revised_answer: str = Field(..., description = "The revised answer generated by the agent. ")

#code: int/str? = Field(..., description = "The code generated by the agent.")
#critique: str = Field(..., description = " The critique of the answer generated by the agent.")

#execution_status: str = Field(..., description = "The execution status of the agent.")
#code_answer: int/str? = Field(..., description = "The code answer generated by the agent.")
#improved_code: int/str? = Field(..., description = "The improved code generated by the agent.")





class CriticAgent(BaseAgent):
"""CRITIC Agent.

Attributes:
llm (BaseChatModel): An instance of a language model used for generating initial answers
and critiques.
mode (Dict[str, str]): A dictionary specifying the CRITIC agent's mode and the benchmark.
For example, {"qa": "hotpotqa"}, {"math": "gsm8k"}, or {"code": "mbpp"}.
**strategy_kwargs (Dict[str, Any]): Additional strategy-specific arguments.
"""

def __init__(
self,
llm: BaseChatModel,
Expand All @@ -31,14 +73,11 @@ def __init__(
) -> None:
"""Initialization."""
super().__init__()

self.llm = llm
self.mode = mode

self.strategy = CriticStrategyFactory().get_strategy(
mode=self.mode, llm=self.llm, **strategy_kwargs
)

def generate(
self,
question: str,
Expand All @@ -53,35 +92,11 @@ def generate(
reset: bool = True,
**kwargs: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Generates an answer that is refined with search results.

Args:
question (str): The question to be answered.
examples (str): Few-shot examples to guide the language model in generating the initial answer.
prompt (str): The instruction template used to prompt the language model for the initial answer.
critique_examples (str): Few-shot examples to guide the language model in generating critiques.
critique_prompt (str): The instruction template for generating critiques.
additional_keys (Dict[str, str]): Additional keys to format the prompt. Defaults to {}.
critique_additional_keys (Dict[str, str]): Additional keys to format the critique_prompt. Defaults to {}.
max_interactions (int): The maximum number of critique cycles. Defaults to 7.
use_tool (bool): Use the external tool. Flag to decide whether to use the interpreter tool for math/code execution, or search tool for QA. Defaults to True.
reset (bool): Resets the agent's state. Defaults to True.
**kwargs (Dict[str, Any]): Additional parameters for flexibility.

Returns:
List[Dict[str, Any]]: A list of dictionaries.
- For "qa" mode: Each dictionary contains an "answer" and "critique". Optionally, a dictionary may include the search "query" and "search_result", and the final dictionary includes the final "revised_answer".
- For "math" mode: Each dictionary contains "code" and "critique". Optionally, a dictionary may include the "execution_status" and "code_answer" if use_interpreter_tool is True. If the critic improves the solution, then the dictionary will have an "improved_code" key.
- For "code" mode: Each dictionary contains "code" and "critique". Optionally, a dictionary may include the "execution_status" if use_interpreter_tool is True. If the critic improves the solution, then the dictionary will have an "improved_code" key.
"""
if reset:
self.reset()
"""Generates an answer that is refined with search results."""

out = []

# Initial answer generation.
answer = self.strategy.generate(question, examples, prompt, additional_keys)

critique = ""
for idx in range(max_interactions):
critique, external_tool_info = self.strategy.generate_critique(
Expand All @@ -97,25 +112,67 @@ def generate(
**kwargs,
)

out.append(
self.strategy.create_output_dict(answer, critique, external_tool_info)
)
if self.mode["qa"]:
output = CriticPydanticOutput.QA(
answer=answer,
critique=critique,
query=external_tool_info.get("query", ""),
search_result=external_tool_info.get("search_result", ""),
revised_answer=self.strategy.update_answer_based_on_critique(
question=question,
examples=critique_examples,
answer=answer,
critique=critique,
prompt=critique_prompt,
additional_keys=critique_additional_keys,
external_tool_info=external_tool_info,
**kwargs,
),
)
elif self.mode["math"]:
output = CriticPydanticOutput.Math(
code=answer,
critique=critique,
execution_status=external_tool_info.get("execution_status", ""),
code_answer=external_tool_info.get("code_answer", ""),
improved_code=self.strategy.update_answer_based_on_critique(
question=question,
examples=critique_examples,
answer=answer,
critique=critique,
prompt=critique_prompt,
additional_keys=critique_additional_keys,
external_tool_info=external_tool_info,
**kwargs,
),
)

elif self.mode["code"]:
output = CriticPydanticOutput.Code(
code=answer,
critique=critique,
execution_status=external_tool_info.get("execution_status", ""),
improved_code=self.strategy.update_answer_based_on_critique(
question=question,
examples=critique_examples,
answer=answer,
critique=critique,
prompt=critique_prompt,
additional_keys=critique_additional_keys,
external_tool_info=external_tool_info,
**kwargs,
),
)

out.append(output.dict())


if self.strategy.halting_condition():
break

# Update answer for the next iteration.
answer = self.strategy.update_answer_based_on_critique(
question=question,
examples=critique_examples,
answer=answer,
critique=critique,
prompt=critique_prompt,
additional_keys=critique_additional_keys,
external_tool_info=external_tool_info,
**kwargs,
)

answer = output.revised_answer if self.mode["qa"] else output.improved_code

return out

def reset(self) -> None:
Expand Down
27 changes: 20 additions & 7 deletions agential/cog/strategies/critic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from agential.cog.strategies.base import BaseStrategy

from pydantic import BaseModel, Field

class CriticBaseStrategy(BaseStrategy):
"""An abstract base class for defining strategies for the CRITIC Agent."""

def __init__(self, llm: BaseChatModel) -> None:
"""Initialization."""
super().__init__(llm)

@abstractmethod
def generate_critique(
self,
Expand All @@ -30,7 +29,6 @@ def generate_critique(
**kwargs: Dict[str, Any],
) -> Tuple[str, Dict[str, Any]]:
"""Generates a critique of the provided answer using the given language model, question, examples, and prompt.

Args:
idx (int): The index of the current interaction.
question (str): The question that was answered by the language model.
Expand All @@ -42,28 +40,43 @@ def generate_critique(
use_tool (bool): Whether to use an external tool (e.g., code interpreter, search tool) during critique.
max_interactions (int): The maximum number of critique interactions.
**kwargs (Dict[str, Any]): Additional arguments that might be needed for specific implementations.

Returns:
Tuple[str, Dict[str, Any]]: The generated critique and external tool information.
"""
pass

@abstractmethod
def create_output_dict(
self, answer: str, critique: str, external_tool_info: Dict[str, str]
) -> Dict[str, Any]:
"""Creates a dictionary containing the answer and critique, along with any additional key updates.

Args:
answer (str): The original answer.
critique (str): The generated critique.
external_tool_info (Dict[str, str]): Information from any external tools used during the critique.

Returns:
Dict[str, Any]: A dictionary containing the answer, critique, and additional key updates.
"""
pass


@abstractmethod
def create_output_pydantic(
self, answer: str, critique: str, query: str, search_result: str, revised_answer: str
) -> BaseModel:
"""Creates a pydantic model of the output.

Args:
answer (str): The answer generated by the agent.
critique (str): The critique of the answer generated by the agent.
query (str): The query requested by the agent.
search_result (str): The search result requested by the agent.
revised_answer (str): The revised answer generated by the agent.
Returns:
BaseModel: A Pydantic model containing the answer, critique, query, search result, and revised answer
"""
pass


@abstractmethod
def update_answer_based_on_critique(
self,
Expand Down
46 changes: 38 additions & 8 deletions agential/cog/strategies/critic/code.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
"""CRITIC Agent strategies for Code."""

from typing import Any, Dict, Tuple

from langchain_core.language_models.chat_models import BaseChatModel

from agential.cog.functional.critic import _prompt_agent, _prompt_critique
from agential.cog.strategies.critic.base import CriticBaseStrategy
from agential.utils.general import safe_execute
from agential.utils.validation import validate_overlapping_keys

from pydantic import BaseModel, Field, Extra


class CriticPydanticOutput(BaseModel):
"""Critic output
Attributes:

code (str): The code generated by the agent.
critique (str): The critique of the answer generated by the agent.
execution_status (str): The execution status of the agent.
improved_code (str): The improved code generated by the agent.
"""

code: str = Field(..., description = "The code generated by the agent.")
critique: str = Field(..., description = " The critique of the answer generated by the agent.")
execution_status: str = Field(..., description = "The execution status of the agent.")
improved_code: str = Field(..., description = "The improved code generated by the agent.")



class CriticCodeStrategy(CriticBaseStrategy):
"""A strategy class for Code benchmarks using the CRITIC agent.

Attributes:
llm (BaseChatModel): The language model used for generating answers and critiques.
"""

def __init__(self, llm: BaseChatModel) -> None:
"""Initialization."""
super().__init__(llm)
self._halt = False

def generate(
self,
question: str,
Expand All @@ -31,14 +44,12 @@ def generate(
**kwargs: Dict[str, Any],
) -> str:
"""Generates an answer for the given question using the provided prompt and examples.

Args:
question (str): The math question to generate an answer for.
examples (str): Few-shot examples to guide the language model.
prompt (str): The prompt to generate an answer.
additional_keys (Dict[str, str]): Additional keys for the prompt.
**kwargs (Dict[str, Any]): Additional arguments.

Returns:
str: The generated answer.
"""
Expand All @@ -53,6 +64,25 @@ def generate(

return answer


def generate_pydantic_output(
self, answer: str, critique: str, query: str, search_result: str, revised_answer: str
) -> CriticPydanticOutput:
"""Creates a pydantic model of the output.
Args:
code (str): The code generated by the agent.
critique (str): The critique of the answer generated by the agent.
execution_status (str): The execution status of the agent.
improved_code (str): The improved code generated by the agent.
"""

return CriticPydanticOutput(
code=code,
critique=critique,
execution_status=execution_status,
improved_code=improved_code,
)

def generate_critique(
self,
idx: int,
Expand Down