Skip to content

Commit

Permalink
Merge branch 'self_refine_qa' of https://github.com/alckasoc/agential
Browse files Browse the repository at this point in the history
…into self_refine_qa
  • Loading branch information
alckasoc committed Jul 13, 2024
2 parents 558fb91 + e40aa57 commit 8d68636
Showing 1 changed file with 195 additions and 0 deletions.
195 changes: 195 additions & 0 deletions agential/cog/self_refine/strategies/qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Self-Refine Agent strategies for QA."""

from typing import Any, Dict

Check warning on line 3 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L3

Added line #L3 was not covered by tests

from langchain_core.language_models.chat_models import BaseChatModel

Check warning on line 5 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L5

Added line #L5 was not covered by tests

from agential.cog.self_refine.functional import (

Check warning on line 7 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L7

Added line #L7 was not covered by tests
_prompt_agent,
_prompt_critique,
_prompt_refine,
)
from agential.cog.self_refine.strategies.base import SelfRefineBaseStrategy

Check warning on line 12 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L12

Added line #L12 was not covered by tests


class SelfRefineQAStrategy(SelfRefineBaseStrategy):

Check warning on line 15 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L15

Added line #L15 was not covered by tests
"""A strategy class for QA benchmarks using the Self-Refine agent.
Attributes:
llm (BaseChatModel): The language model used for generating answers and critiques.
patience (int): The number of interactions to tolerate the same incorrect answer
before halting further attempts. Defaults to 2.
"""

def __init__(self, llm: BaseChatModel, patience: int = 2) -> None:

Check warning on line 24 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L24

Added line #L24 was not covered by tests
"""Initialization."""
super().__init__(llm)
self.patience = patience
self._prev_code_answer = ""
self.patience_counter = 0
self._halt = False

Check warning on line 30 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L26-L30

Added lines #L26 - L30 were not covered by tests

def generate(

Check warning on line 32 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L32

Added line #L32 was not covered by tests
self,
question: str,
examples: str,
prompt: str,
additional_keys: Dict[str, str],
**kwargs: Dict[str, Any],
) -> str:
"""Generates an answer for the given question using the provided prompt and examples.
Args:
question (str): The qa question to generate an answer for.
examples (str): Few-shot examples to guide the language model.
prompt (str): The prompt to generate an answer.
additional_keys (Dict[str, str]): Additional keys for the prompt.
**kwargs (Dict[str, Any]): Additional arguments.
Returns:
str: The generated answer.
"""
answer = _prompt_agent(

Check warning on line 52 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L52

Added line #L52 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
prompt=prompt,
additional_keys=additional_keys,
)
answer = answer.split("```python")[-1].split("```")[0].strip()

Check warning on line 59 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L59

Added line #L59 was not covered by tests

return answer

Check warning on line 61 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L61

Added line #L61 was not covered by tests

def generate_critique(

Check warning on line 63 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L63

Added line #L63 was not covered by tests
self,
question: str,
examples: str,
answer: str,
prompt: str,
additional_keys: Dict[str, str],
) -> str:
"""Generates a critique for the provided answer using the given prompt and examples.
Stops early if patience is reached and answer remains the same.
Args:
question (str): The qa question that was answered.
examples (str): Few-shot examples to guide the language model in generating the critique.
answer (str): The answer to be critiqued.
prompt (str): The prompt to generate a critique.
additional_keys (Dict[str, str]): Additional keys for the prompt.
Returns:
str: The generated critique. If the same incorrect answer is repeated for the number of
interactions specified by patience, the halting condition is triggered.
"""
critique = _prompt_critique(

Check warning on line 86 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L86

Added line #L86 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
answer=answer,
prompt=prompt,
additional_keys=additional_keys,
)

if answer.strip() == self._prev_code_answer:
self.patience_counter += 1
if self.patience_counter == self.patience:
self._halt = True

Check warning on line 98 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L95-L98

Added lines #L95 - L98 were not covered by tests
else:
self._prev_code_answer = answer.strip()

Check warning on line 100 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L100

Added line #L100 was not covered by tests

return critique

Check warning on line 102 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L102

Added line #L102 was not covered by tests

def create_output_dict(self, answer: str, critique: str) -> Dict[str, str]:

Check warning on line 104 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L104

Added line #L104 was not covered by tests
"""Creates an output dictionary containing the answer and critique.
Args:
answer (str): The generated answer.
critique (str): The generated critique.
Returns:
Dict[str, str]: The output dictionary.
"""
return {"answer": answer, "critique": critique}

Check warning on line 114 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L114

Added line #L114 was not covered by tests

def update_answer_based_on_critique(

Check warning on line 116 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L116

Added line #L116 was not covered by tests
self,
question: str,
examples: str,
answer: str,
critique: str,
prompt: str,
additional_keys: Dict[str, str],
) -> str:
"""Updates the answer based on the given critique.
Args:
question: The question that was answered by the language model.
examples: Few-shot examples to guide the language model.
answer: The answer provided by the language model.
critique: The critique of the answer.
prompt: The prompt to be used for generating the updated answer.
additional_keys: Additional context or parameters to include in the critique prompt.
Returns:
str: The updated answer.
"""
new_answer = _prompt_refine(

Check warning on line 138 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L138

Added line #L138 was not covered by tests
llm=self.llm,
question=question,
examples=examples,
answer=answer,
critique=critique,
prompt=prompt,
additional_keys=additional_keys,
)
new_answer = new_answer.split("```python")[-1].split("```")[0].strip()

Check warning on line 147 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L147

Added line #L147 was not covered by tests

return new_answer

Check warning on line 149 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L149

Added line #L149 was not covered by tests

def halting_condition(self) -> bool:

Check warning on line 151 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L151

Added line #L151 was not covered by tests
"""Checks if the halting condition has been met.
Returns True if the Self-Refine Agent's generated answer remains the same for `patience` number of steps.
Returns:
bool: True if the halting condition has been met, False otherwise.
"""
return self._halt

Check warning on line 159 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L159

Added line #L159 was not covered by tests

def reset(self, **kwargs: Dict[str, Any]) -> None:

Check warning on line 161 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L161

Added line #L161 was not covered by tests
"""Resets the strategy to its initial state.
Resets internal variables keeping track of halting.
Args:
**kwargs (Dict[str, Any]): Additional arguments.
"""
self._prev_code_answer = ""
self.patience_counter = 0
self._halt = False

Check warning on line 171 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L169-L171

Added lines #L169 - L171 were not covered by tests


class SelfRefineHotQAStrategy(SelfRefineQAStrategy):

Check warning on line 174 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L174

Added line #L174 was not covered by tests
"""A strategy class for the HotpotQA benchmark using the Self-Refine agent."""

pass

Check warning on line 177 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L177

Added line #L177 was not covered by tests


class SelfRefineFEVERStrategy(SelfRefineQAStrategy):

Check warning on line 180 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L180

Added line #L180 was not covered by tests
"""A strategy class for the FEVER benchmark using the Self-Refine agent."""

pass

Check warning on line 183 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L183

Added line #L183 was not covered by tests


class SelfRefineTriviaQAStrategy(SelfRefineQAStrategy):

Check warning on line 186 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L186

Added line #L186 was not covered by tests
"""A strategy class for the TriviaQA benchmark using the Self-Refine agent."""

pass

Check warning on line 189 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L189

Added line #L189 was not covered by tests


class SelfRefineAmbigNQStrategy(SelfRefineQAStrategy):

Check warning on line 192 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L192

Added line #L192 was not covered by tests
"""A strategy class for the AmbigNQ benchmark using the Self-Refine agent."""

pass

Check warning on line 195 in agential/cog/self_refine/strategies/qa.py

View check run for this annotation

Codecov / codecov/patch

agential/cog/self_refine/strategies/qa.py#L195

Added line #L195 was not covered by tests

0 comments on commit 8d68636

Please sign in to comment.