Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QA for Self-Refine (HotpotQA, FEVER, AmbigNQ, TriviaQA) #227

Merged
merged 32 commits into from
Jul 13, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions agential/cog/self_refine/strategies/qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Self-Refine Agent strategies for QA."""

from typing import Any, Dict

Check warning on line 3 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L3

Added line #L3 was not covered by tests

from langchain_core.language_models.chat_models import BaseChatModel

Check warning on line 5 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L5

Added line #L5 was not covered by tests

from agential.cog.self_refine.functional import (

Check warning on line 7 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L7

Added line #L7 was not covered by tests
_prompt_agent,
_prompt_critique,
_prompt_refine,
)
from agential.cog.self_refine.strategies.base import SelfRefineBaseStrategy

Check warning on line 12 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L12

Added line #L12 was not covered by tests


class SelfRefineQAStrategy(SelfRefineBaseStrategy):

Check warning on line 15 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L15

Added line #L15 was not covered by tests
"""A strategy class for QA benchmarks using the Self-Refine agent.

Attributes:
llm (BaseChatModel): The language model used for generating answers and critiques.
patience (int): The number of interactions to tolerate the same incorrect answer
before halting further attempts. Defaults to 2.
"""

def __init__(self, llm: BaseChatModel, patience: int = 2) -> None:

Check warning on line 24 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L24

Added line #L24 was not covered by tests
"""Initialization."""
super().__init__(llm)
self.patience = patience
self._prev_code_answer = ""
self.patience_counter = 0
self._halt = False

Check warning on line 30 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L26-L30

Added lines #L26 - L30 were not covered by tests

def generate(

Check warning on line 32 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L32

Added line #L32 was not covered by tests
self,
question: str,
examples: str,
prompt: str,
additional_keys: Dict[str, str],
**kwargs: Dict[str, Any],
) -> str:
"""Generates an answer for the given question using the provided prompt and examples.

Args:
question (str): The qa question to generate an answer for.
examples (str): Few-shot examples to guide the language model.
prompt (str): The prompt to generate an answer.
additional_keys (Dict[str, str]): Additional keys for the prompt.
**kwargs (Dict[str, Any]): Additional arguments.

Returns:
str: The generated answer.
"""
answer = _prompt_agent(

Check warning on line 52 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L52

Added line #L52 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
prompt=prompt,
additional_keys=additional_keys,
)
answer = answer.split("```python")[-1].split("```")[0].strip()

Check warning on line 59 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L59

Added line #L59 was not covered by tests

return answer

Check warning on line 61 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L61

Added line #L61 was not covered by tests

def generate_critique(

Check warning on line 63 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L63

Added line #L63 was not covered by tests
self,
question: str,
examples: str,
answer: str,
prompt: str,
additional_keys: Dict[str, str],
) -> str:
"""Generates a critique for the provided answer using the given prompt and examples.

Stops early if patience is reached and answer remains the same.

Args:
question (str): The qa question that was answered.
examples (str): Few-shot examples to guide the language model in generating the critique.
answer (str): The answer to be critiqued.
prompt (str): The prompt to generate a critique.
additional_keys (Dict[str, str]): Additional keys for the prompt.

Returns:
str: The generated critique. If the same incorrect answer is repeated for the number of
interactions specified by patience, the halting condition is triggered.
"""
critique = _prompt_critique(

Check warning on line 86 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L86

Added line #L86 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
answer=answer,
prompt=prompt,
additional_keys=additional_keys,
)

if answer.strip() == self._prev_code_answer:
self.patience_counter += 1
if self.patience_counter == self.patience:
self._halt = True

Check warning on line 98 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L95-L98

Added lines #L95 - L98 were not covered by tests
else:
self._prev_code_answer = answer.strip()

Check warning on line 100 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L100

Added line #L100 was not covered by tests

return critique

Check warning on line 102 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L102

Added line #L102 was not covered by tests

def create_output_dict(self, answer: str, critique: str) -> Dict[str, str]:

Check warning on line 104 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L104

Added line #L104 was not covered by tests
"""Creates an output dictionary containing the answer and critique.

Args:
answer (str): The generated answer.
critique (str): The generated critique.

Returns:
Dict[str, str]: The output dictionary.
"""
return {"answer": answer, "critique": critique}

Check warning on line 114 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L114

Added line #L114 was not covered by tests

def update_answer_based_on_critique(

Check warning on line 116 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L116

Added line #L116 was not covered by tests
self,
question: str,
examples: str,
answer: str,
critique: str,
prompt: str,
additional_keys: Dict[str, str],
) -> str:
"""Updates the answer based on the given critique.

Args:
question: The question that was answered by the language model.
examples: Few-shot examples to guide the language model.
answer: The answer provided by the language model.
critique: The critique of the answer.
prompt: The prompt to be used for generating the updated answer.
additional_keys: Additional context or parameters to include in the critique prompt.

Returns:
str: The updated answer.
"""
new_answer = _prompt_refine(

Check warning on line 138 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L138

Added line #L138 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
answer=answer,
critique=critique,
prompt=prompt,
additional_keys=additional_keys,
)
new_answer = new_answer.split("```python")[-1].split("```")[0].strip()

Check warning on line 147 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L147

Added line #L147 was not covered by tests

return new_answer

Check warning on line 149 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L149

Added line #L149 was not covered by tests

def halting_condition(self) -> bool:

Check warning on line 151 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L151

Added line #L151 was not covered by tests
"""Checks if the halting condition has been met.

Returns True if the Self-Refine Agent's generated answer remains the same for `patience` number of steps.

Returns:
bool: True if the halting condition has been met, False otherwise.
"""
return self._halt

Check warning on line 159 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L159

Added line #L159 was not covered by tests

def reset(self, **kwargs: Dict[str, Any]) -> None:

Check warning on line 161 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L161

Added line #L161 was not covered by tests
"""Resets the strategy to its initial state.

Resets internal variables keeping track of halting.

Args:
**kwargs (Dict[str, Any]): Additional arguments.
"""
self._prev_code_answer = ""
self.patience_counter = 0
self._halt = False

Check warning on line 171 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L169-L171

Added lines #L169 - L171 were not covered by tests


class SelfRefineHotQAStrategy(SelfRefineQAStrategy):

Check warning on line 174 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L174

Added line #L174 was not covered by tests
"""A strategy class for the HotpotQA benchmark using the Self-Refine agent."""

pass

Check warning on line 177 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L177

Added line #L177 was not covered by tests


class SelfRefineFEVERStrategy(SelfRefineQAStrategy):

Check warning on line 180 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L180

Added line #L180 was not covered by tests
"""A strategy class for the FEVER benchmark using the Self-Refine agent."""

pass

Check warning on line 183 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L183

Added line #L183 was not covered by tests


class SelfRefineTriviaQAStrategy(SelfRefineQAStrategy):

Check warning on line 186 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L186

Added line #L186 was not covered by tests
"""A strategy class for the TriviaQA benchmark using the Self-Refine agent."""

pass

Check warning on line 189 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L189

Added line #L189 was not covered by tests


class SelfRefineAmbigNQStrategy(SelfRefineQAStrategy):

Check warning on line 192 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L192

Added line #L192 was not covered by tests
"""A strategy class for the AmbigNQ benchmark using the Self-Refine agent."""

pass

Check warning on line 195 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L195

Added line #L195 was not covered by tests
Loading