diff --git a/src/agentlab/experiments/launch_exp.py b/src/agentlab/experiments/launch_exp.py index d62de9efa..f151f7ecb 100644 --- a/src/agentlab/experiments/launch_exp.py +++ b/src/agentlab/experiments/launch_exp.py @@ -56,6 +56,15 @@ def run_experiments( if parallel_backend == "joblib": from joblib import Parallel, delayed + # split sequential + sequential_exp_args, exp_args_list = _split_sequential_exp(exp_args_list) + + logging.info( + f"Running {len(sequential_exp_args)} in sequential first. The remaining {len(exp_args_list)} will be run in parallel." + ) + for exp_args in sequential_exp_args: + exp_args.run() + Parallel(n_jobs=n_jobs, prefer="processes")( delayed(exp_args.run)() for exp_args in exp_args_list ) @@ -98,9 +107,12 @@ def find_incomplete(study_dir: str | Path, relaunch_mode="incomplete_only"): ) exp_args_list = list(_yield_incomplete_experiments(study_dir, relaunch_mode=relaunch_mode)) + # sort according to exp_args.order + exp_args_list.sort(key=lambda exp_args: exp_args.order if exp_args.order is not None else 0) + if len(exp_args_list) == 0: logging.info(f"No incomplete experiments found in {study_dir}.") - return [], study_dir + return exp_args_list message = f"Make sure the processes that were running are all stopped. Otherwise, " f"there will be concurrent writing in the same directories.\n" @@ -140,3 +152,16 @@ def split_path(path: str): path = path.replace("/", ".") module_name, obj_name = path.rsplit(".", 1) return module_name, obj_name + + +def _split_sequential_exp(exp_args_list: list[ExpArgs]) -> tuple[list[ExpArgs], list[ExpArgs]]: + """split exp_args that are flagged as sequential from those that are not""" + sequential_exp_args = [] + parallel_exp_args = [] + for exp_args in exp_args_list: + if getattr(exp_args, "sequential", False): + sequential_exp_args.append(exp_args) + else: + parallel_exp_args.append(exp_args) + + return sequential_exp_args, parallel_exp_args diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index f1b823af4..2da9771fb 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -68,19 +68,23 @@ def find_incomplete(self, relaunch_mode="incomplete_or_error"): """Find incomplete or errored experiments in the study directory for relaunching.""" self.exp_args_list = find_incomplete(self.dir, relaunch_mode=relaunch_mode) - def set_reproducibility_info(self, strict_reproducibility=False): + def set_reproducibility_info(self, strict_reproducibility=False, comment=None): """Gather relevant information that may affect the reproducibility of the experiment e.g.: versions of BrowserGym, benchmark, AgentLab...""" agent_names = [a.agent_name for a in self.agent_args] info = repro.get_reproducibility_info( - agent_names, self.benchmark, self.uuid, ignore_changes=not strict_reproducibility + agent_names, + self.benchmark, + self.uuid, + ignore_changes=not strict_reproducibility, + comment=comment, ) if self.reproducibility_info is not None: repro.assert_compatible(self.reproducibility_info, info) self.reproducibility_info = info - def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): + def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False, comment=None): """Run all experiments in the study in parallel when possible. Args: @@ -98,7 +102,9 @@ def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False) if self.exp_args_list is None: raise ValueError("exp_args_list is None. Please set exp_args_list before running.") - self.set_reproducibility_info(strict_reproducibility=strict_reproducibility) + self.set_reproducibility_info( + strict_reproducibility=strict_reproducibility, comment=comment + ) self.save() run_experiments(n_jobs, self.exp_args_list, self.dir, parallel_backend=parallel_backend) @@ -172,6 +178,12 @@ def load(dir: Path) -> "Study": with gzip.open(dir / "study.pkl.gz", "rb") as f: study = pickle.load(f) # type: Study study.dir = dir + + # just a check + for i, exp_args in enumerate(study.exp_args_list): + if exp_args.order != i: + logging.warning("The order of the experiments is not correct.") + return study @staticmethod @@ -270,6 +282,12 @@ def _agents_on_benchmark( if not isinstance(agents, (list, tuple)): agents = [agents] + if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"): + if len(agents) > 1: + raise ValueError( + f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation." + ) + for agent in agents: agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark @@ -277,13 +295,31 @@ def _agents_on_benchmark( if demo_mode: set_demo_mode(env_args_list) - return args.expand_cross_product( + exp_args_list = args.expand_cross_product( ExpArgs( agent_args=args.CrossProd(agents), env_args=args.CrossProd(env_args_list), logging_level=logging_level, ) - ) + ) # type: list[ExpArgs] + + for i, exp_args in enumerate(exp_args_list): + exp_args.order = i + + _flag_sequential_exp(exp_args_list, benchmark) + + return exp_args_list + + +def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): + if benchmark.name.startswith("visualwebarena"): + sequential_subset = benchmark.subset_from_glob("requires_reset", "True") + sequential_subset = set( + [env_args.task_name for env_args in sequential_subset.env_args_list] + ) + for exp_args in exp_args_list: + if exp_args.env_args.task_name in sequential_subset: + exp_args.sequential = True # def ablation_study(start_agent: AgentArgs, changes, benchmark: str, demo_mode=False):