Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QA for Self-Refine (HotpotQA, FEVER, AmbigNQ, TriviaQA) #227

Merged
merged 32 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agential/cog/reflexion/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def generate(
**kwargs (Dict[str, Any], optional): Additional keyword arguments for the strategy.

Returns:
result (List[ReflexionCoTOutput]): A list of ReflexionCoTOutput containing the thought, action, observation, is_correct, and reflections.
List[ReflexionCoTOutput]: A list of ReflexionCoTOutput containing the thought, action, observation, is_correct, and reflections.
"""
if not prompt or not reflect_prompt or not examples or not reflect_examples:
if not fewshot_type:
Expand Down Expand Up @@ -359,7 +359,7 @@ def generate(
**kwargs (Any): Additional keyword arguments for the strategy.

Returns:
result (List[ReflexionReActOutput]): List of ReflexionReActOutput where each ReflexionReActOutput contains the ReAct output and
List[ReflexionReActOutput]: List of ReflexionReActOutput where each ReflexionReActOutput contains the ReAct output and
the reflections at the end of the trial.
"""
if not prompt or not reflect_prompt or not examples or not reflect_examples:
Expand Down
52 changes: 39 additions & 13 deletions agential/cog/self_refine/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from langchain_core.language_models.chat_models import BaseChatModel

from agential.base.agent import BaseAgent
from agential.cog.self_refine.factory import SelfRefineFactory
from agential.cog.self_refine.factory import (
SELF_REFINE_BENCHMARK_FEWSHOTS,
SelfRefineFactory,
)
from agential.cog.self_refine.output import SelfRefineOutput


Expand Down Expand Up @@ -46,15 +49,16 @@ def __init__(
def generate(
self,
question: str,
examples: str,
prompt: str,
critique_examples: str,
critique_prompt: str,
refine_examples: str,
refine_prompt: str,
examples: str = "",
prompt: str = "",
critique_examples: str = "",
critique_prompt: str = "",
refine_examples: str = "",
refine_prompt: str = "",
additional_keys: Dict[str, str] = {},
critique_additional_keys: Dict[str, str] = {},
refine_additional_keys: Dict[str, str] = {},
fewshot_type: str = "",
max_interactions: int = 3,
reset: bool = True,
) -> List[SelfRefineOutput]:
Expand All @@ -65,21 +69,43 @@ def generate(

Args:
question (str): The question or problem to solve.
examples (str): Precedent examples to guide initial solution generation.
prompt (str): Instructional prompt for initial solution generation.
critique_examples (str): Precedent examples to guide critique generation.
critique_prompt (str): Instructional prompt for critique generation.
refine_examples (str): Precedent examples to guide solution refinement.
refine_prompt (str): Instructional prompt for refining the solution.
examples (str, optional): Precedent examples to guide initial solution generation. Defaults to "".
prompt (str, optional): Instructional prompt for initial solution generation. Defaults to "".
critique_examples (str, optional): Precedent examples to guide critique generation. Defaults to "".
critique_prompt (str, optional): Instructional prompt for critique generation. Defaults to "".
refine_examples (str, optional): Precedent examples to guide solution refinement. Defaults to "".
refine_prompt (str, optional): Instructional prompt for refining the solution. Defaults to "".
additional_keys (Dict[str, str]): Additional keys to format the prompt. Defaults to {}.
critique_additional_keys (Dict[str, str]): Additional keys to format the critique_prompt. Defaults to {}.
refine_additional_keys (Dict[str, str]): Additional keys to format the refine_prompt. Defaults to {}.
fewshot_type (str): The type of few-shot examples to use. Defaults to "".
max_interactions (int): Maximum number of refinement iterations.
reset (bool): Resets the agent's state. Defaults to True.

Returns:
List[SelfRefineOutput]: A list of answers and critiques.
"""
if (
not prompt
or not critique_prompt
or not examples
or not critique_examples
or not refine_examples
or not refine_prompt
):
if not fewshot_type:
fewshot_type = SELF_REFINE_BENCHMARK_FEWSHOTS[self.benchmark][0] # type: ignore
fewshots = SelfRefineFactory.get_fewshots(
benchmark=self.benchmark, fewshot_type=fewshot_type
)
prompts = SelfRefineFactory.get_prompts(benchmark=self.benchmark)
examples = fewshots["examples"]
critique_examples = fewshots["critique_examples"]
refine_examples = fewshots["refine_examples"]
prompt = prompts["prompt"]
critique_prompt = prompts["critique_prompt"]
refine_prompt = prompts["refine_prompt"]

if reset:
self.reset()

Expand Down
82 changes: 64 additions & 18 deletions agential/cog/self_refine/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,60 @@
from agential.base.factory import BaseFactory
from agential.cog.constants import BENCHMARK_FEWSHOTS, Benchmarks, FewShotType
from agential.cog.self_refine.prompts import (
AMBIGNQ_CRITIQUE_FEWSHOT_EXAMPLES,
AMBIGNQ_REFINE_FEWSHOT_EXAMPLES,
FEVER_CRITIQUE_FEWSHOT_EXAMPLES,
FEVER_REFINE_FEWSHOT_EXAMPLES,
GSM8K_CRITIQUE_FEWSHOT_EXAMPLES,
GSM8K_REFINE_FEWSHOT_EXAMPLES,
HOTPOTQA_CRITIQUE_FEWSHOT_EXAMPLES,
HOTPOTQA_REFINE_FEWSHOT_EXAMPLES,
SELF_REFINE_CRITIQUE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_CRITIQUE_INSTRUCTION_FEVER,
SELF_REFINE_CRITIQUE_INSTRUCTION_GSM8K,
SELF_REFINE_CRITIQUE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_CRITIQUE_INSTRUCTION_SVAMP,
SELF_REFINE_CRITIQUE_INSTRUCTION_TABMWP,
SELF_REFINE_CRITIQUE_INSTRUCTION_TRIVIAQA,
SELF_REFINE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_INSTRUCTION_FEVER,
SELF_REFINE_INSTRUCTION_GSM8K,
SELF_REFINE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_INSTRUCTION_SVAMP,
SELF_REFINE_INSTRUCTION_TABMWP,
SELF_REFINE_INSTRUCTION_TRIVIAQA,
SELF_REFINE_REFINE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_REFINE_INSTRUCTION_FEVER,
SELF_REFINE_REFINE_INSTRUCTION_GSM8K,
SELF_REFINE_REFINE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_REFINE_INSTRUCTION_SVAMP,
SELF_REFINE_REFINE_INSTRUCTION_TABMWP,
SELF_REFINE_REFINE_INSTRUCTION_TRIVIAQA,
SVAMP_CRITIQUE_FEWSHOT_EXAMPLES,
SVAMP_REFINE_FEWSHOT_EXAMPLES,
TABMWP_CRITIQUE_FEWSHOT_EXAMPLES,
TABMWP_REFINE_FEWSHOT_EXAMPLES,
TRIVIAQA_CRITIQUE_FEWSHOT_EXAMPLES,
TRIVIAQA_REFINE_FEWSHOT_EXAMPLES,
)
from agential.cog.self_refine.strategies.base import SelfRefineBaseStrategy
from agential.cog.self_refine.strategies.math import (
SelfRefineGSM8KStrategy,
SelfRefineSVAMPStrategy,
SelfRefineTabMWPStrategy,
)
from agential.cog.self_refine.strategies.qa import (
SelfRefineAmbigNQStrategy,
SelfRefineFEVERStrategy,
SelfRefineHotQAStrategy,
SelfRefineTriviaQAStrategy,
)

SELF_REFINE_BENCHMARK_FEWSHOTS = {
Benchmarks.HOTPOTQA: [],
Benchmarks.FEVER: [],
Benchmarks.TRIVIAQA: [],
Benchmarks.AMBIGNQ: [],
Benchmarks.HOTPOTQA: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.FEVER: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.TRIVIAQA: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.AMBIGNQ: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.GSM8K: [FewShotType.POT],
Benchmarks.SVAMP: [FewShotType.POT],
Benchmarks.TABMWP: [FewShotType.POT],
Expand All @@ -42,16 +68,24 @@

SELF_REFINE_PROMPTS = {
Benchmarks.HOTPOTQA: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_HOTPOTQA,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_HOTPOTQA,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_HOTPOTQA,
},
Benchmarks.FEVER: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_FEVER,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_FEVER,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_FEVER,
},
Benchmarks.TRIVIAQA: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_TRIVIAQA,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_TRIVIAQA,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_TRIVIAQA,
},
Benchmarks.AMBIGNQ: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_AMBIGNQ,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_AMBIGNQ,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_AMBIGNQ,
},
Benchmarks.GSM8K: {
"prompt": SELF_REFINE_INSTRUCTION_GSM8K,
Expand All @@ -77,10 +111,22 @@
}

SELF_REFINE_FEWSHOTS: Dict[str, Dict] = {
Benchmarks.HOTPOTQA: {},
Benchmarks.FEVER: {},
Benchmarks.TRIVIAQA: {},
Benchmarks.AMBIGNQ: {},
Benchmarks.HOTPOTQA: {
"critique_examples": HOTPOTQA_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": HOTPOTQA_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.FEVER: {
"critique_examples": FEVER_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": FEVER_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.TRIVIAQA: {
"critique_examples": TRIVIAQA_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": TRIVIAQA_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.AMBIGNQ: {
"critique_examples": AMBIGNQ_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": AMBIGNQ_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.GSM8K: {
"critique_examples": GSM8K_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": GSM8K_REFINE_FEWSHOT_EXAMPLES,
Expand All @@ -98,10 +144,10 @@
}

SELF_REFINE_STRATEGIES = {
Benchmarks.HOTPOTQA: None,
Benchmarks.FEVER: None,
Benchmarks.TRIVIAQA: None,
Benchmarks.AMBIGNQ: None,
Benchmarks.HOTPOTQA: SelfRefineHotQAStrategy,
Benchmarks.FEVER: SelfRefineFEVERStrategy,
Benchmarks.TRIVIAQA: SelfRefineTriviaQAStrategy,
Benchmarks.AMBIGNQ: SelfRefineAmbigNQStrategy,
Benchmarks.GSM8K: SelfRefineGSM8KStrategy,
Benchmarks.SVAMP: SelfRefineSVAMPStrategy,
Benchmarks.TABMWP: SelfRefineTabMWPStrategy,
Expand Down Expand Up @@ -137,7 +183,7 @@ def get_fewshots(
f"Benchmark '{benchmark}' few-shot type not supported for Self-Refine."
)

benchmark_fewshots = BENCHMARK_FEWSHOTS[benchmark]
benchmark_fewshots = BENCHMARK_FEWSHOTS[benchmark][fewshot_type]

return {"examples": benchmark_fewshots, **SELF_REFINE_FEWSHOTS[benchmark]} # type: ignore

Expand Down Expand Up @@ -180,4 +226,4 @@ def get_strategy(benchmark: str, **kwargs: Any) -> SelfRefineBaseStrategy:
if strategy is None:
raise ValueError(f"No strategy defined for benchmark: {benchmark}")

return strategy(**kwargs)
return strategy(**kwargs) # type: ignore
Loading
Loading