Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/agentlab/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,50 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None:
"""


class Tabs(PromptElement):
def __init__(self, obs, visible: bool = True, prefix="") -> None:
super().__init__(visible=visible)
self.obs = obs
self.prefix = prefix

@property
def _prompt(self) -> str:
# by implementing this as a property, it's only coputed if visible
prompt_pieces = [f"\n{self.prefix}Currently open tabs:"]
for page_index, (page_url, page_title) in enumerate(
zip(self.obs["open_pages_urls"], self.obs["open_pages_titles"])
):
active_or_not = " (active tab)" if page_index == self.obs["active_page_index"] else ""
prompt_piece = f"""\
Tab {page_index}{active_or_not}:
Title: {page_title}
URL: {page_url}
"""
prompt_pieces.append(prompt_piece)
self._prompt = "\n".join(prompt_pieces)


def has_tab_action(action_set: bgym.HighLevelActionSetArgs):
return "tab" in action_set.subsets


class Observation(Shrinkable):
"""Observation of the current step.

Contains the html, the accessibility tree and the error logs.
"""

def __init__(self, obs, flags: ObsFlags) -> None:
def __init__(self, obs, flags: ObsFlags, use_tabs=False) -> None:
super().__init__()
self.flags = flags
self.obs = obs

self.tabs = Tabs(
obs,
visible=use_tabs,
prefix="## ",
)

self.html = HTML(
obs[flags.html_type],
visible_elements_only=flags.filter_visible_elements_only,
Expand Down Expand Up @@ -409,7 +443,7 @@ def shrink(self):
def _prompt(self) -> str:
return f"""
# Observation of current step:
{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt}
{self.tabs}{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt}

"""

Expand Down
6 changes: 5 additions & 1 deletion src/agentlab/agents/generic_agent/generic_agent_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def __init__(
obs_history[-1]["goal"], extra_instructions=flags.extra_instructions
)

self.obs = dp.Observation(obs_history[-1], self.flags.obs)
self.obs = dp.Observation(
obs_history[-1],
self.flags.obs,
use_tabs=dp.has_tab_action(self.flags.action.action_set),
)

self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action)

Expand Down
36 changes: 27 additions & 9 deletions src/agentlab/agents/generic_agent/reproducibility_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
from dataclasses import dataclass
from pathlib import Path

import bgym
from browsergym.experiments.agent import AgentInfo
from browsergym.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
from bs4 import BeautifulSoup
from langchain.schema import AIMessage, BaseMessage
from langchain_community.adapters.openai import convert_message_to_dict

from agentlab.agents.agent_args import AgentArgs
from agentlab.agents.dynamic_prompting import ActionFlags
from agentlab.experiments.study import Study
from agentlab.llm.chat_api import make_assistant_message
from agentlab.llm.llm_utils import Discussion, messages_to_dict
Expand Down Expand Up @@ -141,32 +143,48 @@ def _format_messages(messages: list[dict]):
return "\n".join(f"{m['role']} message:\n{m['content']}\n" for m in messages)


def _make_backward_compatible(agent_args: GenericAgentArgs):
action_set = agent_args.flags.action.action_set
if isinstance(action_set, (str, list)):
if isinstance(action_set, str):
action_set = action_set.split("+")

agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs(
subsets=action_set,
multiaction=agent_args.flags.action.multi_actions,
)

return agent_args


def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO):
"""Reproduce a study by running the same experiments with the same agent."""

original_study_dir = Path(original_study_dir)

study = Study.load(original_study_dir)
study.dir = None
study.make_dir()

exp_args_list: list[ExpArgs] = []
for exp_result in yield_all_exp_results(original_study_dir, progress_fn=None):
agent_args = make_repro_agent(exp_result.exp_args.agent_args, exp_dir=exp_result.exp_dir)
agent_args = _make_backward_compatible(exp_result.exp_args.agent_args)
agent_args = make_repro_agent(agent_args, exp_dir=exp_result.exp_dir)
exp_args_list.append(
ExpArgs(
agent_args=agent_args,
env_args=exp_result.exp_args.env_args,
logging_level=log_level,
)
)

# infer benchmark name from task list for backward compatible
benchmark_name = exp_args_list[0].env_args.task_name.split(".")[0]

return Study(
exp_args_list=exp_args_list,
benchmark_name=benchmark_name,
agent_names=[agent_args.agent_name],
study = Study(
benchmark=benchmark_name,
agent_args=[agent_args],
)
# this exp_args_list has a different agent_args for each experiment as repro_agent takes the exp_dir as argument
# so we overwrite exp_args_list with the one we created above
study.exp_args_list = exp_args_list
return study


def make_repro_agent(agent_args: AgentArgs, exp_dir: Path | str):
Expand Down
30 changes: 17 additions & 13 deletions src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,6 @@ def filter_agent_id(self, agent_id: list[tuple]):
white-space: normal !important;
word-wrap: break-word !important;
}
#task_table {
height: 500px !important;
}
#seed_table {
height: 500px !important;
}
"""


