Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
alckasoc committed Jun 24, 2024
1 parent ee6791e commit bf56ed5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
11 changes: 9 additions & 2 deletions agential/cog/strategies/react/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def generate_action(

return action_type, query

def generate_observation(self, idx: int, action_type: str, query: str) -> Tuple[str, Dict[str, Any]]:
def generate_observation(
self, idx: int, action_type: str, query: str
) -> Tuple[str, Dict[str, Any]]:
"""Generates an observation based on the action type and query.
Args:
Expand Down Expand Up @@ -188,7 +190,12 @@ def generate_observation(self, idx: int, action_type: str, query: str) -> Tuple[
return obs, external_tool_info

def create_output_dict(
self, thought: str, action_type: str, query: str, obs: str, external_tool_info: Dict[str, Any]
self,
thought: str,
action_type: str,
query: str,
obs: str,
external_tool_info: Dict[str, Any],
) -> Dict[str, Any]:
"""Creates a dictionary of the output components.
Expand Down
22 changes: 16 additions & 6 deletions tests/cog/strategies/react/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def test_generate_observation() -> None:
query = "def first_repeated_char(s):\n char_set = set()\n for char in s:\n if char in char_set:\n return char\n else:\n char_set.add(char)\n return None"
llm = FakeListChatModel(responses=[])
strategy = ReActCodeStrategy(llm=llm)
obs, external_tool_info = strategy.generate_observation(idx=0, action_type=action_type, query=query)
obs, external_tool_info = strategy.generate_observation(
idx=0, action_type=action_type, query=query
)
assert obs == gt_obs
assert strategy._answer == query
assert strategy._finished is False
Expand All @@ -130,7 +132,9 @@ def test_generate_observation() -> None:
llm = FakeListChatModel(responses=[])
strategy = ReActCodeStrategy(llm=llm)
strategy._answer = "print('Hello World')"
obs, external_tool_info = strategy.generate_observation(idx=0, action_type=action_type, query=query)
obs, external_tool_info = strategy.generate_observation(
idx=0, action_type=action_type, query=query
)
assert obs == gt_obs
assert strategy._answer == "print('Hello World')"
assert strategy._finished is False
Expand All @@ -144,7 +148,9 @@ def test_generate_observation() -> None:
query = "def first_repeated_char(s):\n char_set = set()\n for char in s:\n if char in char_set:\n return char\n else:\n char_set.add(char)\n return None"
llm = FakeListChatModel(responses=[])
strategy = ReActCodeStrategy(llm=llm)
obs, external_tool_info = strategy.generate_observation(idx=0, action_type=action_type, query=query)
obs, external_tool_info = strategy.generate_observation(
idx=0, action_type=action_type, query=query
)
assert obs == gt_obs
assert strategy._answer == query
assert strategy._finished is True
Expand All @@ -157,7 +163,9 @@ def test_generate_observation() -> None:
query = "def first_repeated_char(s):\n char_set = set()\n for char in s:\n if char in char_set:\n return char\n else:\n char_set.add(char)\n return None"
llm = FakeListChatModel(responses=[])
strategy = ReActCodeStrategy(llm=llm)
obs, external_tool_info = strategy.generate_observation(idx=0, action_type=action_type, query=query)
obs, external_tool_info = strategy.generate_observation(
idx=0, action_type=action_type, query=query
)
assert (
obs
== "Invalid Action. Valid Actions are Implement[code] Test[code] and Finish[answer]."
Expand Down Expand Up @@ -185,10 +193,12 @@ def test_create_output_dict() -> None:
"query": query,
"observation": obs,
"answer": strategy._answer,
"external_tool_info": external_tool_info
"external_tool_info": external_tool_info,
}

output = strategy.create_output_dict(thought, action_type, query, obs, external_tool_info)
output = strategy.create_output_dict(
thought, action_type, query, obs, external_tool_info
)
assert output == expected_output


Expand Down

0 comments on commit bf56ed5

Please sign in to comment.