Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/agentlab/experiments/launch_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def run_experiments(
if parallel_backend == "joblib":
from joblib import Parallel, delayed

# split sequential
sequential_exp_args, exp_args_list = _split_sequential_exp(exp_args_list)

logging.info(
f"Running {len(sequential_exp_args)} in sequential first. The remaining {len(exp_args_list)} will be run in parallel."
)
for exp_args in sequential_exp_args:
exp_args.run()

Parallel(n_jobs=n_jobs, prefer="processes")(
delayed(exp_args.run)() for exp_args in exp_args_list
)
Expand Down Expand Up @@ -98,9 +107,12 @@ def find_incomplete(study_dir: str | Path, relaunch_mode="incomplete_only"):
)
exp_args_list = list(_yield_incomplete_experiments(study_dir, relaunch_mode=relaunch_mode))

# sort according to exp_args.order
exp_args_list.sort(key=lambda exp_args: exp_args.order if exp_args.order is not None else 0)

if len(exp_args_list) == 0:
logging.info(f"No incomplete experiments found in {study_dir}.")
return [], study_dir
return exp_args_list

message = f"Make sure the processes that were running are all stopped. Otherwise, "
f"there will be concurrent writing in the same directories.\n"
Expand Down Expand Up @@ -140,3 +152,16 @@ def split_path(path: str):
path = path.replace("/", ".")
module_name, obj_name = path.rsplit(".", 1)
return module_name, obj_name


def _split_sequential_exp(exp_args_list: list[ExpArgs]) -> tuple[list[ExpArgs], list[ExpArgs]]:
"""split exp_args that are flagged as sequential from those that are not"""
sequential_exp_args = []
parallel_exp_args = []
for exp_args in exp_args_list:
if getattr(exp_args, "sequential", False):
sequential_exp_args.append(exp_args)
else:
parallel_exp_args.append(exp_args)

return sequential_exp_args, parallel_exp_args
48 changes: 42 additions & 6 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,23 @@ def find_incomplete(self, relaunch_mode="incomplete_or_error"):
"""Find incomplete or errored experiments in the study directory for relaunching."""
self.exp_args_list = find_incomplete(self.dir, relaunch_mode=relaunch_mode)

def set_reproducibility_info(self, strict_reproducibility=False):
def set_reproducibility_info(self, strict_reproducibility=False, comment=None):
"""Gather relevant information that may affect the reproducibility of the experiment

e.g.: versions of BrowserGym, benchmark, AgentLab..."""
agent_names = [a.agent_name for a in self.agent_args]
info = repro.get_reproducibility_info(
agent_names, self.benchmark, self.uuid, ignore_changes=not strict_reproducibility
agent_names,
self.benchmark,
self.uuid,
ignore_changes=not strict_reproducibility,
comment=comment,
)
if self.reproducibility_info is not None:
repro.assert_compatible(self.reproducibility_info, info)
self.reproducibility_info = info

def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False):
def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False, comment=None):
"""Run all experiments in the study in parallel when possible.

Args:
Expand All @@ -98,7 +102,9 @@ def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False)
if self.exp_args_list is None:
raise ValueError("exp_args_list is None. Please set exp_args_list before running.")

self.set_reproducibility_info(strict_reproducibility=strict_reproducibility)
self.set_reproducibility_info(
strict_reproducibility=strict_reproducibility, comment=comment
)
self.save()

run_experiments(n_jobs, self.exp_args_list, self.dir, parallel_backend=parallel_backend)
Expand Down Expand Up @@ -172,6 +178,12 @@ def load(dir: Path) -> "Study":
with gzip.open(dir / "study.pkl.gz", "rb") as f:
study = pickle.load(f) # type: Study
study.dir = dir

# just a check
for i, exp_args in enumerate(study.exp_args_list):
if exp_args.order != i:
logging.warning("The order of the experiments is not correct.")

return study

@staticmethod
Expand Down Expand Up @@ -270,20 +282,44 @@ def _agents_on_benchmark(
if not isinstance(agents, (list, tuple)):
agents = [agents]

if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
if len(agents) > 1:
raise ValueError(
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
)

for agent in agents:
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark

env_args_list = benchmark.env_args_list
if demo_mode:
set_demo_mode(env_args_list)

return args.expand_cross_product(
exp_args_list = args.expand_cross_product(
ExpArgs(
agent_args=args.CrossProd(agents),
env_args=args.CrossProd(env_args_list),
logging_level=logging_level,
)
)
) # type: list[ExpArgs]

for i, exp_args in enumerate(exp_args_list):
exp_args.order = i

_flag_sequential_exp(exp_args_list, benchmark)

return exp_args_list


def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
if benchmark.name.startswith("visualwebarena"):
sequential_subset = benchmark.subset_from_glob("requires_reset", "True")
sequential_subset = set(
[env_args.task_name for env_args in sequential_subset.env_args_list]
)
for exp_args in exp_args_list:
if exp_args.env_args.task_name in sequential_subset:
exp_args.sequential = True


# def ablation_study(start_agent: AgentArgs, changes, benchmark: str, demo_mode=False):
Expand Down