Skip to content

Commit

Permalink
Fixing discussion object when adding images w/o detail (#128)
Browse files Browse the repository at this point in the history
* switching to goal_object in xray

* adding excpetion for when summary isnt available

* fixing discussion w/ images
  • Loading branch information
TLSDC authored and gasse committed Nov 20, 2024
1 parent aa4e5ba commit cb787ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
36 changes: 20 additions & 16 deletions src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments.study import get_most_recent_study
from agentlab.llm.chat_api import make_system_message, make_user_message
from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage
from agentlab.llm.llm_utils import Discussion

select_dir_instructions = "Select Experiment Directory"
Expand Down Expand Up @@ -740,7 +741,7 @@ def get_episode_info(info: Info):
steps_info = info.exp_result.steps_info
step_info = steps_info[info.step]
try:
goal = step_info.obs["goal"]
goal = step_info.obs["goal_object"]
except KeyError:
goal = None
try:
Expand All @@ -757,7 +758,7 @@ def get_episode_info(info: Info):
**Goal:**
{code(goal)}
{code(str(AgentLabBaseMessage('', goal)))}
**Task info:**
Expand Down Expand Up @@ -992,20 +993,23 @@ def get_directory_contents(results_dir: Path):
continue

exp_description = dir.name
# get summary*.csv files and find the most recent
summary_files = list(dir.glob("summary*.csv"))
if len(summary_files) != 0:
most_recent_summary = max(summary_files, key=os.path.getctime)
summary_df = pd.read_csv(most_recent_summary)

# get row with max avg_reward
max_reward_row = summary_df.loc[summary_df["avg_reward"].idxmax()]
reward = max_reward_row["avg_reward"] * 100
completed = max_reward_row["n_completed"]
n_err = max_reward_row["n_err"]
exp_description += (
f" - avg-reward: {reward:.1f}% - completed: {completed} - errors: {n_err}"
)
try:
# get summary*.csv files and find the most recent
summary_files = list(dir.glob("summary*.csv"))
if len(summary_files) != 0:
most_recent_summary = max(summary_files, key=os.path.getctime)
summary_df = pd.read_csv(most_recent_summary)

# get row with max avg_reward
max_reward_row = summary_df.loc[summary_df["avg_reward"].idxmax()]
reward = max_reward_row["avg_reward"] * 100
completed = max_reward_row["n_completed"]
n_err = max_reward_row["n_err"]
exp_description += (
f" - avg-reward: {reward:.1f}% - completed: {completed} - errors: {n_err}"
)
except Exception as e:
print(f"Error while reading summary file: {e}")

exp_descriptions.append(exp_description)

Expand Down
5 changes: 2 additions & 3 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import time
from copy import deepcopy
from functools import cache
from typing import TYPE_CHECKING
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -356,7 +355,7 @@ def add_image(self, image: np.ndarray | Image.Image | str, detail: str = None):
if detail:
self.add_content("image_url", {"url": image_url, "detail": detail})
else:
self.add_content("image_url", image_url)
self.add_content("image_url", {"url": image_url})

def to_markdown(self):
if isinstance(self["content"], str):
Expand Down

0 comments on commit cb787ab

Please sign in to comment.