Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
d76e5c3
downgrading ubuntu version for github tests (#62)
TLSDC Oct 15, 2024
3697c4d
Llm api update (#59)
TLSDC Oct 15, 2024
9f58f15
Reproducibility again (#61)
recursix Oct 16, 2024
79b9202
version bump
TLSDC Oct 16, 2024
b09ea93
Patching minor stuff (#69)
TLSDC Oct 16, 2024
3ceaa0f
Improve agent xray app (#70)
xhluca Oct 17, 2024
7bba275
added tmlr definitive config (#71)
TLSDC Oct 17, 2024
1b6b217
downgrading gradio version (#77)
TLSDC Oct 19, 2024
f95df4a
Merge remote-tracking branch 'origin/main' into dev
TLSDC Oct 19, 2024
98acd0c
Study refactor (#73)
recursix Oct 20, 2024
98e5a22
adding message class and updating generic agent accordingly (#68)
TLSDC Oct 21, 2024
a6c1f93
version bump
TLSDC Oct 21, 2024
d085e81
Updating generic_agent to fit use BGym's goal_object (#83)
TLSDC Oct 22, 2024
59ad7cc
Minor revert (#86)
TLSDC Oct 22, 2024
0e83133
Add tabs (#84)
recursix Oct 22, 2024
86fe572
Fix reproduce study (#87)
recursix Oct 23, 2024
682e0f4
upgrading gradio dependency (#88)
TLSDC Oct 23, 2024
176fe8a
bgym update (#90)
TLSDC Oct 23, 2024
605c503
Workarena TMLR experiments (#89)
TLSDC Oct 24, 2024
96b5cd6
handling sequntial in VWA (#91)
recursix Oct 24, 2024
13840fc
Tmlr workarena (#92)
TLSDC Oct 24, 2024
024481a
tmp
recursix Oct 24, 2024
519f51e
reformat
recursix Oct 24, 2024
8f235f8
adding assistantbench to reproducibility_util.py
TLSDC Oct 24, 2024
6e18fb8
gitignore (#97)
gasse Oct 30, 2024
05448cf
Vision fix (#105)
TLSDC Nov 5, 2024
f6ac587
L2 tmlr (#93)
TLSDC Nov 5, 2024
f8d1e47
Replacing Dask with Ray (#100)
recursix Nov 6, 2024
6684e3d
switching to 2 for loops in _agents_on_benchmark (#107)
TLSDC Nov 6, 2024
dab1a48
yet another way to kill timedout jobs (#108)
recursix Nov 6, 2024
aa59a4a
Fix prompt formatting in Observation and add static method to Study c…
recursix Nov 7, 2024
feda734
Bug fix (#111)
recursix Nov 7, 2024
7a5b91e
Fixing openrouter pricing rate limit (#112)
TLSDC Nov 7, 2024
3e94570
updating max prompt configs, vision support (#109)
TLSDC Nov 8, 2024
1ebb896
Cross-product deepcopy fix (#106)
jardinetsouffleton Nov 8, 2024
f35dea0
slugify study_name (#114)
gasse Nov 11, 2024
c5dfb17
Improve timeout handling in task polling logic
recursix Nov 6, 2024
cf05bc6
Add method to override max_steps in Study class
recursix Nov 7, 2024
5297157
add support for tab visibility in observation flags and update relate…
recursix Nov 8, 2024
4279d5c
fix tests
recursix Nov 8, 2024
ad374fc
Fix sorting bug.
recursix Nov 8, 2024
ef33f1f
fix test
recursix Nov 8, 2024
f86b505
black
recursix Nov 8, 2024
1196455
Weblinx results (#104)
gasse Nov 12, 2024
5c8d627
Max new tokens fix (#118)
gasse Nov 12, 2024
6e61f84
Merge branch 'main' into dev
TLSDC Nov 12, 2024
b644474
version bump (#119)
TLSDC Nov 12, 2024
e695e11
fix format (#120)
TLSDC Nov 12, 2024
16e7526
Clean pipeline (#117)
recursix Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ jobs:
- name: Check MiniWob availability
run: curl -I "http://localhost:8080/miniwob/" || echo "MiniWob not reachable"

- name: Pre-download nltk ressources
run: python -c "import nltk; nltk.download('punkt_tab')"

- name: Run AgentLab Unit Tests
env:
MINIWOB_URL: "http://localhost:8080/miniwob/"
Expand Down
29 changes: 2 additions & 27 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,35 +161,10 @@ cython_debug/
**/.DS_Store

.vscode
allowed_selenium.json

# Torchtune
finetuning/torchtune

# PyLLMD repo for finetuning
pyllmd_tune/research-pyllmd/
pyllmd_tune/data/


datasets/*
_sandbox.py
node_modules/
/test-results/
/playwright-report/
/blob-report/
/playwright/.cache/
/test-results/
/playwright-report/
/blob-report/
/playwright/.cache/


results/

# personal (optimass)
ICML_deadline/
mass_utils/
pyllmd_tune/

# don't ignore the miniwob_tasks_all.csv file
!miniwob_tasks_all.csv
# gradio
.gradio/
28 changes: 16 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,28 @@
"""

import logging

from agentlab.agents.generic_agent import (
RANDOM_SEARCH_AGENT,
AGENT_4o,
AGENT_4o_MINI,
AGENT_LLAMA3_70B,
AGENT_LLAMA31_70B,
)
from agentlab.analyze.inspect_results import get_most_recent_folder
from agentlab.experiments import study_generators
from agentlab.experiments.study import Study

logging.getLogger().setLevel(logging.INFO)

# choose your agent or provide a new agent
agent_args = [AGENT_4o_MINI]
# agent_args = [AGENT_4o]

## select the benchmark to run on

# ## select the benchmark to run on
benchmark = "miniwob_tiny_test"
# benchmark = "miniwob"
# benchmark = "workarena.l1"
# benchmark = "workarena.l2"
# benchmark = "workarena.l3"
# benchmark = "workarena_l1"
# benchmark = "workarena_l2"
# benchmark = "workarena_l3"
# benchmark = "webarena"

# Set reproducibility_mode = True for reproducibility
Expand All @@ -53,13 +52,18 @@

if relaunch:
# relaunch an existing study
study_dir = get_most_recent_folder()
study = study_generators.make_relaunch_study(study_dir, relaunch_mode="incomplete_or_error")
study = Study.load_most_recent(contains=None)
study.find_incomplete(include_errors=True)

else:
study = study_generators.run_agents_on_benchmark(agent_args, benchmark)

study.run(n_jobs=n_jobs, parallel_backend="joblib", strict_reproducibility=reproducibility_mode)
study = Study(agent_args, benchmark, logging_level_stdout=logging.WARNING)

study.run(
n_jobs=n_jobs,
parallel_backend="ray",
strict_reproducibility=reproducibility_mode,
n_relaunch=3,
)

if reproducibility_mode:
study.append_to_journal(strict_reproducibility=True)
58 changes: 47 additions & 11 deletions reproducibility_journal.csv

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ contexttimer
ipython
pyyaml>=6
pandas
gradio
gradio>=5.5 # issue with DataFrame scrolling before 5.5
gitpython # for the reproducibility script
requests
requests
matplotlib
ray[default]
python-slugify
2 changes: 1 addition & 1 deletion src/agentlab/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.2"
__version__ = "0.3.0"
3 changes: 2 additions & 1 deletion src/agentlab/agents/agent_args.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from bgym import AbstractAgentArgs
import bgym


class AgentArgs(AbstractAgentArgs):

def set_benchmark(self, benchmark: str, demo_mode: bool):
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
"""Optional method to set benchmark specific flags.

This allows the agent to have minor adjustments based on the benchmark.
Expand Down
121 changes: 82 additions & 39 deletions src/agentlab/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import difflib
import logging
import platform
import time
Expand All @@ -9,12 +8,12 @@
from typing import Literal
from warnings import warn

import bgym
from browsergym.core.action.base import AbstractActionSet
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html

from agentlab.llm.llm_utils import (
BaseMessage,
ParseError,
count_tokens,
extract_code_blocks,
Expand Down Expand Up @@ -70,6 +69,7 @@ class ObsFlags(Flags):

use_html: bool = True
use_ax_tree: bool = False
use_tabs: bool = False
use_focused_element: bool = False
use_error_logs: bool = False
use_history: bool = False
Expand All @@ -94,13 +94,14 @@ class ObsFlags(Flags):

@dataclass
class ActionFlags(Flags):
multi_actions: bool = False
action_set: str = "bid"
is_strict: bool = False
demo_mode: Literal["off", "default", "all_blue", "only_visible_elements"] = "off"
action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method
long_description: bool = True
individual_examples: bool = False

# for backward compatibility
multi_actions: bool = None
is_strict: bool = None


class PromptElement:
"""Base class for all prompt elements. Prompt elements can be hidden."""
Expand All @@ -121,7 +122,7 @@ def __init__(self, visible: bool = True) -> None:
self._visible = visible

@property
def prompt(self):
def prompt(self) -> str | BaseMessage:
"""Avoid overriding this method. Override _prompt instead."""
if self.is_visible:
return self._prompt
Expand Down Expand Up @@ -252,7 +253,14 @@ def fit_tokens(
if isinstance(prompt, str):
prompt_str = prompt
elif isinstance(prompt, list):
# warn deprecated
warn(
"Using list of prompts is deprecated. Use a Discussion object instead.",
DeprecationWarning,
)
prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"])
elif isinstance(prompt, BaseMessage):
prompt_str = str(prompt)
else:
raise ValueError(f"Unrecognized type for prompt: {type(prompt)}")
n_token = count_tokens(prompt_str, model=model_name)
Expand Down Expand Up @@ -357,6 +365,29 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None:
"""


class Tabs(PromptElement):
def __init__(self, obs, visible: bool = True, prefix="") -> None:
super().__init__(visible=visible)
self.obs = obs
self.prefix = prefix

@property
def _prompt(self) -> str:
# by implementing this as a property, it's only coputed if visible
prompt_pieces = [f"\n{self.prefix}Currently open tabs:"]
for page_index, (page_url, page_title) in enumerate(
zip(self.obs["open_pages_urls"], self.obs["open_pages_titles"])
):
active_or_not = " (active tab)" if page_index == self.obs["active_page_index"] else ""
prompt_piece = f"""\
Tab {page_index}{active_or_not}:
Title: {page_title}
URL: {page_url}
"""
prompt_pieces.append(prompt_piece)
return "\n".join(prompt_pieces)


class Observation(Shrinkable):
"""Observation of the current step.

Expand All @@ -367,6 +398,13 @@ def __init__(self, obs, flags: ObsFlags) -> None:
super().__init__()
self.flags = flags
self.obs = obs

self.tabs = Tabs(
obs,
visible=lambda: flags.use_tabs,
prefix="## ",
)

self.html = HTML(
obs[flags.html_type],
visible_elements_only=flags.filter_visible_elements_only,
Expand Down Expand Up @@ -400,25 +438,18 @@ def shrink(self):
def _prompt(self) -> str:
return f"""
# Observation of current step:
{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt}
{self.tabs.prompt}{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt}

"""

def add_screenshot(self, prompt):
def add_screenshot(self, prompt: BaseMessage) -> BaseMessage:
if self.flags.use_screenshot:
if isinstance(prompt, str):
prompt = [{"type": "text", "text": prompt}]
if self.flags.use_som:
screenshot = self.obs["screenshot_som"]
else:
screenshot = self.obs["screenshot"]
img_url = image_to_jpg_base64_url(screenshot)
prompt.append(
{
"type": "image_url",
"image_url": {"url": img_url, "detail": self.flags.openai_vision_detail},
}
)
prompt.add_image(img_url, detail=self.flags.openai_vision_detail)
return prompt


Expand All @@ -441,24 +472,36 @@ def __init__(self, visible: bool = True) -> None:


class GoalInstructions(PromptElement):
def __init__(self, goal, visible: bool = True, extra_instructions=None) -> None:
def __init__(self, goal_object, visible: bool = True, extra_instructions=None) -> None:
super().__init__(visible)
self._prompt = f"""\
self._prompt = [
dict(
type="text",
text=f"""\
# Instructions
Review the current state of the page and all other information to find the best
possible next action to accomplish your goal. Your answer will be interpreted
and executed by a program, make sure to follow the formatting instructions.

## Goal:
{goal}
"""
""",
)
]

self._prompt += goal_object

if extra_instructions:
self._prompt += f"""
self._prompt += [
dict(
type="text",
text=f"""

## Extra instructions:

{extra_instructions}
"""
""",
)
]


class ChatInstructions(PromptElement):
Expand Down Expand Up @@ -592,24 +635,24 @@ def _parse_answer(self, text_answer):
return ans_dict


def make_action_set(action_flags: ActionFlags) -> AbstractActionSet:
# def make_action_set(action_flags: ActionFlags) -> AbstractActionSet:

if action_flags.action_set == "python":
action_set = PythonActionSet(strict=action_flags.is_strict)
if action_flags.demo_mode != "off":
warn(
f'Action_set "python" is incompatible with demo_mode={repr(action_flags.demo_mode)}.'
)
return action_set
# if action_flags.action_set == "python":
# action_set = PythonActionSet(strict=action_flags.is_strict)
# if action_flags.demo_mode != "off":
# warn(
# f'Action_set "python" is incompatible with demo_mode={repr(action_flags.demo_mode)}.'
# )
# return action_set

action_set = HighLevelActionSet(
subsets=list(set(["chat"] + ["infeas"] + action_flags.action_set.split("+"))),
multiaction=action_flags.multi_actions,
strict=action_flags.is_strict,
demo_mode=action_flags.demo_mode,
)
# action_set = HighLevelActionSet(
# subsets=list(set(["chat"] + ["infeas"] + action_flags.action_set.split("+"))),
# multiaction=action_flags.multi_actions,
# strict=action_flags.is_strict,
# demo_mode=action_flags.demo_mode,
# )

return action_set
# return action_set


class Think(PromptElement):
Expand Down
Loading