From 7bcc56c8f0ab7b26d8d08ba506701456bae1f82f Mon Sep 17 00:00:00 2001 From: recursix Date: Tue, 22 Oct 2024 15:53:15 -0400 Subject: [PATCH 1/5] add tabs --- src/agentlab/agents/dynamic_prompting.py | 33 +++++++++++++++++-- .../generic_agent/generic_agent_prompt.py | 6 +++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 54d52f2cd..c90e3299e 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -366,16 +366,45 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None: """ +class Tabs(PromptElement): + def __init__(self, obs, visible: bool = True, prefix="") -> None: + super().__init__(visible=visible) + + prompt_pieces = [f"\n{prefix}Currently open tabs:"] + for page_index, (page_url, page_title) in enumerate( + zip(obs["open_pages_urls"], obs["open_pages_titles"]) + ): + active_or_not = " (active tab)" if page_index == obs["active_page_index"] else "" + prompt_piece = f"""\ +Tab {page_index}{active_or_not}: + Title: {page_title} + URL: {page_url} +""" + prompt_pieces.append(prompt_piece) + self._prompt = "\n".join(prompt_pieces) + + +def has_tab_action(action_set: bgym.HighLevelActionSetArgs): + return "tab" in action_set.subsets + + class Observation(Shrinkable): """Observation of the current step. Contains the html, the accessibility tree and the error logs. """ - def __init__(self, obs, flags: ObsFlags) -> None: + def __init__(self, obs, flags: ObsFlags, use_tabs=False) -> None: super().__init__() self.flags = flags self.obs = obs + + self.tabs = Tabs( + obs, + visible=use_tabs, + prefix="## ", + ) + self.html = HTML( obs[flags.html_type], visible_elements_only=flags.filter_visible_elements_only, @@ -409,7 +438,7 @@ def shrink(self): def _prompt(self) -> str: return f""" # Observation of current step: -{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} +{self.tabs}{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} """ diff --git a/src/agentlab/agents/generic_agent/generic_agent_prompt.py b/src/agentlab/agents/generic_agent/generic_agent_prompt.py index a655b42f3..50eeeed21 100644 --- a/src/agentlab/agents/generic_agent/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent/generic_agent_prompt.py @@ -74,7 +74,11 @@ def __init__( obs_history[-1]["goal"], extra_instructions=flags.extra_instructions ) - self.obs = dp.Observation(obs_history[-1], self.flags.obs) + self.obs = dp.Observation( + obs_history[-1], + self.flags.obs, + use_tabs=dp.has_tab_action(self.flags.action.action_set), + ) self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) From c9634f8f78d99d6d3a8748bbed9450cfdcf7649e Mon Sep 17 00:00:00 2001 From: recursix Date: Tue, 22 Oct 2024 15:55:06 -0400 Subject: [PATCH 2/5] this workaround is worst --- src/agentlab/analyze/agent_xray.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 8274ed262..38968fd6c 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -153,12 +153,6 @@ def filter_agent_id(self, agent_id: list[tuple]): white-space: normal !important; word-wrap: break-word !important; } -#task_table { - height: 500px !important; -} -#seed_table { - height: 500px !important; -} """ @@ -227,7 +221,7 @@ def run_gradio(results_dir: Path): content. You have to sort back with the Idx column to align the click with the order.""" ) - agent_table = gr.DataFrame(height=500, show_label=False, interactive=False) + agent_table = gr.DataFrame(max_height=500, show_label=False, interactive=False) with gr.Tab("Select Task and Seed", id="Select Task"): with gr.Row(): with gr.Column(scale=4): @@ -244,7 +238,10 @@ def run_gradio(results_dir: Path): refresh_results_button = gr.Button("↺", scale=0, size="sm") task_table = gr.DataFrame( - height=500, show_label=False, interactive=False, elem_id="task_table" + max_height=500, + show_label=False, + interactive=False, + elem_id="task_table", ) with gr.Column(scale=2): @@ -259,7 +256,10 @@ def run_gradio(results_dir: Path): ) seed_table = gr.DataFrame( - height=500, show_label=False, interactive=False, elem_id="seed_table" + max_height=500, + show_label=False, + interactive=False, + elem_id="seed_table", ) with gr.Tab("Constants and Variables"): @@ -272,7 +272,9 @@ def run_gradio(results_dir: Path): **all** agents. They are displayed as a table with the name and value of the constant.""" ) - constants = gr.DataFrame(height=500, show_label=False, interactive=False) + constants = gr.DataFrame( + max_height=500, show_label=False, interactive=False + ) with gr.Column(scale=2): with gr.Accordion("Variables", open=False): gr.Markdown( @@ -281,9 +283,11 @@ def run_gradio(results_dir: Path): They are displayed as a table with the name, value and count of unique values. A maximum of 3 different values are displayed.""" ) - variables = gr.DataFrame(height=500, show_label=False, interactive=False) + variables = gr.DataFrame( + max_height=500, show_label=False, interactive=False + ) with gr.Tab("Global Stats"): - global_stats = gr.DataFrame(height=500, show_label=False, interactive=False) + global_stats = gr.DataFrame(max_height=500, show_label=False, interactive=False) with gr.Row(): episode_info = gr.Markdown(label="Episode Info", elem_classes="my-markdown") @@ -356,7 +360,7 @@ def run_gradio(results_dir: Path): logs = gr.Code(language=None, **code_args) with gr.Tab("Stats") as tab_stats: - stats = gr.DataFrame(height=500, show_label=False, interactive=False) + stats = gr.DataFrame(max_height=500, show_label=False, interactive=False) with gr.Tab("Agent Info HTML") as tab_agent_info_html: with gr.Row(): From b08c3e7d77987f10ea13261c0b9777e1b2173561 Mon Sep 17 00:00:00 2001 From: recursix Date: Tue, 22 Oct 2024 15:55:57 -0400 Subject: [PATCH 3/5] bug fix --- src/agentlab/experiments/reproducibility_util.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index e9b9dd90f..754347953 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -262,14 +262,10 @@ def _verify_report(report_df: pd.DataFrame, agent_names=list[str], strict_reprod unique_agent_names = report_df["agent.agent_name"].unique() if set(agent_names) != set(unique_agent_names): raise ValueError( - f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}.", - raise_error=strict_reproducibility, + f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}." ) if len(set(agent_names)) != len(agent_names): - raise ValueError( - f"Duplicate agent names {agent_names}.", - raise_error=strict_reproducibility, - ) + raise ValueError(f"Duplicate agent names {agent_names}.") report_df = report_df.set_index("agent.agent_name", inplace=False) From a6fe3a90f1ce0e4eb3c8a2f26ba951dc48ad323a Mon Sep 17 00:00:00 2001 From: recursix Date: Tue, 22 Oct 2024 15:56:21 -0400 Subject: [PATCH 4/5] fix reproduce study --- .../generic_agent/reproducibility_agent.py | 36 ++++++++++++++----- src/agentlab/experiments/study.py | 25 +++++++++---- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index 091cf6cf5..ffec1111a 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from pathlib import Path +import bgym from browsergym.experiments.agent import AgentInfo from browsergym.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results from bs4 import BeautifulSoup @@ -24,6 +25,7 @@ from langchain_community.adapters.openai import convert_message_to_dict from agentlab.agents.agent_args import AgentArgs +from agentlab.agents.dynamic_prompting import ActionFlags from agentlab.experiments.study import Study from agentlab.llm.chat_api import make_assistant_message from agentlab.llm.llm_utils import Discussion, messages_to_dict @@ -141,18 +143,29 @@ def _format_messages(messages: list[dict]): return "\n".join(f"{m['role']} message:\n{m['content']}\n" for m in messages) +def _make_backward_compatible(agent_args: GenericAgentArgs): + action_set = agent_args.flags.action.action_set + if isinstance(action_set, (str, list)): + if isinstance(action_set, str): + action_set = action_set.split("+") + + agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs( + subsets=action_set, + multiaction=agent_args.flags.action.multi_actions, + ) + + return agent_args + + def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO): """Reproduce a study by running the same experiments with the same agent.""" original_study_dir = Path(original_study_dir) - study = Study.load(original_study_dir) - study.dir = None - study.make_dir() - exp_args_list: list[ExpArgs] = [] for exp_result in yield_all_exp_results(original_study_dir, progress_fn=None): - agent_args = make_repro_agent(exp_result.exp_args.agent_args, exp_dir=exp_result.exp_dir) + agent_args = _make_backward_compatible(exp_result.exp_args.agent_args) + agent_args = make_repro_agent(agent_args, exp_dir=exp_result.exp_dir) exp_args_list.append( ExpArgs( agent_args=agent_args, @@ -160,13 +173,18 @@ def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO): logging_level=log_level, ) ) + + # infer benchmark name from task list for backward compatible benchmark_name = exp_args_list[0].env_args.task_name.split(".")[0] - return Study( - exp_args_list=exp_args_list, - benchmark_name=benchmark_name, - agent_names=[agent_args.agent_name], + study = Study( + benchmark=benchmark_name, + agent_args=[agent_args], ) + # this exp_args_list has a different agent_args for each experiment as repro_agent takes the exp_dir as argument + # so we overwrite exp_args_list with the one we created above + study.exp_args_list = exp_args_list + return study def make_repro_agent(agent_args: AgentArgs, exp_dir: Path | str): diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 60f2166bb..c83019524 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -23,8 +23,9 @@ class Study: This is part of the high level API to help keep experiments organized and reproducible. Attributes: - benchmark: Benchmark - The benchmark to evaluate the agents on. + benchmark: Benchmark | str + The benchmark to evaluate the agents on. If a string is provided, it will be + converted to the corresponding benchmark using bgym.BENCHMARKS. agent_args: list[AgentArgs] The list of agents to evaluate. @@ -43,7 +44,7 @@ class Study: """ agent_args: list[AgentArgs] = None - benchmark: Benchmark = None + benchmark: Benchmark | str = None dir: Path = None suffix: str = "" # used for adding a personnal comment to the study name uuid: str = None @@ -157,10 +158,20 @@ def get_report(self, ignore_cache=False, ignore_stale=False): @staticmethod def load(dir: Path) -> "Study": - with gzip.open(dir / "study.pkl.gz", "rb") as f: - study = pickle.load(f) # type: Study - - study.dir = dir + dir = Path(dir) + study_path = dir / "study.pkl.gz" + if not study_path.exists() and dir.is_dir(): + # For backward compatibility + first_result = next( + inspect_results.yield_all_exp_results(savedir_base=dir, progress_fn=None) + ) + benchmark_name = first_result.exp_args.env_args.task_name.split(".")[0] + agent_args = first_result.exp_args.agent_args + study = Study(agent_args=agent_args, benchmark=benchmark_name, dir=dir) + else: + with gzip.open(dir / "study.pkl.gz", "rb") as f: + study = pickle.load(f) # type: Study + study.dir = dir return study @staticmethod From 9420fbea847b3173d764c7d925f7119d8b8a2457 Mon Sep 17 00:00:00 2001 From: recursix Date: Tue, 22 Oct 2024 16:01:28 -0400 Subject: [PATCH 5/5] make sure it's not computed if not visible --- src/agentlab/agents/dynamic_prompting.py | 11 ++++++++--- tests/experiments/test_reproducibility_util.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index c90e3299e..5abd7ea60 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -369,12 +369,17 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None: class Tabs(PromptElement): def __init__(self, obs, visible: bool = True, prefix="") -> None: super().__init__(visible=visible) + self.obs = obs + self.prefix = prefix - prompt_pieces = [f"\n{prefix}Currently open tabs:"] + @property + def _prompt(self) -> str: + # by implementing this as a property, it's only coputed if visible + prompt_pieces = [f"\n{self.prefix}Currently open tabs:"] for page_index, (page_url, page_title) in enumerate( - zip(obs["open_pages_urls"], obs["open_pages_titles"]) + zip(self.obs["open_pages_urls"], self.obs["open_pages_titles"]) ): - active_or_not = " (active tab)" if page_index == obs["active_page_index"] else "" + active_or_not = " (active tab)" if page_index == self.obs["active_page_index"] else "" prompt_piece = f"""\ Tab {page_index}{active_or_not}: Title: {page_title} diff --git a/tests/experiments/test_reproducibility_util.py b/tests/experiments/test_reproducibility_util.py index 57299ae3e..6008bb30e 100644 --- a/tests/experiments/test_reproducibility_util.py +++ b/tests/experiments/test_reproducibility_util.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "benchmark_name", - ["miniwob_all", "workarena_l1", "webarena", "visualwebarena"], + ["miniwob", "workarena_l1", "webarena", "visualwebarena"], ) def test_get_reproducibility_info(benchmark_name):