Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QA for Self-Refine (HotpotQA, FEVER, AmbigNQ, TriviaQA) #227

Merged
merged 32 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions agential/cog/self_refine/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,60 @@
from agential.base.factory import BaseFactory
from agential.cog.constants import BENCHMARK_FEWSHOTS, Benchmarks, FewShotType
from agential.cog.self_refine.prompts import (
AMBIGNQ_CRITIQUE_FEWSHOT_EXAMPLES,
AMBIGNQ_REFINE_FEWSHOT_EXAMPLES,
FEVER_CRITIQUE_FEWSHOT_EXAMPLES,
FEVER_REFINE_FEWSHOT_EXAMPLES,
GSM8K_CRITIQUE_FEWSHOT_EXAMPLES,
GSM8K_REFINE_FEWSHOT_EXAMPLES,
HOTPOTQA_CRITIQUE_FEWSHOT_EXAMPLES,
HOTPOTQA_REFINE_FEWSHOT_EXAMPLES,
SELF_REFINE_CRITIQUE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_CRITIQUE_INSTRUCTION_FEVER,
SELF_REFINE_CRITIQUE_INSTRUCTION_GSM8K,
SELF_REFINE_CRITIQUE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_CRITIQUE_INSTRUCTION_SVAMP,
SELF_REFINE_CRITIQUE_INSTRUCTION_TABMWP,
SELF_REFINE_CRITIQUE_INSTRUCTION_TRIVIAQA,
SELF_REFINE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_INSTRUCTION_FEVER,
SELF_REFINE_INSTRUCTION_GSM8K,
SELF_REFINE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_INSTRUCTION_SVAMP,
SELF_REFINE_INSTRUCTION_TABMWP,
SELF_REFINE_INSTRUCTION_TRIVIAQA,
SELF_REFINE_REFINE_INSTRUCTION_AMBIGNQ,
SELF_REFINE_REFINE_INSTRUCTION_FEVER,
SELF_REFINE_REFINE_INSTRUCTION_GSM8K,
SELF_REFINE_REFINE_INSTRUCTION_HOTPOTQA,
SELF_REFINE_REFINE_INSTRUCTION_SVAMP,
SELF_REFINE_REFINE_INSTRUCTION_TABMWP,
SELF_REFINE_REFINE_INSTRUCTION_TRIVIAQA,
SVAMP_CRITIQUE_FEWSHOT_EXAMPLES,
SVAMP_REFINE_FEWSHOT_EXAMPLES,
TABMWP_CRITIQUE_FEWSHOT_EXAMPLES,
TABMWP_REFINE_FEWSHOT_EXAMPLES,
TRIVIAQA_CRITIQUE_FEWSHOT_EXAMPLES,
TRIVIAQA_REFINE_FEWSHOT_EXAMPLES,
)
from agential.cog.self_refine.strategies.base import SelfRefineBaseStrategy
from agential.cog.self_refine.strategies.math import (
SelfRefineGSM8KStrategy,
SelfRefineSVAMPStrategy,
SelfRefineTabMWPStrategy,
)
from agential.cog.self_refine.strategies.qa import (
SelfRefineAmbigNQStrategy,
SelfRefineFEVERStrategy,
SelfRefineHotQAStrategy,
SelfRefineTriviaQAStrategy,
)

SELF_REFINE_BENCHMARK_FEWSHOTS = {
Benchmarks.HOTPOTQA: [],
Benchmarks.FEVER: [],
Benchmarks.TRIVIAQA: [],
Benchmarks.AMBIGNQ: [],
Benchmarks.HOTPOTQA: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.FEVER: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.TRIVIAQA: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.AMBIGNQ: [FewShotType.COT, FewShotType.DIRECT, FewShotType.REACT],
Benchmarks.GSM8K: [FewShotType.POT],
Benchmarks.SVAMP: [FewShotType.POT],
Benchmarks.TABMWP: [FewShotType.POT],
Expand All @@ -42,16 +68,24 @@

SELF_REFINE_PROMPTS = {
Benchmarks.HOTPOTQA: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_HOTPOTQA,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_HOTPOTQA,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_HOTPOTQA,
},
Benchmarks.FEVER: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_FEVER,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_FEVER,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_FEVER,
},
Benchmarks.TRIVIAQA: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_TRIVIAQA,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_TRIVIAQA,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_TRIVIAQA,
},
Benchmarks.AMBIGNQ: {
"prompt": "",
"prompt": SELF_REFINE_INSTRUCTION_AMBIGNQ,
"critique_prompt": SELF_REFINE_CRITIQUE_INSTRUCTION_AMBIGNQ,
"refine_prompt": SELF_REFINE_REFINE_INSTRUCTION_AMBIGNQ,
},
Benchmarks.GSM8K: {
"prompt": SELF_REFINE_INSTRUCTION_GSM8K,
Expand All @@ -77,10 +111,22 @@
}