Expand Down Expand Up @@ -227,7 +221,7 @@ def run_gradio(results_dir: Path):
content. You have to sort back with the Idx column to align the click with
the order."""
)
agent_table = gr.DataFrame(height=500, show_label=False, interactive=False)
agent_table = gr.DataFrame(max_height=500, show_label=False, interactive=False)
with gr.Tab("Select Task and Seed", id="Select Task"):
with gr.Row():
with gr.Column(scale=4):
Expand All @@ -244,7 +238,10 @@ def run_gradio(results_dir: Path):
refresh_results_button = gr.Button("↺", scale=0, size="sm")

task_table = gr.DataFrame(
height=500, show_label=False, interactive=False, elem_id="task_table"
max_height=500,
show_label=False,
interactive=False,
elem_id="task_table",
)

with gr.Column(scale=2):
Expand All @@ -259,7 +256,10 @@ def run_gradio(results_dir: Path):
)

seed_table = gr.DataFrame(
height=500, show_label=False, interactive=False, elem_id="seed_table"
max_height=500,
show_label=False,
interactive=False,
elem_id="seed_table",
)

with gr.Tab("Constants and Variables"):
Expand All @@ -272,7 +272,9 @@ def run_gradio(results_dir: Path):
**all** agents. They are displayed as a table with the name and value of the
constant."""
)
constants = gr.DataFrame(height=500, show_label=False, interactive=False)
constants = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)
with gr.Column(scale=2):
with gr.Accordion("Variables", open=False):
gr.Markdown(
Expand All @@ -281,9 +283,11 @@ def run_gradio(results_dir: Path):
They are displayed as a table with the name, value and count of unique
values. A maximum of 3 different values are displayed."""
)
variables = gr.DataFrame(height=500, show_label=False, interactive=False)
variables = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)
with gr.Tab("Global Stats"):
global_stats = gr.DataFrame(height=500, show_label=False, interactive=False)
global_stats = gr.DataFrame(max_height=500, show_label=False, interactive=False)

with gr.Row():
episode_info = gr.Markdown(label="Episode Info", elem_classes="my-markdown")
Expand Down Expand Up @@ -356,7 +360,7 @@ def run_gradio(results_dir: Path):
logs = gr.Code(language=None, **code_args)

with gr.Tab("Stats") as tab_stats:
stats = gr.DataFrame(height=500, show_label=False, interactive=False)
stats = gr.DataFrame(max_height=500, show_label=False, interactive=False)

with gr.Tab("Agent Info HTML") as tab_agent_info_html:
with gr.Row():
Expand Down
8 changes: 2 additions & 6 deletions src/agentlab/experiments/reproducibility_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,10 @@ def _verify_report(report_df: pd.DataFrame, agent_names=list[str], strict_reprod
unique_agent_names = report_df["agent.agent_name"].unique()
if set(agent_names) != set(unique_agent_names):
raise ValueError(
f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}.",
raise_error=strict_reproducibility,
f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}."
)
if len(set(agent_names)) != len(agent_names):
raise ValueError(
f"Duplicate agent names {agent_names}.",
raise_error=strict_reproducibility,
)
raise ValueError(f"Duplicate agent names {agent_names}.")

report_df = report_df.set_index("agent.agent_name", inplace=False)

Expand Down
25 changes: 18 additions & 7 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ class Study:
This is part of the high level API to help keep experiments organized and reproducible.

Attributes:
benchmark: Benchmark
The benchmark to evaluate the agents on.
benchmark: Benchmark | str
The benchmark to evaluate the agents on. If a string is provided, it will be
converted to the corresponding benchmark using bgym.BENCHMARKS.

agent_args: list[AgentArgs]
The list of agents to evaluate.
Expand All @@ -43,7 +44,7 @@ class Study:
"""

agent_args: list[AgentArgs] = None
benchmark: Benchmark = None
benchmark: Benchmark | str = None
dir: Path = None
suffix: str = "" # used for adding a personnal comment to the study name
uuid: str = None
Expand Down Expand Up @@ -157,10 +158,20 @@ def get_report(self, ignore_cache=False, ignore_stale=False):

@staticmethod
def load(dir: Path) -> "Study":
with gzip.open(dir / "study.pkl.gz", "rb") as f:
study = pickle.load(f) # type: Study

study.dir = dir
dir = Path(dir)
study_path = dir / "study.pkl.gz"
if not study_path.exists() and dir.is_dir():
# For backward compatibility
first_result = next(
inspect_results.yield_all_exp_results(savedir_base=dir, progress_fn=None)
)
benchmark_name = first_result.exp_args.env_args.task_name.split(".")[0]
agent_args = first_result.exp_args.agent_args
study = Study(agent_args=agent_args, benchmark=benchmark_name, dir=dir)
else:
with gzip.open(dir / "study.pkl.gz", "rb") as f:
study = pickle.load(f) # type: Study
study.dir = dir
return study

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/experiments/test_reproducibility_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize(
"benchmark_name",
["miniwob_all", "workarena_l1", "webarena", "visualwebarena"],
["miniwob", "workarena_l1", "webarena", "visualwebarena"],
)
def test_get_reproducibility_info(benchmark_name):

Expand Down