Skip to content

Commit

Permalink
add tot examples
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaHSR committed Jan 14, 2024
1 parent da87e6a commit 2829964
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/tot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 1/12/2024 6:05 PM
# @Author : stellahong ([email protected])
# @Desc :
73 changes: 73 additions & 0 deletions examples/tot/creative_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:06 PM
# @Author : stellahong ([email protected])
# @Desc :
import re

from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt


class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt

def __call__(self, input_text: str) -> str:
return input_text

def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)

def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
return self.value_prompt + f"Choice {id}:\n{input}\n"


class TextGenEvaluator(BaseEvaluator):
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map: dict = {val: key for key, val in value_map.items()}

def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)

if match:
vote = int(match.groups()[0])
print(vote)
if vote == int(node_id):
value = 1
except:
value = 0
return value

def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status


if __name__ == "__main__":
import asyncio

initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""

parser = TextGenParser()
evaluator = TextGenEvaluator()

config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)

tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))
64 changes: 64 additions & 0 deletions examples/tot/game24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:36 AM
# @Author : stellahong ([email protected])
# @Desc :
import re

from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt


class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt

def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]

def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)

def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)


class Game24Evaluator(BaseEvaluator):
value_map : dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map : dict = {val: key for key, val in value_map.items()}

def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
value = self.value_map[matches[0]]
except:
value = 0.001
return value

def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status


if __name__ == "__main__":
import asyncio

initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()

config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)

tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))
4 changes: 4 additions & 0 deletions examples/tot/prompt_templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 5:21 PM
# @Author : stellahong ([email protected])
# @Desc :
25 changes: 25 additions & 0 deletions examples/tot/prompt_templates/creative_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
standard_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
"""

cot_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
"""


vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
"""

compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
"""

score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
"""
139 changes: 139 additions & 0 deletions examples/tot/prompt_templates/game24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# 5-shot
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
"""

# 5-shot
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
"""

# 1-shot
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 / 2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Here is my task for 1 input and {n_generate_sample} possible thoughts:
Input: {input}
Possible next steps:
"""

value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
"""

value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input}
Answer: {answer}
Judge:"""

0 comments on commit 2829964

Please sign in to comment.