SELF_REFINE_FEWSHOTS: Dict[str, Dict] = {
Benchmarks.HOTPOTQA: {},
Benchmarks.FEVER: {},
Benchmarks.TRIVIAQA: {},
Benchmarks.AMBIGNQ: {},
Benchmarks.HOTPOTQA: {
"critique_examples": HOTPOTQA_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": HOTPOTQA_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.FEVER: {
"critique_examples": FEVER_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": FEVER_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.TRIVIAQA: {
"critique_examples": TRIVIAQA_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": TRIVIAQA_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.AMBIGNQ: {
"critique_examples": AMBIGNQ_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": AMBIGNQ_REFINE_FEWSHOT_EXAMPLES,
},
Benchmarks.GSM8K: {
"critique_examples": GSM8K_CRITIQUE_FEWSHOT_EXAMPLES,
"refine_examples": GSM8K_REFINE_FEWSHOT_EXAMPLES,
Expand All @@ -98,10 +144,10 @@
}

SELF_REFINE_STRATEGIES = {
Benchmarks.HOTPOTQA: None,
Benchmarks.FEVER: None,
Benchmarks.TRIVIAQA: None,
Benchmarks.AMBIGNQ: None,
Benchmarks.HOTPOTQA: SelfRefineHotQAStrategy,
Benchmarks.FEVER: SelfRefineFEVERStrategy,
Benchmarks.TRIVIAQA: SelfRefineTriviaQAStrategy,
Benchmarks.AMBIGNQ: SelfRefineAmbigNQStrategy,
Benchmarks.GSM8K: SelfRefineGSM8KStrategy,
Benchmarks.SVAMP: SelfRefineSVAMPStrategy,
Benchmarks.TABMWP: SelfRefineTabMWPStrategy,
Expand Down
18 changes: 18 additions & 0 deletions agential/cog/self_refine/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,19 @@ def _prompt_agent(
prompt=prompt,
additional_keys=additional_keys,
)
print("<PROMPT AGENT=============================================>")
print(prompt)
print("<PROMPT AGENT=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<PROMPT AGENT=============================================>")
-  print(prompt)
-  print("<PROMPT AGENT=============================================>")
+  import logging
+  logger = logging.getLogger(__name__)
+  logger.debug("<PROMPT AGENT=============================================>")
+  logger.debug(prompt)
+  logger.debug("<PROMPT AGENT=============================================>")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("<PROMPT AGENT=============================================>")
print(prompt)
print("<PROMPT AGENT=============================================>")
import logging
logger = logging.getLogger(__name__)
logger.debug("<PROMPT AGENT=============================================>")
logger.debug(prompt)
logger.debug("<PROMPT AGENT=============================================>")

out = llm(
[
HumanMessage(
content=prompt,
)
]
).content
print("<OUT AGENT=============================================>")
print(repr(out))
print("<OUT AGENT=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<OUT AGENT=============================================>")
-  print(repr(out))
-  print("<OUT AGENT=============================================>")
+  logger.debug("<OUT AGENT=============================================>")
+  logger.debug(repr(out))
+  logger.debug("<OUT AGENT=============================================>")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("<OUT AGENT=============================================>")
print(repr(out))
print("<OUT AGENT=============================================>")
logger.debug("<OUT AGENT=============================================>")
logger.debug(repr(out))
logger.debug("<OUT AGENT=============================================>")

assert isinstance(out, str)
return out.strip()

Expand Down Expand Up @@ -133,13 +139,19 @@ def _prompt_critique(
prompt=prompt,
additional_keys=additional_keys,
)
print("<PROMPT CRITIQUE=============================================>")
print(prompt)
print("<PROMPT CRITIQUE=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<PROMPT CRITIQUE=============================================>")
-  print(prompt)
-  print("<PROMPT CRITIQUE=============================================>")
+  logger.debug("<PROMPT CRITIQUE=============================================>")
+  logger.debug(prompt)
+  logger.debug("<PROMPT CRITIQUE=============================================>")

Committable suggestion was skipped due to low confidence.

out = llm(
[
HumanMessage(
content=prompt,
)
]
).content
print("<OUT CRITIQUE=============================================>")
print(repr(out))
print("<OUT CRITIQUE=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<OUT CRITIQUE=============================================>")
-  print(repr(out))
-  print("<OUT CRITIQUE=============================================>")
+  logger.debug("<OUT CRITIQUE=============================================>")
+  logger.debug(repr(out))
+  logger.debug("<OUT CRITIQUE=============================================>")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("<OUT CRITIQUE=============================================>")
print(repr(out))
print("<OUT CRITIQUE=============================================>")
logger.debug("<OUT CRITIQUE=============================================>")
logger.debug(repr(out))
logger.debug("<OUT CRITIQUE=============================================>")

assert isinstance(out, str)
return out.strip()

Expand Down Expand Up @@ -208,12 +220,18 @@ def _prompt_refine(
prompt=prompt,
additional_keys=additional_keys,
)
print("<PROMPT REFINE=============================================>")
print(prompt)
print("<PROMPT REFINE=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<PROMPT REFINE=============================================>")
-  print(prompt)
-  print("<PROMPT REFINE=============================================>")
+  logger.debug("<PROMPT REFINE=============================================>")
+  logger.debug(prompt)
+  logger.debug("<PROMPT REFINE=============================================>")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("<PROMPT REFINE=============================================>")
print(prompt)
print("<PROMPT REFINE=============================================>")
logger.debug("<PROMPT REFINE=============================================>")
logger.debug(prompt)
logger.debug("<PROMPT REFINE=============================================>")

out = llm(
[
HumanMessage(
content=prompt,
)
]
).content
print("<OUT REFINE=============================================>")
print(repr(out))
print("<OUT REFINE=============================================>")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging instead of print statements for debugging.

Using print statements for debugging is not recommended for production code. Consider using the logging module to provide better control over log levels and outputs.

-  print("<OUT REFINE=============================================>")
-  print(repr(out))
-  print("<OUT REFINE=============================================>")
+  logger.debug("<OUT REFINE=============================================>")
+  logger.debug(repr(out))
+  logger.debug("<OUT REFINE=============================================>")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("<OUT REFINE=============================================>")
print(repr(out))
print("<OUT REFINE=============================================>")
logger.debug("<OUT REFINE=============================================>")
logger.debug(repr(out))
logger.debug("<OUT REFINE=============================================>")

assert isinstance(out, str)
return out.strip()
Loading
Loading