Skip to content

Commit

Permalink
RUN pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaHSR committed Jan 14, 2024
1 parent 2829964 commit 379ae51
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
22 changes: 11 additions & 11 deletions examples/tot/creative_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@
# @Desc :
import re

from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt


class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt

def __call__(self, input_text: str) -> str:
return input_text

def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)

def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
Expand All @@ -33,14 +33,14 @@ def value(self, input: str = "", **kwargs) -> str:
class TextGenEvaluator(BaseEvaluator):
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map: dict = {val: key for key, val in value_map.items()}

def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)

if match:
vote = int(match.groups()[0])
print(vote)
Expand All @@ -49,7 +49,7 @@ def __call__(self, evaluation: str, **kwargs) -> float:
except:
value = 0
return value

def status_verify(self, value):
status = False
if value in self.status_map:
Expand All @@ -61,13 +61,13 @@ def status_verify(self, value):

if __name__ == "__main__":
import asyncio

initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""

parser = TextGenParser()
evaluator = TextGenEvaluator()

config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)

tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))
22 changes: 11 additions & 11 deletions examples/tot/game24.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,44 @@
# @Desc :
import re

from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt


class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt

def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]

def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)

def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)


class Game24Evaluator(BaseEvaluator):
value_map : dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map : dict = {val: key for key, val in value_map.items()}
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map: dict = {val: key for key, val in value_map.items()}

def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
value = self.value_map[matches[0]]
except:
value = 0.001
return value

def status_verify(self, value):
status = False
if value in self.status_map:
Expand All @@ -53,12 +53,12 @@ def status_verify(self, value):

if __name__ == "__main__":
import asyncio

initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()

config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)

tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))

0 comments on commit 379ae51

Please sign in to comment.