forked from geekan/MetaGPT
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 1/12/2024 6:05 PM | ||
# @Author : stellahong ([email protected]) | ||
# @Desc : |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 12/25/2023 1:06 PM | ||
# @Author : stellahong ([email protected]) | ||
# @Desc : | ||
import re | ||
|
||
from metagpt.strategy.tot import TreeofThought | ||
from metagpt.strategy.tot_schema import ( | ||
BaseEvaluator, | ||
BaseParser, | ||
Strategy, | ||
ThoughtSolverConfig, | ||
) | ||
from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt | ||
|
||
|
||
class TextGenParser(BaseParser): | ||
propose_prompt: str = cot_prompt | ||
value_prompt: str = vote_prompt | ||
|
||
def __call__(self, input_text: str) -> str: | ||
return input_text | ||
|
||
def propose(self, current_state: str, **kwargs) -> str: | ||
return self.propose_prompt.format(input=current_state, **kwargs) | ||
|
||
def value(self, input: str = "", **kwargs) -> str: | ||
# node_result = self(input) | ||
id = kwargs.get("node_id", "0") | ||
return self.value_prompt + f"Choice {id}:\n{input}\n" | ||
|
||
|
||
class TextGenEvaluator(BaseEvaluator): | ||
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc | ||
status_map: dict = {val: key for key, val in value_map.items()} | ||
|
||
def __call__(self, evaluation: str, **kwargs) -> float: | ||
try: | ||
value = 0 | ||
node_id = kwargs.get("node_id", "0") | ||
pattern = r".*best choice is .*(\d+).*" | ||
match = re.match(pattern, evaluation, re.DOTALL) | ||
|
||
if match: | ||
vote = int(match.groups()[0]) | ||
print(vote) | ||
if vote == int(node_id): | ||
value = 1 | ||
except: | ||
value = 0 | ||
return value | ||
|
||
def status_verify(self, value): | ||
status = False | ||
if value in self.status_map: | ||
status_value = self.status_map[value] | ||
if status_value != "impossible": | ||
status = True | ||
return status | ||
|
||
|
||
if __name__ == "__main__": | ||
import asyncio | ||
|
||
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are.""" | ||
|
||
parser = TextGenParser() | ||
evaluator = TextGenEvaluator() | ||
|
||
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator) | ||
|
||
tot_base = TreeofThought(strategy=Strategy.BFS, config=config) | ||
asyncio.run(tot_base.solve(init_prompt=initial_prompt)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 12/25/2023 1:36 AM | ||
# @Author : stellahong ([email protected]) | ||
# @Desc : | ||
import re | ||
|
||
from metagpt.strategy.tot import TreeofThought | ||
from metagpt.strategy.tot_schema import ( | ||
BaseEvaluator, | ||
BaseParser, | ||
Strategy, | ||
ThoughtSolverConfig, | ||
) | ||
from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt | ||
|
||
|
||
class Game24Parser(BaseParser): | ||
propose_prompt: str = propose_prompt | ||
value_prompt: str = value_prompt | ||
|
||
def __call__(self, input_text: str) -> str: | ||
last_line = input_text.strip().split("\n")[-1] | ||
return last_line.split("left: ")[-1].split(")")[0] | ||
|
||
def propose(self, current_state: str, **kwargs) -> str: | ||
return self.propose_prompt.format(input=current_state, **kwargs) | ||
|
||
def value(self, input: str = "", **kwargs) -> str: | ||
node_result = self(input) | ||
return self.value_prompt.format(input=node_result) | ||
|
||
|
||
class Game24Evaluator(BaseEvaluator): | ||
value_map : dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc | ||
status_map : dict = {val: key for key, val in value_map.items()} | ||
|
||
def __call__(self, evaluation: str, **kwargs) -> float: | ||
try: | ||
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation) | ||
value = self.value_map[matches[0]] | ||
except: | ||
value = 0.001 | ||
return value | ||
|
||
def status_verify(self, value): | ||
status = False | ||
if value in self.status_map: | ||
status_value = self.status_map[value] | ||
if status_value != "impossible": | ||
status = True | ||
return status | ||
|
||
|
||
if __name__ == "__main__": | ||
import asyncio | ||
|
||
initial_prompt = """4 5 6 10""" | ||
parser = Game24Parser() | ||
evaluator = Game24Evaluator() | ||
|
||
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator) | ||
|
||
tot = TreeofThought(strategy=Strategy.BFS, config=config) | ||
asyncio.run(tot.solve(init_prompt=initial_prompt)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 12/23/2023 5:21 PM | ||
# @Author : stellahong ([email protected]) | ||
# @Desc : |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
standard_prompt = """ | ||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input} | ||
""" | ||
|
||
cot_prompt = """ | ||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input} | ||
Make a plan then write. Your output should be of the following format: | ||
Plan: | ||
Your plan here. | ||
Passage: | ||
Your passage here. | ||
""" | ||
|
||
|
||
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice. | ||
""" | ||
|
||
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent". | ||
""" | ||
|
||
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# 5-shot | ||
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. | ||
Input: 4 4 6 8 | ||
Answer: (4 + 8) * (6 - 4) = 24 | ||
Input: 2 9 10 12 | ||
Answer: 2 * 12 * (10 - 9) = 24 | ||
Input: 4 9 10 13 | ||
Answer: (13 - 9) * (10 - 4) = 24 | ||
Input: 1 4 8 8 | ||
Answer: (8 / 4 + 1) * 8 = 24 | ||
Input: 5 5 5 9 | ||
Answer: 5 + 5 + 5 + 9 = 24 | ||
Input: {input} | ||
""" | ||
|
||
# 5-shot | ||
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number. | ||
Input: 4 4 6 8 | ||
Steps: | ||
4 + 8 = 12 (left: 4 6 12) | ||
6 - 4 = 2 (left: 2 12) | ||
2 * 12 = 24 (left: 24) | ||
Answer: (6 - 4) * (4 + 8) = 24 | ||
Input: 2 9 10 12 | ||
Steps: | ||
12 * 2 = 24 (left: 9 10 24) | ||
10 - 9 = 1 (left: 1 24) | ||
24 * 1 = 24 (left: 24) | ||
Answer: (12 * 2) * (10 - 9) = 24 | ||
Input: 4 9 10 13 | ||
Steps: | ||
13 - 10 = 3 (left: 3 4 9) | ||
9 - 3 = 6 (left: 4 6) | ||
4 * 6 = 24 (left: 24) | ||
Answer: 4 * (9 - (13 - 10)) = 24 | ||
Input: 1 4 8 8 | ||
Steps: | ||
8 / 4 = 2 (left: 1 2 8) | ||
1 + 2 = 3 (left: 3 8) | ||
3 * 8 = 24 (left: 24) | ||
Answer: (1 + 8 / 4) * 8 = 24 | ||
Input: 5 5 5 9 | ||
Steps: | ||
5 + 5 = 10 (left: 5 9 10) | ||
10 + 5 = 15 (left: 9 15) | ||
15 + 9 = 24 (left: 24) | ||
Answer: ((5 + 5) + 5) + 9 = 24 | ||
Input: {input} | ||
""" | ||
|
||
# 1-shot | ||
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts: | ||
Input: 2 8 8 14 | ||
Possible next steps: | ||
2 + 8 = 10 (left: 8 10 14) | ||
8 / 2 = 4 (left: 4 8 14) | ||
14 + 2 = 16 (left: 8 8 16) | ||
2 * 8 = 16 (left: 8 14 16) | ||
8 - 2 = 6 (left: 6 8 14) | ||
14 - 8 = 6 (left: 2 6 8) | ||
14 / 2 = 7 (left: 7 8 8) | ||
14 - 2 = 12 (left: 8 8 12) | ||
Here is my task for 1 input and {n_generate_sample} possible thoughts: | ||
Input: {input} | ||
Possible next steps: | ||
""" | ||
|
||
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible) | ||
10 14 | ||
10 + 14 = 24 | ||
sure | ||
11 12 | ||
11 + 12 = 23 | ||
12 - 11 = 1 | ||
11 * 12 = 132 | ||
11 / 12 = 0.91 | ||
impossible | ||
4 4 10 | ||
4 + 4 + 10 = 8 + 10 = 18 | ||
4 * 10 - 4 = 40 - 4 = 36 | ||
(10 - 4) * 4 = 6 * 4 = 24 | ||
sure | ||
4 9 11 | ||
9 + 11 + 4 = 20 + 4 = 24 | ||
sure | ||
5 7 8 | ||
5 + 7 + 8 = 12 + 8 = 20 | ||
(8 - 5) * 7 = 3 * 7 = 21 | ||
I cannot obtain 24 now, but numbers are within a reasonable range | ||
likely | ||
5 6 6 | ||
5 + 6 + 6 = 17 | ||
(6 - 5) * 6 = 1 * 6 = 6 | ||
I cannot obtain 24 now, but numbers are within a reasonable range | ||
likely | ||
10 10 11 | ||
10 + 10 + 11 = 31 | ||
(11 - 10) * 10 = 10 | ||
10 10 10 are all too big | ||
impossible | ||
1 3 3 | ||
1 * 3 * 3 = 9 | ||
(1 + 3) * 3 = 12 | ||
1 3 3 are all too small | ||
impossible | ||
{input} | ||
""" | ||
|
||
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24. | ||
Input: 4 4 6 8 | ||
Answer: (4 + 8) * (6 - 4) = 24 | ||
Judge: | ||
sure | ||
Input: 2 9 10 12 | ||
Answer: 2 * 12 * (10 - 9) = 24 | ||
Judge: | ||
sure | ||
Input: 4 9 10 13 | ||
Answer: (13 - 9) * (10 - 4) = 24 | ||
Judge: | ||
sure | ||
Input: 4 4 6 8 | ||
Answer: (4 + 8) * (6 - 4) + 1 = 25 | ||
Judge: | ||
impossible | ||
Input: 2 9 10 12 | ||
Answer: 2 * (12 - 10) = 24 | ||
Judge: | ||
impossible | ||
Input: 4 9 10 13 | ||
Answer: (13 - 4) * (10 - 9) = 24 | ||
Judge: | ||
impossible | ||
Input: {input} | ||
Answer: {answer} | ||
Judge:""" |