diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 7a0e312a6..3342ba540 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -58,6 +58,9 @@ jobs: - name: Check MiniWob availability run: curl -I "http://localhost:8080/miniwob/" || echo "MiniWob not reachable" + - name: Pre-download nltk ressources + run: python -c "import nltk; nltk.download('punkt_tab')" + - name: Run AgentLab Unit Tests env: MINIWOB_URL: "http://localhost:8080/miniwob/" diff --git a/.gitignore b/.gitignore index d0037afc9..aa26dc9dc 100644 --- a/.gitignore +++ b/.gitignore @@ -161,35 +161,10 @@ cython_debug/ **/.DS_Store .vscode -allowed_selenium.json -# Torchtune -finetuning/torchtune - -# PyLLMD repo for finetuning -pyllmd_tune/research-pyllmd/ -pyllmd_tune/data/ - - -datasets/* _sandbox.py -node_modules/ -/test-results/ -/playwright-report/ -/blob-report/ -/playwright/.cache/ -/test-results/ -/playwright-report/ -/blob-report/ -/playwright/.cache/ - results/ -# personal (optimass) -ICML_deadline/ -mass_utils/ -pyllmd_tune/ - -# don't ignore the miniwob_tasks_all.csv file -!miniwob_tasks_all.csv +# gradio +.gradio/ \ No newline at end of file diff --git a/main.py b/main.py index 7a038b6a9..4f9f57c84 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,6 @@ """ import logging - from agentlab.agents.generic_agent import ( RANDOM_SEARCH_AGENT, AGENT_4o, @@ -15,8 +14,7 @@ AGENT_LLAMA3_70B, AGENT_LLAMA31_70B, ) -from agentlab.analyze.inspect_results import get_most_recent_folder -from agentlab.experiments import study_generators +from agentlab.experiments.study import Study logging.getLogger().setLevel(logging.INFO) @@ -24,12 +22,13 @@ agent_args = [AGENT_4o_MINI] # agent_args = [AGENT_4o] -## select the benchmark to run on + +# ## select the benchmark to run on benchmark = "miniwob_tiny_test" # benchmark = "miniwob" -# benchmark = "workarena.l1" -# benchmark = "workarena.l2" -# benchmark = "workarena.l3" +# benchmark = "workarena_l1" +# benchmark = "workarena_l2" +# benchmark = "workarena_l3" # benchmark = "webarena" # Set reproducibility_mode = True for reproducibility @@ -53,13 +52,18 @@ if relaunch: # relaunch an existing study - study_dir = get_most_recent_folder() - study = study_generators.make_relaunch_study(study_dir, relaunch_mode="incomplete_or_error") + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) else: - study = study_generators.run_agents_on_benchmark(agent_args, benchmark) - - study.run(n_jobs=n_jobs, parallel_backend="joblib", strict_reproducibility=reproducibility_mode) + study = Study(agent_args, benchmark, logging_level_stdout=logging.WARNING) + + study.run( + n_jobs=n_jobs, + parallel_backend="ray", + strict_reproducibility=reproducibility_mode, + n_relaunch=3, + ) if reproducibility_mode: study.append_to_journal(strict_reproducibility=True) diff --git a/reproducibility_journal.csv b/reproducibility_journal.csv index df2ff7478..ad2bfaa81 100644 --- a/reproducibility_journal.csv +++ b/reproducibility_journal.csv @@ -1,12 +1,48 @@ -git_user,agent_name,benchmark,benchmark_version,date,avg_reward,std_err,n_err,n_completed,comment,os,python_version,playwright_version,agentlab_version,agentlab_git_hash,agentlab__local_modifications,browsergym_version,browsergym_git_hash,browsergym__local_modifications -recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob_tiny_test,0.6.3,2024-09-19_21-07-34,0.75,0.217,0,4/4,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,c99bdf74c98f323cc6a646467ba5f21154b6fd18,,0.6.4,b73531271d2ce688c104eb4dfba2819583f1ba36, -recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob_tiny_test,0.6.3,2024-09-19_21-28-58,1.0,0.0,0,4/4,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,c99bdf74c98f323cc6a646467ba5f21154b6fd18," M: reproducibility_journal.csv +git_user,agent_name,benchmark,benchmark_version,date,study_id,avg_reward,std_err,n_err,n_completed,comment,os,python_version,playwright_version,agentlab_version,agentlab_git_hash,agentlab__local_modifications,browsergym_version,browsergym_git_hash,browsergym__local_modifications +recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob_tiny_test,0.6.3,2024-09-19_21-07-34,,0.75,0.217,0,4/4,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,c99bdf74c98f323cc6a646467ba5f21154b6fd18,,0.6.4,b73531271d2ce688c104eb4dfba2819583f1ba36, +recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob_tiny_test,0.6.3,2024-09-19_21-28-58,,1.0,0.0,0,4/4,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,c99bdf74c98f323cc6a646467ba5f21154b6fd18," M: reproducibility_journal.csv M: src/agentlab/experiments/task_collections.py",0.6.4,b73531271d2ce688c104eb4dfba2819583f1ba36, -recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.6.3,2024-09-20_07-16-21,0.546,0.02,0,625/625,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,295f01005faf8f2c73a31be6a18cec19d563b54b,,0.6.4,b73531271d2ce688c104eb4dfba2819583f1ba36, -recursix,GenericAgent-gpt-4o-2024-05-13,miniwob,0.6.3,2024-09-20_22-09-43,0.656,0.019,0,625/625,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,f6216486d5faac2c8b3fb0a63e114e5a4bafde47,,0.6.4,8cef8fe34940ff490d0cc06b0c8f100180d09d43, -recursix,GenericAgent-gpt-4o-2024-05-13,miniwob,0.6.3,2024-09-21_12-04-39,0.656,0.019,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,fe561b93c5f053e9f9625358862f542523b5e14a,,0.7.0,ed6d6992ef64bfb91aca7002d33cb6ed5ec031ef, -recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.6.3,2024-10-01_11-45-23,0.539,0.02,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,fe27819a99b163fd9240ba3e144e010413bff24d,,0.7.1,b0ad675572e01cac0d7255100112de0828877148, -recursix,GenericAgent-gpt-4o-mini-2024-07-18,workarena.l1,0.3.2,2024-10-05_13-21-27,0.23,0.023,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,aadf86b397cd36c581e1a61e491aec649ac5a140, M: main.py,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, -recursix,GenericAgent-gpt-4o-2024-05-13,workarena.l1,0.3.2,2024-10-05_15-45-42,0.382,0.027,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,ab447e997af589bbd022de7a5189a7685ddfa6ef,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, -recursix,GenericAgent-meta-llama_llama-3.1-70b-instruct,miniwob_tiny_test,0.7.0,2024-10-05_17-49-15,1.0,0.0,0,4/4,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,a98fa24426a6ddde8443e8be44ed94cd9522e5ca,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, -recursix,GenericAgent-meta-llama_llama-3-70b-instruct,workarena.l1,0.3.2,2024-10-09_21-16-37,0.176,0.021,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,c847dbd334184271b32b252409a1b6c1042d7442,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.6.3,2024-09-20_07-16-21,,0.546,0.02,0,625/625,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,295f01005faf8f2c73a31be6a18cec19d563b54b,,0.6.4,b73531271d2ce688c104eb4dfba2819583f1ba36, +recursix,GenericAgent-gpt-4o-2024-05-13,miniwob,0.6.3,2024-09-20_22-09-43,,0.656,0.019,0,625/625,,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,f6216486d5faac2c8b3fb0a63e114e5a4bafde47,,0.6.4,8cef8fe34940ff490d0cc06b0c8f100180d09d43, +recursix,GenericAgent-gpt-4o-2024-05-13,miniwob,0.6.3,2024-09-21_12-04-39,,0.656,0.019,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,fe561b93c5f053e9f9625358862f542523b5e14a,,0.7.0,ed6d6992ef64bfb91aca7002d33cb6ed5ec031ef, +recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.6.3,2024-10-01_11-45-23,,0.539,0.02,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.2,1.39.0,0.2.1,fe27819a99b163fd9240ba3e144e010413bff24d,,0.7.1,b0ad675572e01cac0d7255100112de0828877148, +recursix,GenericAgent-gpt-4o-mini-2024-07-18,workarena.l1,0.3.2,2024-10-05_13-21-27,,0.23,0.023,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,aadf86b397cd36c581e1a61e491aec649ac5a140," M: main.py",0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +recursix,GenericAgent-gpt-4o-2024-05-13,workarena.l1,0.3.2,2024-10-05_15-45-42,,0.382,0.027,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,ab447e997af589bbd022de7a5189a7685ddfa6ef,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +recursix,GenericAgent-meta-llama_llama-3-70b-instruct,workarena.l1,0.3.2,2024-10-09_21-16-37,,0.176,0.021,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,c847dbd334184271b32b252409a1b6c1042d7442,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +recursix,GenericAgent-meta-llama_llama-3.1-70b-instruct,miniwob_tiny_test,0.7.0,2024-10-05_17-49-15,,1.0,0.0,0,4/4,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,a98fa24426a6ddde8443e8be44ed94cd9522e5ca,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.8.1,2024-10-17_10-13-28,,0.557,0.02,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.7,1.39.0,0.2.2,7bba275c004f1f90dfd83eaaab963ab5066e2baf,,0.8.1,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.8.1,2024-10-17_10-50-53,,0.563,0.02,0,625/625,None,Darwin (Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:00 PDT 2024; root:xnu-10063.141.2~1/RELEASE_X86_64),3.12.7,1.39.0,0.2.2,057b7d4a201cc1cd1ebd7bc884f6a91e104c479d,,0.8.1,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,workarena.l1,0.4.1,2024-10-17_17-30-43,,0.258,0.024,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.2,7bba275c004f1f90dfd83eaaab963ab5066e2baf,,0.8.1,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,workarena.l1,0.4.1,2024-10-17_18-30-28,,0.273,0.025,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.2,8b2b3f39a2bdb9efafad97791536a0b8cff4e708,,0.8.1,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,miniwob_all,0.9.0,2024-10-20_01-54-16,2024-10-20_01-54-02,0.588,0.014,0,1250/1250,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.2,1770eba87fabfe1e32cdf6078d71032fe00db736,,0.9.0,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini-2024-07-18,workarena_l1,0.4.1,2024-10-22_18-41-55,2024-10-22_15-24-53,0.215,0.023,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,a6c1f93c59fb7a838d06ca02ef6c62abe2ce278c,,0.9.0,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini,workarena_l1,0.4.1,2024-10-23_12-17-24,2024-10-23_02-00-49,0.252,0.024,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,d364ce7d5f566889830cdc0ef58b320d2093694e,,0.9.0,None, +ThibaultLSDC,GenericAgent-gpt-4o,workarena_l1,0.4.1,2024-10-23_12-17-24,2024-10-23_02-00-49,0.488,0.028,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,d364ce7d5f566889830cdc0ef58b320d2093694e,,0.9.0,None, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l1,0.4.1,2024-10-23_12-17-24,2024-10-23_02-00-49,0.579,0.027,2,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,d364ce7d5f566889830cdc0ef58b320d2093694e,,0.9.0,None, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,workarena_l1,0.4.1,2024-10-23_12-17-24,2024-10-23_02-00-49,0.309,0.025,2,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,d364ce7d5f566889830cdc0ef58b320d2093694e,,0.9.0,None, +ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,workarena_l1,0.4.1,2024-10-23_12-17-24,2024-10-23_02-00-49,0.527,0.027,2,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,d364ce7d5f566889830cdc0ef58b320d2093694e,,0.9.0,None, +ThibaultLSDC,GenericAgent-gpt-4o-mini,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.27,0.024,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724, +ThibaultLSDC,GenericAgent-gpt-4o,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.455,0.027,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.564,0.027,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.279,0.025,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724, +ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.567,0.027,4,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724, +recursix,GenericAgent-anthropic_claude-3.5-sonnet:beta,webarena,0.11.3,2024-11-02_23-50-17,22a9d3f5-9d86-455e-b451-3ea17690ce8a,0.329,0.016,0,812/812,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.3,418a05d90c74800cd66371b7846ef861185b8c47,,0.11.3,160167ff0d2631826f0131e8e30b92ef448d6881, +ThibaultLSDC,GenericAgent-gpt-4o-mini,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.013,0.007,2,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None, +ThibaultLSDC,GenericAgent-gpt-4o,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.085,0.018,3,233/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.391,0.032,3,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.021,0.009,2,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None, +ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.149,0.023,1,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l3_agent_curriculum_eval,0.4.1,2024-10-24_23-03-30,2024-10-24_18-06-57,0.004,0.004,1,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,de67ed8ad4321740ff05cf26ab889978be706460,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-gpt-4o-mini,miniwob,0.10.2,2024-10-25_17-16-23,2024-10-25_06-08-16,0.566,0.02,0,625/625,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,f12887f776525bcad6a0c42cb49651ff4f65af43,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-gpt-4o,miniwob,0.10.2,2024-10-25_17-16-23,2024-10-25_06-08-16,0.638,0.019,0,625/625,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,f12887f776525bcad6a0c42cb49651ff4f65af43,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,miniwob,0.10.2,2024-10-25_17-16-23,2024-10-25_06-08-16,0.698,0.018,0,625/625,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,f12887f776525bcad6a0c42cb49651ff4f65af43,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,miniwob,0.10.2,2024-10-25_17-16-23,2024-10-25_06-08-16,0.576,0.02,0,625/625,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,f12887f776525bcad6a0c42cb49651ff4f65af43,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,miniwob,0.10.2,2024-10-25_17-16-23,2024-10-25_06-08-16,0.678,0.019,0,625/625,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,f12887f776525bcad6a0c42cb49651ff4f65af43,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-405b-instruct,workarena_l1,0.4.1,2024-10-25_20-32-26,2024-10-25_17-34-45,0.433,0.027,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,177ba72a7469e5610e6b615adf1bdcde58cb0298,,0.10.2,a9e44a88139798543ba53fc8c45d44997665ccca, +Maxime Gasse,GenericAgent-gpt-4o-2024-05-13,weblinx_test,0.0.1.dev13,2024-11-04_16-01-14,2024-11-04_15-59-12,0.123,0.006,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.0,1.39.0,0.2.3,6e18fb818a64ec1e3f379c1a6480411d2fd0628b,,0.11.3,3ab1843edb14bfce7d39485f0106d0dc0c2d7486, +ThibaultLSDC,GenericAgent-gpt-4o-mini,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.116,0.006,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, +ThibaultLSDC,GenericAgent-gpt-4o,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.125,0.006,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, +ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.137,0.006,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.089,0.005,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, +ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.125,0.006,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, +ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-405b-instruct,weblinx_test,0.0.1.dev13,2024-11-07_21-42-30,b9451759-4f0e-492c-a3c8-fa5109d2d9b1,0.079,0.005,0,2650/2650,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,7a5b91e62056fa8fb26efdd2f64f5b25a92b817c,,0.12.0,8633c30c31e6a5a1d5122835c035aa56d18f3f0a, diff --git a/requirements.txt b/requirements.txt index e96fa61ee..453f312d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,9 @@ contexttimer ipython pyyaml>=6 pandas -gradio +gradio>=5.5 # issue with DataFrame scrolling before 5.5 gitpython # for the reproducibility script -requests \ No newline at end of file +requests +matplotlib +ray[default] +python-slugify diff --git a/src/agentlab/__init__.py b/src/agentlab/__init__.py index b5fdc7530..493f7415d 100644 --- a/src/agentlab/__init__.py +++ b/src/agentlab/__init__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.3.0" diff --git a/src/agentlab/agents/agent_args.py b/src/agentlab/agents/agent_args.py index 0e0d6d8b9..40810f6b8 100644 --- a/src/agentlab/agents/agent_args.py +++ b/src/agentlab/agents/agent_args.py @@ -1,9 +1,10 @@ from bgym import AbstractAgentArgs +import bgym class AgentArgs(AbstractAgentArgs): - def set_benchmark(self, benchmark: str, demo_mode: bool): + def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool): """Optional method to set benchmark specific flags. This allows the agent to have minor adjustments based on the benchmark. diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 47aed2264..73688f0f4 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -1,5 +1,4 @@ import abc -import difflib import logging import platform import time @@ -9,12 +8,12 @@ from typing import Literal from warnings import warn +import bgym from browsergym.core.action.base import AbstractActionSet -from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.core.action.python import PythonActionSet from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html from agentlab.llm.llm_utils import ( + BaseMessage, ParseError, count_tokens, extract_code_blocks, @@ -70,6 +69,7 @@ class ObsFlags(Flags): use_html: bool = True use_ax_tree: bool = False + use_tabs: bool = False use_focused_element: bool = False use_error_logs: bool = False use_history: bool = False @@ -94,13 +94,14 @@ class ObsFlags(Flags): @dataclass class ActionFlags(Flags): - multi_actions: bool = False - action_set: str = "bid" - is_strict: bool = False - demo_mode: Literal["off", "default", "all_blue", "only_visible_elements"] = "off" + action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method long_description: bool = True individual_examples: bool = False + # for backward compatibility + multi_actions: bool = None + is_strict: bool = None + class PromptElement: """Base class for all prompt elements. Prompt elements can be hidden.""" @@ -121,7 +122,7 @@ def __init__(self, visible: bool = True) -> None: self._visible = visible @property - def prompt(self): + def prompt(self) -> str | BaseMessage: """Avoid overriding this method. Override _prompt instead.""" if self.is_visible: return self._prompt @@ -252,7 +253,14 @@ def fit_tokens( if isinstance(prompt, str): prompt_str = prompt elif isinstance(prompt, list): + # warn deprecated + warn( + "Using list of prompts is deprecated. Use a Discussion object instead.", + DeprecationWarning, + ) prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"]) + elif isinstance(prompt, BaseMessage): + prompt_str = str(prompt) else: raise ValueError(f"Unrecognized type for prompt: {type(prompt)}") n_token = count_tokens(prompt_str, model=model_name) @@ -357,6 +365,29 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None: """ +class Tabs(PromptElement): + def __init__(self, obs, visible: bool = True, prefix="") -> None: + super().__init__(visible=visible) + self.obs = obs + self.prefix = prefix + + @property + def _prompt(self) -> str: + # by implementing this as a property, it's only coputed if visible + prompt_pieces = [f"\n{self.prefix}Currently open tabs:"] + for page_index, (page_url, page_title) in enumerate( + zip(self.obs["open_pages_urls"], self.obs["open_pages_titles"]) + ): + active_or_not = " (active tab)" if page_index == self.obs["active_page_index"] else "" + prompt_piece = f"""\ +Tab {page_index}{active_or_not}: + Title: {page_title} + URL: {page_url} +""" + prompt_pieces.append(prompt_piece) + return "\n".join(prompt_pieces) + + class Observation(Shrinkable): """Observation of the current step. @@ -367,6 +398,13 @@ def __init__(self, obs, flags: ObsFlags) -> None: super().__init__() self.flags = flags self.obs = obs + + self.tabs = Tabs( + obs, + visible=lambda: flags.use_tabs, + prefix="## ", + ) + self.html = HTML( obs[flags.html_type], visible_elements_only=flags.filter_visible_elements_only, @@ -400,25 +438,18 @@ def shrink(self): def _prompt(self) -> str: return f""" # Observation of current step: -{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} +{self.tabs.prompt}{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} """ - def add_screenshot(self, prompt): + def add_screenshot(self, prompt: BaseMessage) -> BaseMessage: if self.flags.use_screenshot: - if isinstance(prompt, str): - prompt = [{"type": "text", "text": prompt}] if self.flags.use_som: screenshot = self.obs["screenshot_som"] else: screenshot = self.obs["screenshot"] img_url = image_to_jpg_base64_url(screenshot) - prompt.append( - { - "type": "image_url", - "image_url": {"url": img_url, "detail": self.flags.openai_vision_detail}, - } - ) + prompt.add_image(img_url, detail=self.flags.openai_vision_detail) return prompt @@ -441,24 +472,36 @@ def __init__(self, visible: bool = True) -> None: class GoalInstructions(PromptElement): - def __init__(self, goal, visible: bool = True, extra_instructions=None) -> None: + def __init__(self, goal_object, visible: bool = True, extra_instructions=None) -> None: super().__init__(visible) - self._prompt = f"""\ + self._prompt = [ + dict( + type="text", + text=f"""\ # Instructions Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. ## Goal: -{goal} -""" +""", + ) + ] + + self._prompt += goal_object + if extra_instructions: - self._prompt += f""" + self._prompt += [ + dict( + type="text", + text=f""" ## Extra instructions: {extra_instructions} -""" +""", + ) + ] class ChatInstructions(PromptElement): @@ -592,24 +635,24 @@ def _parse_answer(self, text_answer): return ans_dict -def make_action_set(action_flags: ActionFlags) -> AbstractActionSet: +# def make_action_set(action_flags: ActionFlags) -> AbstractActionSet: - if action_flags.action_set == "python": - action_set = PythonActionSet(strict=action_flags.is_strict) - if action_flags.demo_mode != "off": - warn( - f'Action_set "python" is incompatible with demo_mode={repr(action_flags.demo_mode)}.' - ) - return action_set +# if action_flags.action_set == "python": +# action_set = PythonActionSet(strict=action_flags.is_strict) +# if action_flags.demo_mode != "off": +# warn( +# f'Action_set "python" is incompatible with demo_mode={repr(action_flags.demo_mode)}.' +# ) +# return action_set - action_set = HighLevelActionSet( - subsets=list(set(["chat"] + ["infeas"] + action_flags.action_set.split("+"))), - multiaction=action_flags.multi_actions, - strict=action_flags.is_strict, - demo_mode=action_flags.demo_mode, - ) +# action_set = HighLevelActionSet( +# subsets=list(set(["chat"] + ["infeas"] + action_flags.action_set.split("+"))), +# multiaction=action_flags.multi_actions, +# strict=action_flags.is_strict, +# demo_mode=action_flags.demo_mode, +# ) - return action_set +# return action_set class Think(PromptElement): diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index 4c0a39a74..a5db8c906 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -1,3 +1,5 @@ +import bgym + from agentlab.agents import dynamic_prompting as dp from agentlab.experiments import args from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT @@ -25,8 +27,10 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - multi_actions=False, - action_set="bid", + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), long_description=False, individual_examples=True, ), @@ -38,7 +42,7 @@ use_abstract_example=True, use_hints=True, enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, ) @@ -71,8 +75,10 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - multi_actions=False, # often detrimental - action_set="bid", + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), long_description=False, individual_examples=True, ), @@ -84,7 +90,7 @@ use_abstract_example=True, # useful use_hints=True, # useful enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, ) @@ -116,8 +122,10 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - multi_actions=False, - action_set="bid", + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), long_description=False, individual_examples=True, ), @@ -129,7 +137,7 @@ use_abstract_example=True, use_hints=True, enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, add_missparsed_messages=True, @@ -164,8 +172,10 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - multi_actions=True, - action_set="bid", + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=True, + ), long_description=False, individual_examples=True, ), @@ -177,7 +187,7 @@ use_abstract_example=True, use_hints=True, enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, add_missparsed_messages=True, @@ -210,8 +220,10 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - multi_actions=False, - action_set="bid", + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), long_description=False, individual_examples=False, ), @@ -223,7 +235,7 @@ use_abstract_example=True, use_hints=True, enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, ) @@ -270,10 +282,12 @@ filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]), ), action=dp.ActionFlags( - multi_actions=args.Choice([True, False], p=[0.7, 0.3]), - action_set=args.Choice(["bid", "bid+coord"]), - # action_set=args.Choice(["python", "bid", "coord", - # "bid+coord"]), + action_set=bgym.HighLevelActionSetArgs( + subsets=args.Choice([["bid"], ["bid", "coord"]]), + multiaction=args.Choice([True, False], p=[0.7, 0.3]), + ), + long_description=False, + individual_examples=False, ), # drop_ax_tree_first=True, # this flag is no longer active, according to browsergym doc use_plan=args.Choice([True, False]), @@ -285,7 +299,7 @@ use_hints=args.Choice([True, False], p=[0.7, 0.3]), be_cautious=args.Choice([True, False]), enable_chat=False, - max_prompt_tokens=None, + max_prompt_tokens=40_000, extra_instructions=None, ) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 7c65e3cd6..98026dc1f 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -1,13 +1,15 @@ +from copy import deepcopy from dataclasses import asdict, dataclass from functools import partial from warnings import warn +import bgym from browsergym.experiments.agent import Agent, AgentInfo from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message -from agentlab.llm.llm_utils import ParseError, retry +from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt @@ -25,13 +27,23 @@ def __post_init__(self): except AttributeError: pass - def set_benchmark(self, benchmark, demo_mode): + def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): """Override Some flags based on the benchmark.""" - if benchmark == "miniwob": + if benchmark.name.startswith("miniwob"): self.flags.obs.use_html = True + self.flags.obs.use_tabs = benchmark.is_multi_tab + self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args) + + # for backward compatibility with old traces + if self.flags.action.multi_actions is not None: + self.flags.action.action_set.multiaction = self.flags.action.multi_actions + if self.flags.action.is_strict is not None: + self.flags.action.action_set.strict = self.flags.action.is_strict + + # verify if we can remove this if demo_mode: - self.flags.action.demo_mode = "all_blue" + self.action_set.demo_mode = "all_blue" def set_reproducibility_mode(self): self.chat_model_args.temperature = 0 @@ -62,7 +74,7 @@ def __init__( self.max_retry = max_retry self.flags = flags - self.action_set = dp.make_action_set(self.flags.action) + self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) self._check_flag_constancy() @@ -88,9 +100,9 @@ def get_action(self, obs): max_prompt_tokens, max_trunc_itr = self._get_maxes() - system_prompt = dp.SystemPrompt().prompt + system_prompt = SystemMessage(dp.SystemPrompt().prompt) - prompt = dp.fit_tokens( + human_prompt = dp.fit_tokens( shrinkable=main_prompt, max_prompt_tokens=max_prompt_tokens, model_name=self.chat_model_args.model_name, @@ -101,10 +113,7 @@ def get_action(self, obs): # TODO, we would need to further shrink the prompt if the retry # cause it to be too long - chat_messages = [ - make_system_message(system_prompt), - make_user_message(prompt), - ] + chat_messages = Discussion([system_prompt, human_prompt]) ans_dict = retry( self.chat_llm, chat_messages, diff --git a/src/agentlab/agents/generic_agent/generic_agent_prompt.py b/src/agentlab/agents/generic_agent/generic_agent_prompt.py index 81450847b..67899f182 100644 --- a/src/agentlab/agents/generic_agent/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent/generic_agent_prompt.py @@ -1,9 +1,11 @@ -from dataclasses import dataclass import logging +from dataclasses import dataclass + from browsergym.core import action from browsergym.core.action.base import AbstractActionSet + from agentlab.agents import dynamic_prompting as dp -from agentlab.llm.llm_utils import parse_html_tags_raise +from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise @dataclass @@ -69,17 +71,20 @@ def __init__( "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) self.instructions = dp.GoalInstructions( - obs_history[-1]["goal"], extra_instructions=flags.extra_instructions + obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions ) - self.obs = dp.Observation(obs_history[-1], self.flags.obs) + self.obs = dp.Observation( + obs_history[-1], + self.flags.obs, + ) self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) def time_for_caution(): # no need for caution if we're in single action mode return flags.be_cautious and ( - flags.action.multi_actions or flags.action.action_set == "python" + flags.action.action_set.multiaction or flags.action.action_set == "python" ) self.be_cautious = dp.BeCautious(visible=time_for_caution) @@ -90,9 +95,10 @@ def time_for_caution(): self.memory = Memory(visible=lambda: flags.use_memory) @property - def _prompt(self) -> str: - prompt = f"""\ -{self.instructions.prompt}\ + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + prompt.add_text( + f"""\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ @@ -103,9 +109,11 @@ def _prompt(self) -> str: {self.memory.prompt}\ {self.criticise.prompt}\ """ + ) if self.flags.use_abstract_example: - prompt += f""" + prompt.add_text( + f""" # Abstract Example Here is an abstract version of the answer with description of the content of @@ -117,9 +125,11 @@ def _prompt(self) -> str: {self.criticise.abstract_ex}\ {self.action_prompt.abstract_ex}\ """ + ) if self.flags.use_concrete_example: - prompt += f""" + prompt.add_text( + f""" # Concrete Example Here is a concrete example of how to format your answer. @@ -130,6 +140,7 @@ def _prompt(self) -> str: {self.criticise.concrete_ex}\ {self.action_prompt.concrete_ex}\ """ + ) return self.obs.add_screenshot(prompt) def shrink(self): @@ -242,77 +253,3 @@ class Criticise(dp.PromptElement): def _parse_answer(self, text_answer): return parse_html_tags_raise(text_answer, optional_keys=["action_draft", "criticise"]) - - -if __name__ == "__main__": - html_template = """ - - -
- Hello World. - Step {}. -
- - - """ - - OBS_HISTORY = [ - { - "goal": "do this and that", - "pruned_html": html_template.format(1), - "axtree_txt": "[1] Click me", - "last_action_error": "", - "focused_element_bid": "32", - }, - { - "goal": "do this and that", - "pruned_html": html_template.format(2), - "axtree_txt": "[1] Click me", - "last_action_error": "", - "focused_element_bid": "32", - }, - { - "goal": "do this and that", - "pruned_html": html_template.format(3), - "axtree_txt": "[1] Click me", - "last_action_error": "Hey, there is an error now", - "focused_element_bid": "32", - }, - ] - ACTIONS = ["click('41')", "click('42')"] - MEMORIES = ["memory A", "memory B"] - THOUGHTS = ["thought A", "thought B"] - - flags = dp.ObsFlags( - use_html=True, - use_ax_tree=True, - use_plan=True, - use_criticise=True, - use_thinking=True, - use_error_logs=True, - use_past_error_logs=True, - use_history=True, - use_action_history=True, - use_memory=True, - use_diff=True, - html_type="pruned_html", - use_concrete_example=True, - use_abstract_example=True, - multi_actions=True, - use_screenshot=False, - ) - - print( - MainPrompt( - action_set=dp.make_action_set( - "bid", is_strict=False, multiaction=True, demo_mode="off" - ), - obs_history=OBS_HISTORY, - actions=ACTIONS, - memories=MEMORIES, - thoughts=THOUGHTS, - previous_plan="No plan yet", - step=0, - flags=flags, - ).prompt - ) diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index c197b76e9..ffec1111a 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from pathlib import Path +import bgym from browsergym.experiments.agent import AgentInfo from browsergym.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results from bs4 import BeautifulSoup @@ -24,9 +25,10 @@ from langchain_community.adapters.openai import convert_message_to_dict from agentlab.agents.agent_args import AgentArgs -from agentlab.experiments.study_generators import Study +from agentlab.agents.dynamic_prompting import ActionFlags +from agentlab.experiments.study import Study from agentlab.llm.chat_api import make_assistant_message -from agentlab.llm.llm_utils import messages_to_dict +from agentlab.llm.llm_utils import Discussion, messages_to_dict from .generic_agent import GenericAgent, GenericAgentArgs @@ -43,7 +45,7 @@ def __init__(self, old_messages, delay=1) -> None: self.old_messages = old_messages self.delay = delay - def __call__(self, messages: list): + def __call__(self, messages: list | Discussion): self.new_messages = copy(messages) if len(messages) >= len(self.old_messages): @@ -95,7 +97,7 @@ def get_action(self, obs): # same answers step = len(self.actions) step_info = self.exp_result.get_step_info(step) - old_chat_messages = step_info.agent_info.get("chat_messages", None) + old_chat_messages = step_info.agent_info.get("chat_messages", None) # type: Discussion if old_chat_messages is None: err_msg = self.exp_result.summary_info["err_msg"] @@ -135,20 +137,35 @@ def _make_agent_stats(action, agent_info, step_info, old_chat_messages, new_chat def _format_messages(messages: list[dict]): + if isinstance(messages, Discussion): + return messages.to_string() messages = messages_to_dict(messages) return "\n".join(f"{m['role']} message:\n{m['content']}\n" for m in messages) +def _make_backward_compatible(agent_args: GenericAgentArgs): + action_set = agent_args.flags.action.action_set + if isinstance(action_set, (str, list)): + if isinstance(action_set, str): + action_set = action_set.split("+") + + agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs( + subsets=action_set, + multiaction=agent_args.flags.action.multi_actions, + ) + + return agent_args + + def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO): """Reproduce a study by running the same experiments with the same agent.""" original_study_dir = Path(original_study_dir) - study_name = f"reproducibility_of_{original_study_dir.name}" - exp_args_list: list[ExpArgs] = [] for exp_result in yield_all_exp_results(original_study_dir, progress_fn=None): - agent_args = make_repro_agent(exp_result.exp_args.agent_args, exp_dir=exp_result.exp_dir) + agent_args = _make_backward_compatible(exp_result.exp_args.agent_args) + agent_args = make_repro_agent(agent_args, exp_dir=exp_result.exp_dir) exp_args_list.append( ExpArgs( agent_args=agent_args, @@ -156,13 +173,18 @@ def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO): logging_level=log_level, ) ) + + # infer benchmark name from task list for backward compatible benchmark_name = exp_args_list[0].env_args.task_name.split(".")[0] - return Study( - exp_args_list=exp_args_list, - benchmark_name=benchmark_name, - agent_names=[agent_args.agent_name], + study = Study( + benchmark=benchmark_name, + agent_args=[agent_args], ) + # this exp_args_list has a different agent_args for each experiment as repro_agent takes the exp_dir as argument + # so we overwrite exp_args_list with the one we created above + study.exp_args_list = exp_args_list + return study def make_repro_agent(agent_args: AgentArgs, exp_dir: Path | str): diff --git a/src/agentlab/agents/generic_agent/tmlr_config.py b/src/agentlab/agents/generic_agent/tmlr_config.py new file mode 100644 index 000000000..48a28c682 --- /dev/null +++ b/src/agentlab/agents/generic_agent/tmlr_config.py @@ -0,0 +1,72 @@ +from copy import deepcopy + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags + +BASE_FLAGS = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, # gpt-4o config except for this line + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + multi_actions=False, + action_set="bid", + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +def get_base_agent(llm_config: str): + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=BASE_FLAGS, + ) + + +def get_vision_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) + + +def get_som_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + flags.obs.use_som = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) diff --git a/src/agentlab/agents/most_basic_agent/most_basic_agent.py b/src/agentlab/agents/most_basic_agent/most_basic_agent.py index 2e0cfcbe0..9da6d9368 100644 --- a/src/agentlab/agents/most_basic_agent/most_basic_agent.py +++ b/src/agentlab/agents/most_basic_agent/most_basic_agent.py @@ -4,11 +4,18 @@ import bgym +from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT -from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry +from agentlab.llm.llm_utils import ( + Discussion, + HumanMessage, + ParseError, + SystemMessage, + extract_code_blocks, + retry, +) from agentlab.llm.tracking import cost_tracker_decorator -from agentlab.agents.agent_args import AgentArgs if TYPE_CHECKING: from agentlab.llm.chat_api import BaseModelArgs @@ -51,25 +58,25 @@ def __init__( @cost_tracker_decorator def get_action(self, obs: Any) -> tuple[str, dict]: - system_prompt = f""" -You are a web assistant. -""" - prompt = f""" + messages = Discussion(SystemMessage("You are a web assistant.")) + messages.append( + HumanMessage( + f""" You are helping a user to accomplish the following goal on a website: {obs["goal"]} -Here is the current state of the website, in the form of an html: - -{obs["dom_txt"]} - To do so, you can interact with the environment using the following actions: {self.action_set.describe(with_long_description=False)} The inputs to those functions are the bids given in the html. -The action you provide must be in between triple ticks. +Here is the current state of the website, in the form of an html: + +{obs["pruned_html"]} + +The action you provide must be in between triple ticks and leverage the 'bid=' information provided in the html. Here is an example of how to use the bid action: ``` @@ -79,15 +86,14 @@ def get_action(self, obs: Any) -> tuple[str, dict]: Please provide a single action at a time and wait for the next observation. Provide only a single action per step. Focus on the bid that are given in the html, and use them to perform the actions. """ + ) + ) if self.use_chain_of_thought: - prompt += f""" + messages.add_text( + f""" Provide a chain of thoughts reasoning to decompose the task into smaller steps. And execute only the next step. """ - - messages = [ - make_system_message(system_prompt), - make_user_message(prompt), - ] + ) def parser(response: str) -> tuple[dict, bool, str]: blocks = extract_code_blocks(response) @@ -108,7 +114,7 @@ def parser(response: str) -> tuple[dict, bool, str]: think=thought, chat_messages=messages, # put any stats that you care about as long as it is a number or a dict of numbers - stats={"prompt_length": len(prompt), "response_length": len(thought)}, + stats={"prompt_length": len(messages), "response_length": len(thought)}, markdown_page="Add any txt information here, including base 64 images, to display in xray", extra_info={"chat_model_args": asdict(self.chat_model_args)}, ), @@ -147,6 +153,12 @@ def parser(response: str) -> tuple[dict, bool, str]: ), ] +AGENT_4o_MINI = MostBasicAgentArgs( + temperature=0.3, + use_chain_of_thought=True, + chat_model_args=chat_model_args, +) + def experiment_config(): return exp_args diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 228901b39..da7e98d39 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -19,8 +19,9 @@ from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR +from agentlab.experiments.study import get_most_recent_study from agentlab.llm.chat_api import make_system_message, make_user_message -from agentlab.llm.llm_utils import image_to_jpg_base64_url +from agentlab.llm.llm_utils import Discussion select_dir_instructions = "Select Experiment Directory" AGENT_NAME_KEY = "agent.agent_name" @@ -141,6 +142,10 @@ def filter_agent_id(self, agent_id: list[tuple]): max-height: 400px; overflow-y: auto; } +.error-report { + max-height: 700px; + overflow-y: auto; +} .my-code-view { max-height: 300px; overflow-y: auto; @@ -183,8 +188,6 @@ def run_gradio(results_dir: Path): 2. **Select Task**: Select the task you want to analyze, this will trigger an update of the available seeds. - **IMPORTANT NOTE**: Due to a gradio bug, if you sort the columns of the table, the task - selection will not correspond to the right one. 3. **Select the Seed**: You might have multiple repetition for a given task, you will be able to select the seed you want to analyze. @@ -215,12 +218,11 @@ def run_gradio(results_dir: Path): """\ Click on a row to select an agent. It will trigger the update of other fields. - - **GRADIO BUG**: If you sort the columns the click will not match the - content. You have to sort back with the Idx column to align the click with - the order.""" + + The update mechanism is somewhat flacky, please help figure out why (or is it just gradio?). + """ ) - agent_table = gr.DataFrame(height=500, show_label=False, interactive=False) + agent_table = gr.DataFrame(max_height=500, show_label=False, interactive=False) with gr.Tab("Select Task and Seed", id="Select Task"): with gr.Row(): with gr.Column(scale=4): @@ -230,13 +232,17 @@ def run_gradio(results_dir: Path): """\ Click on a row to select a task. It will trigger the update of other fields. - **GRADIO BUG**: If you sort the columns the click will not match the - content. You have to sort back with the Idx column to align the click with - the order.""" + The update mechanism is somewhat flacky, please help figure out why (or is it just gradio?). + """ ) refresh_results_button = gr.Button("↺", scale=0, size="sm") - task_table = gr.DataFrame(height=500, show_label=False, interactive=False) + task_table = gr.DataFrame( + max_height=500, + show_label=False, + interactive=False, + elem_id="task_table", + ) with gr.Column(scale=2): with gr.Accordion("Seed Selector (click for help)", open=False): @@ -244,12 +250,16 @@ def run_gradio(results_dir: Path): """\ Click on a row to select a seed. It will trigger the update of other fields. - **GRADIO BUG**: If you sort the columns the click will not match the - content. You have to sort back with the Idx column to align the click with - the order.""" + The update mechanism is somewhat flacky, please help figure out why (or is it just gradio?). + """ ) - seed_table = gr.DataFrame(height=500, show_label=False, interactive=False) + seed_table = gr.DataFrame( + max_height=500, + show_label=False, + interactive=False, + elem_id="seed_table", + ) with gr.Tab("Constants and Variables"): with gr.Row(): @@ -261,7 +271,9 @@ def run_gradio(results_dir: Path): **all** agents. They are displayed as a table with the name and value of the constant.""" ) - constants = gr.DataFrame(height=500, show_label=False, interactive=False) + constants = gr.DataFrame( + max_height=500, show_label=False, interactive=False + ) with gr.Column(scale=2): with gr.Accordion("Variables", open=False): gr.Markdown( @@ -270,10 +282,14 @@ def run_gradio(results_dir: Path): They are displayed as a table with the name, value and count of unique values. A maximum of 3 different values are displayed.""" ) - variables = gr.DataFrame(height=500, show_label=False, interactive=False) + variables = gr.DataFrame( + max_height=500, show_label=False, interactive=False + ) with gr.Tab("Global Stats"): - global_stats = gr.DataFrame(height=500, show_label=False, interactive=False) + global_stats = gr.DataFrame(max_height=500, show_label=False, interactive=False) + with gr.Tab("Error Report"): + error_report = gr.Markdown(elem_classes="error-report", show_copy_button=True) with gr.Row(): episode_info = gr.Markdown(label="Episode Info", elem_classes="my-markdown") action_info = gr.Markdown(label="Action Info", elem_classes="my-markdown") @@ -345,7 +361,7 @@ def run_gradio(results_dir: Path): logs = gr.Code(language=None, **code_args) with gr.Tab("Stats") as tab_stats: - stats = gr.DataFrame(height=500, show_label=False, interactive=False) + stats = gr.DataFrame(max_height=500, show_label=False, interactive=False) with gr.Tab("Agent Info HTML") as tab_agent_info_html: with gr.Row(): @@ -401,7 +417,7 @@ def run_gradio(results_dir: Path): exp_dir_choice.change( fn=new_exp_dir, inputs=exp_dir_choice, - outputs=[agent_table, agent_id, constants, variables, global_stats], + outputs=[agent_table, agent_id, constants, variables, global_stats, error_report], ) agent_table.select(fn=on_select_agent, inputs=agent_table, outputs=[agent_id]) @@ -482,7 +498,9 @@ def run_gradio(results_dir: Path): tabs.select(tab_select) demo.queue() - demo.launch(server_port=int(os.getenv("AGENTXRAY_APP_PORT", 7899)), share=True) + + do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true" + demo.launch(server_port=int(os.getenv("AGENTXRAY_APP_PORT", "7899")), share=do_share) def tab_select(evt: gr.SelectData): @@ -569,7 +587,9 @@ def update_chat_messages(): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) - messages = [] + if isinstance(chat_messages, Discussion): + return chat_messages.to_markdown() + messages = [] # TODO(ThibaultLSDC) remove this at some point for i, m in enumerate(chat_messages): if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated m = m.content @@ -805,22 +825,22 @@ def extract_columns(row: pd.Series): ) seed_df = result_df.apply(extract_columns, axis=1) - seed_df["Idx"] = seed_df.index return seed_df def on_select_agent(evt: gr.SelectData, df: pd.DataFrame): - global info + # TODO try to find a clever way to solve the sort bug here return info.get_agent_id(df.iloc[evt.index[0]]) def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]): - return (agent_id, df.iloc[evt.index[0]][TASK_NAME_KEY]) + # get col index + col_idx = df.columns.get_loc(TASK_NAME_KEY) + return (agent_id, evt.row_value[col_idx]) def update_seeds(agent_task_id: tuple): agent_id, task_name = agent_task_id - global info seed_df = get_seeds_df(info.agent_df, task_name) first_seed = seed_df.iloc[0]["seed"] return seed_df, EpisodeId(agent_id=agent_id, task_name=task_name, seed=first_seed) @@ -828,7 +848,8 @@ def update_seeds(agent_task_id: tuple): def on_select_seed(evt: gr.SelectData, df: pd.DataFrame, agent_task_id: tuple): agent_id, task_name = agent_task_id - seed = df.iloc[evt.index[0]]["seed"] + col_idx = df.columns.get_loc("seed") + seed = evt.row_value[col_idx] # seed should be the first column return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed) @@ -903,18 +924,25 @@ def get_agent_report(result_df: pd.DataFrame): def update_global_stats(): - global info stats = inspect_results.global_report(info.result_df, reduce_fn=inspect_results.summarize_stats) stats.reset_index(inplace=True) return stats +def update_error_report(): + report_files = list(info.exp_list_dir.glob("error_report*.md")) + if len(report_files) == 0: + return "No error report found" + report_files = sorted(report_files, key=os.path.getctime, reverse=True) + return report_files[0].read_text() + + def new_exp_dir(exp_dir, progress=gr.Progress(), just_refresh=False): if exp_dir == select_dir_instructions: return None, None - global info + exp_dir = exp_dir.split(" - ")[0] if len(exp_dir) == 0: info.exp_list_dir = None @@ -924,15 +952,25 @@ def new_exp_dir(exp_dir, progress=gr.Progress(), just_refresh=False): info.result_df = inspect_results.load_result_df(info.exp_list_dir, progress_fn=progress.tqdm) info.result_df = remove_args_from_col(info.result_df) - agent_report = display_table(get_agent_report(info.result_df)) + study_summary = inspect_results.summarize_study(info.result_df) + # save study_summary + study_summary.to_csv(info.exp_list_dir / "summary_df.csv", index=False) + agent_report = display_table(study_summary) + info.agent_id_keys = agent_report.index.names agent_report.reset_index(inplace=True) - agent_report["Idx"] = agent_report.index agent_id = info.get_agent_id(agent_report.iloc[0]) constants, variables = format_constant_and_variables() - return agent_report, agent_id, constants, variables, update_global_stats() + return ( + agent_report, + agent_id, + constants, + variables, + update_global_stats(), + update_error_report(), + ) def new_agent_id(agent_id: list[tuple]): @@ -941,7 +979,6 @@ def new_agent_id(agent_id: list[tuple]): info.tasks_df = inspect_results.reduce_episodes(info.agent_df).reset_index() info.tasks_df = info.tasks_df.drop(columns=["std_err"]) - info.tasks_df["Idx"] = info.tasks_df.index # task name of first element task_name = info.tasks_df.iloc[0][TASK_NAME_KEY] @@ -949,14 +986,34 @@ def new_agent_id(agent_id: list[tuple]): def get_directory_contents(results_dir: Path): - directories = sorted( - [str(file.name) for file in results_dir.iterdir() if file.is_dir()], reverse=True - ) - return [select_dir_instructions] + directories + exp_descriptions = [] + for dir in results_dir.iterdir(): + if not dir.is_dir(): + continue + + exp_description = dir.name + # get summary*.csv files and find the most recent + summary_files = list(dir.glob("summary*.csv")) + if len(summary_files) != 0: + most_recent_summary = max(summary_files, key=os.path.getctime) + summary_df = pd.read_csv(most_recent_summary) + + # get row with max avg_reward + max_reward_row = summary_df.loc[summary_df["avg_reward"].idxmax()] + reward = max_reward_row["avg_reward"] * 100 + completed = max_reward_row["n_completed"] + n_err = max_reward_row["n_err"] + exp_description += ( + f" - avg-reward: {reward:.1f}% - completed: {completed} - errors: {n_err}" + ) + + exp_descriptions.append(exp_description) + + return [select_dir_instructions] + sorted(exp_descriptions, reverse=True) def most_recent_folder(results_dir: Path): - return inspect_results.get_most_recent_folder(results_dir).name + return get_most_recent_study(results_dir).name def refresh_exp_dir_choices(exp_dir_choice): diff --git a/src/agentlab/analyze/inspect_results.ipynb b/src/agentlab/analyze/inspect_results.ipynb index 6db090926..b4b3828ae 100644 --- a/src/agentlab/analyze/inspect_results.ipynb +++ b/src/agentlab/analyze/inspect_results.ipynb @@ -8,6 +8,8 @@ "source": [ "from agentlab.experiments.exp_utils import RESULTS_DIR\n", "from agentlab.analyze import inspect_results\n", + "from agentlab.experiments.study import get_most_recent_study\n", + "\n", "import pandas as pd\n", "pd.set_option('display.max_rows', 200)\n", "\n", @@ -52,7 +54,7 @@ "# result_dir = RESULTS_DIR / \"2024-05-28_01-13-04_generic_agent_eval_llm\" \n", "# result_dir = RESULTS_DIR / \"2024-05-28_01-44-29_generic_agent_eval_llm\"\n", "\n", - "result_dir = inspect_results.get_most_recent_folder(RESULTS_DIR, contains=None)\n", + "result_dir = get_most_recent_study(RESULTS_DIR, contains=None)\n", "\n", "print(result_dir)\n", "result_df = inspect_results.load_result_df(result_dir)" @@ -149,7 +151,11 @@ "metadata": {}, "outputs": [], "source": [ - "print(inspect_results.error_report(result_df, max_stack_trace=1))" + "from IPython.display import Markdown, display\n", + "\n", + "report = inspect_results.error_report(result_df, max_stack_trace=2, use_log=True)\n", + "# display(Markdown(report))\n", + "print(report)" ] }, { @@ -164,7 +170,7 @@ ], "metadata": { "kernelspec": { - "display_name": "ui-copilot", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/src/agentlab/analyze/inspect_results.py b/src/agentlab/analyze/inspect_results.py index 7d46113c9..09ba23a05 100644 --- a/src/agentlab/analyze/inspect_results.py +++ b/src/agentlab/analyze/inspect_results.py @@ -16,6 +16,7 @@ from IPython.display import display from tqdm import tqdm + from agentlab.analyze.error_categorization import ( ERR_CLASS_MAP, is_critical_server_error, @@ -245,9 +246,9 @@ def get_std_err(df, metric): if np.all(np.isin(data, [0, 1])): mean = np.mean(data) std_err = np.sqrt(mean * (1 - mean) / len(data)) + return mean, std_err else: return get_sample_std_err(df, metric) - return mean, std_err def get_sample_std_err(df, metric): @@ -258,7 +259,7 @@ def get_sample_std_err(df, metric): mean = np.mean(data) std_err = np.std(data, ddof=1) / np.sqrt(len(data)) if np.isnan(std_err): - std_err = 0 + std_err = np.float64(0) return mean, std_err @@ -295,6 +296,8 @@ def summarize(sub_df, use_bootstrap=False): n_completed=f"{n_completed}/{len(sub_df)}", n_err=err.sum(skipna=True), ) + if "stats.cum_cost" in sub_df: + record["cum_cost"] = sub_df["stats.cum_cost"].sum(skipna=True).round(4) return pd.Series(record) @@ -509,41 +512,6 @@ def flag_report(report: pd.DataFrame, metric: str = "avg_reward", round_digits: return flag_report -def get_most_recent_folder( - root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None -): - """Return the most recent directory based on the date in the folder name. - - Args: - root_dir: The directory to search in - date_format: The format of the date in the folder name - contains: If not None, only consider folders that contains this string - - Returns: - Path: The most recent folder satisfying the conditions - """ - - if root_dir is None: - root_dir = RESULTS_DIR - - most_recent_folder = None - most_recent_time = datetime.min - - for item in root_dir.iterdir(): - if item.is_dir() and not item.name.startswith("_"): - if contains is not None and contains not in item.name: - continue - try: - folder_date = datetime.strptime("_".join(item.name.split("_")[:2]), date_format) - if folder_date > most_recent_time: - most_recent_time = folder_date - most_recent_folder = item - except (ValueError, IndexError): - continue - - return most_recent_folder - - def display_report( report: pd.DataFrame, apply_shrink_columns: bool = True, @@ -615,10 +583,12 @@ def set_wrap_style(df): # ------------ -def map_err_key(err_msg): +def map_err_key(err_msg: str): if err_msg is None: return err_msg + # remove logs from the message if any + err_msg = err_msg[: err_msg.find("=== logs ===")].rstrip() regex_replacements = [ ( r"your messages resulted in \d+ tokens", @@ -635,7 +605,7 @@ def map_err_key(err_msg): return err_msg -def error_report(df: pd.DataFrame, max_stack_trace=10): +def error_report(df: pd.DataFrame, max_stack_trace=10, use_log=False): """Report the error message for each agent.""" if "err_key" not in df: @@ -645,35 +615,62 @@ def error_report(df: pd.DataFrame, max_stack_trace=10): report = [] for err_key, count in unique_counts.items(): report.append("-------------------") - report.append(f"{count}x : {err_key}\n") + report.append(f"## {count}x : " + err_key.replace("\n", "
") + "\n") + # find sub_df with this error message sub_df = df[df["err_key"] == err_key] idx = 0 exp_result_list = [get_exp_result(row.exp_dir) for _, row in sub_df.iterrows()] - task_names = [exp_result.exp_args.env_args.task_name for exp_result in exp_result_list] - - # count unique using numpy - unique_task_names, counts = np.unique(task_names, return_counts=True) - task_and_count = sorted(zip(unique_task_names, counts), key=lambda x: x[1], reverse=True) - for task_name, count in task_and_count: - report.append(f"{count:2d} {task_name}") + exp_result_list = sorted(exp_result_list, key=lambda x: x.exp_args.env_args.task_name) + for exp_result in exp_result_list: + report.append( + f"* {exp_result.exp_args.env_args.task_name} seed: {exp_result.exp_args.env_args.task_seed}" + ) report.append(f"\nShowing Max {max_stack_trace} stack traces:\n") for exp_result in exp_result_list: if idx >= max_stack_trace: break - # print task name and stack trace - stack_trace = exp_result.summary_info.get("stack_trace", "") - report.append(f"Task Name: {exp_result.exp_args.env_args.task_name}\n") - report.append(f"exp_dir: {exp_result.exp_dir}\n") - report.append(f"Stack Trace: \n {stack_trace}\n") - report.append("\n") + + if not use_log: + # print task name and stack trace + stack_trace = exp_result.summary_info.get("stack_trace", "") + report.append(f"Task Name: {exp_result.exp_args.env_args.task_name}\n") + report.append(f"exp_dir: {exp_result.exp_dir}\n") + report.append(f"Stack Trace: \n {stack_trace}\n") + report.append("\n") + else: + report.append(f"```bash\n{_format_log(exp_result)}\n```") + idx += 1 return "\n".join(report) +def _format_log(exp_result: ExpResult, head_lines=10, tail_lines=50): + """Extract head and tail of the log. Try to find the traceback.""" + log = exp_result.logs + if log is None: + return "No log found" + + log_lines = log.split("\n") + if len(log_lines) <= head_lines + tail_lines: + return log + + # first 10 lines: + log_head = "\n".join(log_lines[:head_lines]) + + try: + traceback_idx = log.rindex("Traceback (most recent call last):") + tail_idx = log.rindex("action:", 0, traceback_idx) + log_tail = log[tail_idx:] + except ValueError: + log_tail = "\n".join(log_lines[-tail_lines:]) + + return log_head + "\n...\n...truncated middle of the log\n...\n" + log_tail + + def categorize_error(row): if pd.isna(row.get("err_msg", None)): return None diff --git a/src/agentlab/experiments/args.py b/src/agentlab/experiments/args.py index bbbb3b7b4..6a4fa804e 100644 --- a/src/agentlab/experiments/args.py +++ b/src/agentlab/experiments/args.py @@ -105,13 +105,19 @@ def expand_cross_product(obj: Any | list[Any]): for obj in obj_list: cprod_paths = _find_cprod_with_paths(obj) if not cprod_paths: - return [copy.deepcopy(obj)] + result.append(copy.deepcopy(obj)) + continue paths, cprod_objects = zip(*cprod_paths) combinations = product(*[cprod_obj.elements for cprod_obj in cprod_objects]) + # create a base object with empty fields to make fast deep copies from + base_obj = copy.deepcopy(obj) + for path in paths: + _set_value(base_obj, path, None) + for combo in combinations: - new_obj = copy.deepcopy(obj) + new_obj = copy.deepcopy(base_obj) for path, value in zip(paths, combo): _set_value(new_obj, path, value) result.append(new_obj) diff --git a/src/agentlab/experiments/exp_utils.py b/src/agentlab/experiments/exp_utils.py index 3ae88deff..97ce527d9 100644 --- a/src/agentlab/experiments/exp_utils.py +++ b/src/agentlab/experiments/exp_utils.py @@ -4,6 +4,12 @@ from tqdm import tqdm import logging from browsergym.experiments.loop import ExpArgs +from contextlib import contextmanager +import signal +import sys +from time import time, sleep + +logger = logging.getLogger(__name__) # Get logger based on module name # TODO move this to a more appropriate place @@ -19,8 +25,149 @@ RESULTS_DIR.mkdir(parents=True, exist_ok=True) +def run_exp(exp_arg: ExpArgs, *dependencies, avg_step_timeout=60): + """Run exp_args.run() with a timeout and handle dependencies.""" + # episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout) + # logger.warning(f"Running {exp_arg.exp_id} with timeout of {episode_timeout} seconds.") + # with timeout_manager(seconds=episode_timeout): + # this timeout method is not robust enough. using ray.cancel instead + return exp_arg.run() + + +def _episode_timeout(exp_arg: ExpArgs, avg_step_timeout=60): + """Some logic to determine the episode timeout.""" + max_steps = getattr(exp_arg.env_args, "max_steps", None) + if max_steps is None: + episode_timeout_global = 10 * 60 * 60 # 10 hours + else: + episode_timeout_global = exp_arg.env_args.max_steps * avg_step_timeout + + episode_timeout_exp = getattr(exp_arg, "episode_timeout", episode_timeout_global) + + return min(episode_timeout_global, episode_timeout_exp) + + +@contextmanager +def timeout_manager(seconds: int = None): + """Context manager to handle timeouts.""" + + if isinstance(seconds, float): + seconds = max(1, int(seconds)) # make sure seconds is at least 1 + + if seconds is None or sys.platform == "win32": + try: + logger.warning("Timeouts are not supported on Windows.") + yield + finally: + pass + return + + def alarm_handler(signum, frame): + + logger.warning(f"Operation timed out after {seconds}s, raising TimeoutError.") + # send sigint + # os.kill(os.getpid(), signal.SIGINT) # this doesn't seem to do much I don't know why + + # Still raise TimeoutError for immediate handling + # This works, but it doesn't seem enough to kill the job + raise TimeoutError(f"Operation timed out after {seconds} seconds") + + previous_handler = signal.signal(signal.SIGALRM, alarm_handler) + signal.alarm(seconds) + + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, previous_handler) + + +def add_dependencies(exp_args_list: list[ExpArgs], task_dependencies: dict[str, list[str]] = None): + """Add dependencies to a list of ExpArgs. + + Args: + exp_args_list: list[ExpArgs] + A list of experiments to run. + task_dependencies: dict + A dictionary mapping task names to a list of task names that they + depend on. If None or empty, no dependencies are added. + + Returns: + list[ExpArgs] + The modified exp_args_list with dependencies added. + """ + + if task_dependencies is None or all([len(dep) == 0 for dep in task_dependencies.values()]): + # nothing to be done + return exp_args_list + + for exp_args in exp_args_list: + exp_args.make_id() # makes sure there is an exp_id + + exp_args_map = {exp_args.env_args.task_name: exp_args for exp_args in exp_args_list} + if len(exp_args_map) != len(exp_args_list): + raise ValueError( + ( + "Task names are not unique in exp_args_map, " + "you can't run multiple seeds with task dependencies." + ) + ) + + for task_name in exp_args_map.keys(): + if task_name not in task_dependencies: + raise ValueError(f"Task {task_name} is missing from task_dependencies") + + # turn dependencies from task names to exp_ids + for task_name, exp_args in exp_args_map.items(): + exp_args.depends_on = tuple( + exp_args_map[dep_name].exp_id for dep_name in task_dependencies[task_name] + ) + + return exp_args_list + + +# Mock implementation of the ExpArgs class with timestamp checks for unit testing +class MockedExpArgs: + def __init__(self, exp_id, depends_on=None): + self.exp_id = exp_id + self.depends_on = depends_on if depends_on else [] + self.start_time = None + self.end_time = None + self.env_args = None + + def run(self): + self.start_time = time() + + # # simulate playright code, (this was causing issues due to python async loop) + # import playwright.sync_api + + # pw = playwright.sync_api.sync_playwright().start() + # pw.selectors.set_test_id_attribute("mytestid") + sleep(3) # Simulate task execution time + self.end_time = time() + return self + + +def make_seeds(n, offset=42): + raise DeprecationWarning("This function will be removed. Comment out this error if needed.") + return [seed + offset for seed in range(n)] + + +def order(exp_args_list: list[ExpArgs]): + raise DeprecationWarning("This function will be removed. Comment out this error if needed.") + """Store the order of the list of experiments to be able to sort them back. + + This is important for progression or ablation studies. + """ + for i, exp_args in enumerate(exp_args_list): + exp_args.order = i + return exp_args_list + + +# This was an old function for filtering some issue with the experiments. def hide_some_exp(base_dir, filter: callable, just_test): """Move all experiments that match the filter to a new name.""" + raise DeprecationWarning("This function will be removed. Comment out this error if needed.") exp_list = list(yield_all_exp_results(base_dir, progress_fn=None)) msg = f"Searching {len(exp_list)} experiments to move to _* expriments where `filter(exp_args)` is True." @@ -38,17 +185,3 @@ def hide_some_exp(base_dir, filter: callable, just_test): _move_old_exp(exp.exp_dir) filtered_out.append(exp) return filtered_out - - -def make_seeds(n, offset=42): - return [seed + offset for seed in range(n)] - - -def order(exp_args_list: list[ExpArgs]): - """Store the order of the list of experiments to be able to sort them back. - - This is important for progression or ablation studies. - """ - for i, exp_args in enumerate(exp_args_list): - exp_args.order = i - return exp_args_list diff --git a/src/agentlab/experiments/get_ray_url.py b/src/agentlab/experiments/get_ray_url.py new file mode 100644 index 000000000..b652254cb --- /dev/null +++ b/src/agentlab/experiments/get_ray_url.py @@ -0,0 +1,5 @@ +import ray + +context = ray.init(address="auto", ignore_reinit_error=True) + +print(context) diff --git a/src/agentlab/experiments/graph_execution.py b/src/agentlab/experiments/graph_execution.py deleted file mode 100644 index c12a1048b..000000000 --- a/src/agentlab/experiments/graph_execution.py +++ /dev/null @@ -1,96 +0,0 @@ -from dask import compute, delayed -from browsergym.experiments.loop import ExpArgs -from distributed import LocalCluster, Client - - -def _run(exp_arg: ExpArgs, *dependencies): - return exp_arg.run() - - -def make_dask_client(n_worker): - """Create a Dask client with a LocalCluster backend. - - I struggled to find an appropriate configuration. - I believe it has to do with the interplay of playwright async loop (even if - used in sync mode) and the fact that dask uses asyncio under the hood. - Making sure we use processes and 1 thread per worker seems to work. - - Args: - n_worker: int - Number of workers to create. - - Returns: - A Dask client object. - """ - cluster = LocalCluster( - n_workers=n_worker, - processes=True, - threads_per_worker=1, - ) - - return Client(cluster) - - -def execute_task_graph(exp_args_list: list[ExpArgs]): - """Execute a task graph in parallel while respecting dependencies.""" - exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list} - - tasks = {} - - def get_task(exp_arg: ExpArgs): - if exp_arg.exp_id not in tasks: - dependencies = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on] - tasks[exp_arg.exp_id] = delayed(_run)(exp_arg, *dependencies) - return tasks[exp_arg.exp_id] - - for exp_arg in exp_args_list: - get_task(exp_arg) - - task_ids, task_list = zip(*tasks.items()) - results = compute(*task_list) - - return {task_id: result for task_id, result in zip(task_ids, results)} - - -def add_dependencies(exp_args_list: list[ExpArgs], task_dependencies: dict[list] = None): - """Add dependencies to a list of ExpArgs. - - Args: - exp_args_list: list[ExpArgs] - A list of experiments to run. - task_dependencies: dict - A dictionary mapping task names to a list of task names that they - depend on. If None or empty, no dependencies are added. - - Returns: - list[ExpArgs] - The modified exp_args_list with dependencies added. - """ - - if task_dependencies is None or all([len(dep) == 0 for dep in task_dependencies.values()]): - # nothing to be done - return exp_args_list - - exp_args_map = {exp_args.env_args.task_name: exp_args for exp_args in exp_args_list} - if len(exp_args_map) != len(exp_args_list): - raise ValueError( - ( - "Task names are not unique in exp_args_map, " - "you can't run multiple seeds with task dependencies." - ) - ) - - for task_name in exp_args_map.keys(): - if task_name not in task_dependencies: - raise ValueError(f"Task {task_name} is missing from task_dependencies") - - # turn dependencies from task names to exp_ids - for task_name, exp_args in exp_args_map.items(): - - exp_args.depends_on = tuple( - exp_args_map[dep_name].exp_id - for dep_name in task_dependencies[task_name] - if dep_name in exp_args_map # ignore dependencies that are not to be run - ) - - return exp_args_list diff --git a/src/agentlab/experiments/graph_execution_dask.py b/src/agentlab/experiments/graph_execution_dask.py new file mode 100644 index 000000000..dc51dd518 --- /dev/null +++ b/src/agentlab/experiments/graph_execution_dask.py @@ -0,0 +1,64 @@ +from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError + +from contextlib import contextmanager +import threading +from dask import compute, delayed +from bgym import ExpArgs +from distributed import LocalCluster, Client +from agentlab.experiments.exp_utils import _episode_timeout + +# from agentlab.experiments.exp_utils import run_exp + + +def run_exp(exp_arg: ExpArgs, *dependencies, avg_step_timeout=60): + """Run exp_args.run() with a timeout and handle dependencies.""" + # dask can't use the timeout_manager define in exp_utils.py + # ValueError: signal only works in main thread of the main interpreter + # most alternative I try doesn't work + episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout) + return exp_arg.run() + + +def make_dask_client(n_worker): + """Create a Dask client with a LocalCluster backend. + + I struggled to find an appropriate configuration. + I believe it has to do with the interplay of playwright async loop (even if + used in sync mode) and the fact that dask uses asyncio under the hood. + Making sure we use processes and 1 thread per worker seems to work. + + Args: + n_worker: int + Number of workers to create. + + Returns: + A Dask client object. + """ + cluster = LocalCluster( + n_workers=n_worker, + processes=True, + threads_per_worker=1, + ) + + return Client(cluster) + + +def execute_task_graph(exp_args_list: list[ExpArgs]): + """Execute a task graph in parallel while respecting dependencies.""" + exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list} + + tasks = {} + + def get_task(exp_arg: ExpArgs): + if exp_arg.exp_id not in tasks: + dependencies = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on] + tasks[exp_arg.exp_id] = delayed(run_exp)(exp_arg, *dependencies) + return tasks[exp_arg.exp_id] + + for exp_arg in exp_args_list: + get_task(exp_arg) + + task_ids, task_list = zip(*tasks.items()) + results = compute(*task_list) + + return {task_id: result for task_id, result in zip(task_ids, results)} diff --git a/src/agentlab/experiments/graph_execution_ray.py b/src/agentlab/experiments/graph_execution_ray.py new file mode 100644 index 000000000..5dd18d4ae --- /dev/null +++ b/src/agentlab/experiments/graph_execution_ray.py @@ -0,0 +1,90 @@ +# import os + +# # Disable Ray log deduplication +# os.environ["RAY_DEDUP_LOGS"] = "0" +import logging +import time + +import bgym +import ray +from ray.util import state + +from agentlab.experiments.exp_utils import _episode_timeout, run_exp + +logger = logging.getLogger(__name__) + +run_exp = ray.remote(run_exp) + + +def execute_task_graph(exp_args_list: list[bgym.ExpArgs], avg_step_timeout=60): + """Execute a task graph in parallel while respecting dependencies using Ray.""" + + exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list} + task_map = {} + + def get_task(exp_arg: bgym.ExpArgs): + if exp_arg.exp_id not in task_map: + # Get all dependency tasks first + dependency_tasks = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on] + + # Create new task that depends on the dependency results + task_map[exp_arg.exp_id] = run_exp.remote( + exp_arg, *dependency_tasks, avg_step_timeout=avg_step_timeout + ) + return task_map[exp_arg.exp_id] + + # Build task graph + for exp_arg in exp_args_list: + get_task(exp_arg) + + max_timeout = max([_episode_timeout(exp_args, avg_step_timeout) for exp_args in exp_args_list]) + + return poll_for_timeout(task_map, max_timeout, poll_interval=max_timeout * 0.1) + + +def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_interval: float = 1.0): + """Cancel tasks that exceeds the timeout + + I tried various different methods for killing a job that hangs. so far it's + the only one that seems to work reliably (hopefully) + """ + task_list = list(tasks.values()) + task_ids = list(tasks.keys()) + + logger.warning(f"Any task exceeding {timeout} seconds will be cancelled.") + + while True: + ready, not_ready = ray.wait(task_list, num_returns=len(task_list), timeout=poll_interval) + for task in not_ready: + elapsed_time = get_elapsed_time(task) + # print(f"Task {task.task_id().hex()} elapsed time: {elapsed_time}") + if elapsed_time is not None and elapsed_time > timeout: + msg = f"Task {task.task_id().hex()} hase been running for {elapsed_time}s, more than the timeout: {timeout}s." + if elapsed_time < timeout + 60 + poll_interval: + logger.warning(msg + " Cancelling task.") + ray.cancel(task, force=False, recursive=False) + else: + logger.warning(msg + " Force killing.") + ray.cancel(task, force=True, recursive=False) + if len(ready) == len(task_list): + results = [] + for task in ready: + try: + result = ray.get(task) + except Exception as e: + result = e + results.append(result) + + return {task_id: result for task_id, result in zip(task_ids, results)} + + +def get_elapsed_time(task_ref: ray.ObjectRef): + task_id = task_ref.task_id().hex() + task_info = state.get_task(task_id, address="auto") + if task_info and task_info.start_time_ms is not None: + start_time_s = task_info.start_time_ms / 1000.0 # Convert ms to s + current_time_s = time.time() + elapsed_time = current_time_s - start_time_s + return elapsed_time + else: + return None # Task has not started yet diff --git a/src/agentlab/experiments/launch_exp.py b/src/agentlab/experiments/launch_exp.py index fd7d2b6b3..cb331a99f 100644 --- a/src/agentlab/experiments/launch_exp.py +++ b/src/agentlab/experiments/launch_exp.py @@ -2,24 +2,17 @@ from importlib import import_module from pathlib import Path +import bgym from browsergym.experiments.loop import ExpArgs, yield_all_exp_results - - -def import_object(path: str): - module_name, obj_name = split_path(path) - try: - module = import_module(module_name) - obj = getattr(module, obj_name) - except (ImportError, AttributeError) as e: - raise ImportError(f"Error importing {path}: {e}") - return obj +from agentlab.experiments.exp_utils import run_exp def run_experiments( n_jobs, exp_args_list: list[ExpArgs], study_dir, - parallel_backend="joblib", + parallel_backend="ray", + avg_step_timeout=60, ): """Run a list of ExpArgs in parallel. @@ -34,7 +27,10 @@ def run_experiments( exp_dir: Path Directory where the experiments will be saved. parallel_backend: str - Parallel backend to use. Either "joblib", "dask" or "sequential". + Parallel backend to use. Either "joblib", "ray" or "sequential". + The only backend that supports webarena graph dependencies correctly is ray or sequential. + avg_step_timeout: int + Will raise a TimeoutError if the episode is not finished after env_args.max_steps * avg_step_timeout seconds. """ if len(exp_args_list) == 0: @@ -44,9 +40,9 @@ def run_experiments( study_dir = Path(study_dir) study_dir.mkdir(parents=True, exist_ok=True) - if n_jobs == 1 and parallel_backend != "sequential": - logging.warning("Only 1 job, switching to sequential backend.") - parallel_backend = "sequential" + # if n_jobs == 1 and parallel_backend != "sequential": + # logging.warning("Only 1 job, switching to sequential backend.") + # parallel_backend = "sequential" logging.info(f"Saving experiments to {study_dir}") for exp_args in exp_args_list: @@ -56,18 +52,40 @@ def run_experiments( if parallel_backend == "joblib": from joblib import Parallel, delayed - Parallel(n_jobs=n_jobs, prefer="processes")( - delayed(exp_args.run)() for exp_args in exp_args_list + # split sequential (should be no longer needed with dependencies) + sequential_exp_args, exp_args_list = _split_sequential_exp(exp_args_list) + + logging.info( + f"Running {len(sequential_exp_args)} in sequential first. The remaining {len(exp_args_list)} will be run in parallel." ) + for exp_args in sequential_exp_args: + run_exp(exp_args, avg_step_timeout=avg_step_timeout) - elif parallel_backend == "dask": - from agentlab.experiments.graph_execution import execute_task_graph, make_dask_client + Parallel(n_jobs=n_jobs, prefer="processes")( + delayed(run_exp)(exp_args, avg_step_timeout=avg_step_timeout) + for exp_args in exp_args_list + ) - with make_dask_client(n_worker=n_jobs): - execute_task_graph(exp_args_list) + # dask will be deprecated, as there was issues. use ray instead + # elif parallel_backend == "dask": + # from agentlab.experiments.graph_execution_dask import ( + # execute_task_graph, + # make_dask_client, + # ) + + # with make_dask_client(n_worker=n_jobs): + # execute_task_graph(exp_args_list) + elif parallel_backend == "ray": + from agentlab.experiments.graph_execution_ray import execute_task_graph, ray + + ray.init(num_cpus=n_jobs) + try: + execute_task_graph(exp_args_list, avg_step_timeout=avg_step_timeout) + finally: + ray.shutdown() elif parallel_backend == "sequential": for exp_args in exp_args_list: - exp_args.run() + run_exp(exp_args, avg_step_timeout=avg_step_timeout) else: raise ValueError(f"Unknown parallel_backend: {parallel_backend}") finally: @@ -79,13 +97,16 @@ def run_experiments( logging.info("Experiment finished.") -def relaunch_study(study_dir: str | Path, relaunch_mode="incomplete_only"): - """Return exp_args_list and study_dir +def find_incomplete(study_dir: str | Path, include_errors=True): + """Find all incomplete experiments for relaunching. + + Note: completed experiments are kept but are replaced by dummy exp_args + with nothing to run. This help keeping the dependencies between tasks. Args: study_dir: Path The directory where the experiments are saved. - relaunch_mode: str + include_errors: str Find all incomplete experiments and relaunch them. - "incomplete_only": relaunch only the incomplete experiments. - "incomplete_or_error": relaunch incomplete or errors. @@ -96,47 +117,87 @@ def relaunch_study(study_dir: str | Path, relaunch_mode="incomplete_only"): raise ValueError( f"You asked to relaunch an existing experiment but {study_dir} does not exist." ) - exp_args_list = list(_yield_incomplete_experiments(study_dir, relaunch_mode=relaunch_mode)) - if len(exp_args_list) == 0: + exp_result_list = list(yield_all_exp_results(study_dir, progress_fn=None)) + exp_args_list = [_hide_completed(exp_result, include_errors) for exp_result in exp_result_list] + # sort according to exp_args.order + exp_args_list.sort(key=lambda exp_args: exp_args.order if exp_args.order is not None else 0) + + job_count = non_dummy_count(exp_args_list) + + if job_count == 0: logging.info(f"No incomplete experiments found in {study_dir}.") - return [], study_dir + return exp_args_list + else: + logging.info(f"Found {job_count} incomplete experiments in {study_dir}.") message = f"Make sure the processes that were running are all stopped. Otherwise, " f"there will be concurrent writing in the same directories.\n" logging.info(message) - return exp_args_list, study_dir + return exp_args_list + +def non_dummy_count(exp_args_list: list[ExpArgs]) -> int: + return sum([not exp_args.is_dummy for exp_args in exp_args_list]) -def _yield_incomplete_experiments(exp_root, relaunch_mode="incomplete_only"): - """Find all incomplete experiments and relaunch them.""" - # TODO(make relanch_mode a callable, for flexibility) - for exp_result in yield_all_exp_results(exp_root, progress_fn=None): # type: ExpArgs - try: - # TODO implement has_finished instead of dealing with FileNotFoundError - summary_info = exp_result.summary_info - except FileNotFoundError: - yield exp_result.exp_args - continue +def noop(*args, **kwargs): + pass - if relaunch_mode == "incomplete_only": - continue - err_msg = summary_info.get("err_msg", None) +def _hide_completed(exp_result: bgym.ExpResult, include_errors: bool = True): + """Hide completed experiments from the list. - if err_msg is not None: - if relaunch_mode == "incomplete_or_error": - yield exp_result.exp_args - else: - raise ValueError(f"Unknown relaunch_mode: {relaunch_mode}") + This little hack, allows an elegant way to keep the task dependencies for e.g. webarena + while skipping the tasks that are completed when relaunching. + """ + + hide = False + if exp_result.status == "done": + hide = True + if exp_result.status == "error" and (not include_errors): + hide = True + + exp_args = exp_result.exp_args + exp_args.is_dummy = hide # just to keep track + exp_args.status = exp_result.status + if hide: + # make those function do nothing since they are finished. + exp_args.run = noop + exp_args.prepare = noop + + return exp_args + + +# TODO remove this function once ray backend is stable +def _split_sequential_exp(exp_args_list: list[ExpArgs]) -> tuple[list[ExpArgs], list[ExpArgs]]: + """split exp_args that are flagged as sequential from those that are not""" + sequential_exp_args = [] + parallel_exp_args = [] + for exp_args in exp_args_list: + if getattr(exp_args, "sequential", False): + sequential_exp_args.append(exp_args) + else: + parallel_exp_args.append(exp_args) + + return sequential_exp_args, parallel_exp_args -def split_path(path: str): +def _split_path(path: str): """Split a path into a module name and an object name.""" if "/" in path: path = path.replace("/", ".") module_name, obj_name = path.rsplit(".", 1) return module_name, obj_name + + +def import_object(path: str): + module_name, obj_name = _split_path(path) + try: + module = import_module(module_name) + obj = getattr(module, obj_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Error importing {path}: {e}") + return obj diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index 3ef7d8ef6..52e4e62a3 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -1,27 +1,30 @@ import csv -import json import logging import os import platform -from copy import deepcopy from datetime import datetime from importlib import metadata from pathlib import Path +import bgym import pandas as pd -from browsergym.experiments.loop import ExpArgs from git import InvalidGitRepositoryError, Repo from git.config import GitConfigParser import agentlab -from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs def _get_repo(module): return Repo(Path(module.__file__).resolve().parent, search_parent_directories=True) -def _get_benchmark_version(benchmark_name): +def _get_benchmark_version(benchmark: bgym.Benchmark) -> str: + benchmark_name = benchmark.name + + if hasattr(benchmark, "get_version"): + return benchmark.get_version() + + # in between 2 pull requests if benchmark_name.startswith("miniwob"): return metadata.distribution("browsergym.miniwob").version elif benchmark_name.startswith("workarena"): @@ -35,6 +38,8 @@ def _get_benchmark_version(benchmark_name): return metadata.distribution("weblinx_browsergym").version except metadata.PackageNotFoundError: return "0.0.1rc1" + elif benchmark_name.startswith("assistantbench"): + return metadata.distribution("browsergym.assistantbench").version else: raise ValueError(f"Unknown benchmark {benchmark_name}") @@ -166,13 +171,15 @@ def _get_git_info(module, changes_white_list=()) -> tuple[str, list[tuple[str, P def get_reproducibility_info( - agent_name: str | list[str], - benchmark_name, + agent_names: str | list[str], + benchmark: bgym.Benchmark, + study_id: str = "", comment=None, changes_white_list=( # Files that are often modified during experiments but do not affect reproducibility "*/reproducibility_script.py", "*reproducibility_journal.csv", "*main.py", + "*inspect_results.ipynb", ), ignore_changes=False, ): @@ -183,15 +190,16 @@ def get_reproducibility_info( import agentlab - if isinstance(agent_name, str): - agent_name = [agent_name] + if isinstance(agent_names, str): + agent_names = [agent_names] info = { "git_user": _get_git_username(_get_repo(agentlab)), - "agent_names": agent_name, - "benchmark": benchmark_name, + "agent_names": agent_names, + "benchmark": benchmark.name, + "study_id": study_id, "comment": comment, - "benchmark_version": _get_benchmark_version(benchmark_name), + "benchmark_version": _get_benchmark_version(benchmark), "date": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), "os": f"{platform.system()} ({platform.version()})", "python_version": platform.python_version(), @@ -226,7 +234,7 @@ def add_git_info(module_name, module): return info -def _assert_compatible(info: dict, old_info: dict, raise_if_incompatible=True): +def assert_compatible(info: dict, old_info: dict, raise_if_incompatible=True): """Make sure that the two info dicts are compatible.""" # TODO may need to adapt if there are multiple agents, and the re-run on # error only has a subset of agents. Hence old_info.agent_name != info.agent_name @@ -234,81 +242,12 @@ def _assert_compatible(info: dict, old_info: dict, raise_if_incompatible=True): if key in ("date", "avg_reward", "std_err", "n_completed", "n_err"): continue if info[key] != old_info[key]: - if not raise_if_incompatible: - logging.warning( - f"Reproducibility info already exist and is not compatible." - f"Key {key} has changed from {old_info[key]} to {info[key]}." - ) - else: - raise ValueError( - f"Reproducibility info already exist and is not compatible." - f"Key {key} has changed from {old_info[key]} to {info[key]}." - f"Set strict_reproducibility=False to bypass this error." - ) - - -# def _benchmark_from_task_name(task_name: str): -# """Extract the benchmark from the task name. -# TODO should be more robost, e.g. handle workarna.L1, workarena.L2, etc. -# """ -# return task_name.split(".")[0] - - -# def infer_agent(exp_args_list: list[ExpArgs]): -# return list(set(exp_args.agent_args.agent_name for exp_args in exp_args_list)) - - -# def infer_benchmark(exp_args_list: list[ExpArgs]): -# bench_name = set( -# _benchmark_from_task_name(exp_args.env_args.task_name) for exp_args in exp_args_list -# ) -# if len(bench_name) > 1: -# raise ValueError( -# f"Multiple benchmarks in the same study are not well supported: {bench_name}." -# "Comment out the reproducibility part of the code to proceed at your own risk." -# ) - -# return bench_name.pop() - - -# def write_reproducibility_info( -# study_dir, agent_name, benchmark_name, comment=None, strict_reproducibility=True -# ): -# info = get_reproducibility_info( -# agent_name, benchmark_name, comment, ignore_changes=not strict_reproducibility -# ) -# return save_reproducibility_info(study_dir, info, strict_reproducibility) - - -def save_reproducibility_info(study_dir, info, strict_reproducibility=True): - """ - Save a JSON file containing reproducibility information to the specified directory. - """ - - info_path = Path(study_dir) / "reproducibility_info.json" - - if info_path.exists(): - with open(info_path, "r") as f: - existing_info = json.load(f) - _assert_compatible(info, existing_info, raise_if_incompatible=strict_reproducibility) - logging.info( - "Reproducibility info already exists and is compatible. Overwriting the old one." - ) - - with open(info_path, "w") as f: - json.dump(info, f, indent=4) - - info_str = json.dumps(info, indent=4) - logging.info(f"Reproducibility info saved to {info_path}. Info: {info_str}") - - return info - - -def load_reproducibility_info(study_dir) -> dict[str]: - """Retrieve the reproducibility info from the study directory.""" - info_path = Path(study_dir) / "reproducibility_info.json" - with open(info_path, "r") as f: - return json.load(f) + _raise_or_warn( + f"Reproducibility info already exist and is not compatible." + f"Key {key} has changed from {old_info[key]} to {info[key]}." + f"Set strict_reproducibility=False to bypass this error.", + raise_error=raise_if_incompatible, + ) def _raise_or_warn(msg, raise_error=True): @@ -325,14 +264,10 @@ def _verify_report(report_df: pd.DataFrame, agent_names=list[str], strict_reprod unique_agent_names = report_df["agent.agent_name"].unique() if set(agent_names) != set(unique_agent_names): raise ValueError( - f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}.", - raise_error=strict_reproducibility, + f"Agent names in the report {unique_agent_names} do not match the agent names {agent_names}." ) if len(set(agent_names)) != len(agent_names): - raise ValueError( - f"Duplicate agent names {agent_names}.", - raise_error=strict_reproducibility, - ) + raise ValueError(f"Duplicate agent names {agent_names}.") report_df = report_df.set_index("agent.agent_name", inplace=False) diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py new file mode 100644 index 000000000..b42f0bb5d --- /dev/null +++ b/src/agentlab/experiments/study.py @@ -0,0 +1,489 @@ +import gzip +import logging +import pickle +import re +import uuid +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import bgym +from bgym import Benchmark, EnvArgs, ExpArgs +from slugify import slugify + +from agentlab.agents.agent_args import AgentArgs +from agentlab.analyze import inspect_results +from agentlab.experiments import args +from agentlab.experiments import reproducibility_util as repro +from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies +from agentlab.experiments.launch_exp import ( + find_incomplete, + non_dummy_count, + run_experiments, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class Study: + """A study coresponds to one or multiple agents evaluated on a benchmark. + + This is part of the high level API to help keep experiments organized and reproducible. + + Attributes: + benchmark: Benchmark | str + The benchmark to evaluate the agents on. If a string is provided, it will be + converted to the corresponding benchmark using bgym.DEFAULT_BENCHMARKS. + + agent_args: list[AgentArgs] + The list of agents to evaluate. + + dir: Path + The directory where the results will be saved. + + suffix: str + A suffix to add to the study name + + uuid: str + A unique identifier for the study + + reproducibility_info: dict + The reproducibility information for the study. + """ + + agent_args: list[AgentArgs] = None + benchmark: Benchmark | str = None + dir: Path = None + suffix: str = "" # used for adding a personnal comment to the study name + uuid: str = None + reproducibility_info: dict = None + logging_level: int = logging.DEBUG + logging_level_stdout: int = logging.WARNING + comment: str = None # Extra comments from the authors of this study + ignore_dependencies: bool = False + + def __post_init__(self): + self.uuid = uuid.uuid4() + if isinstance(self.benchmark, str): + self.benchmark = bgym.DEFAULT_BENCHMARKS[self.benchmark]() + if isinstance(self.dir, str): + self.dir = Path(self.dir) + self.make_exp_args_list() + + def make_exp_args_list(self): + self.exp_args_list = _agents_on_benchmark( + self.agent_args, + self.benchmark, + logging_level=self.logging_level, + logging_level_stdout=self.logging_level_stdout, + ignore_dependencies=self.ignore_dependencies, + ) + + def find_incomplete(self, include_errors=True): + """Find incomplete or errored experiments in the study directory for relaunching. + + Args: + include_errors: bool + If True, include errored experiments in the list. + + Returns: + list[ExpArgs]: The list of all experiments with completed ones replaced by a + dummy exp_args to keep the task dependencies. + """ + self.exp_args_list = find_incomplete(self.dir, include_errors=include_errors) + n_incomplete = non_dummy_count(self.exp_args_list) + n_error = [ + getattr(exp_args, "status", "incomplete") == "error" for exp_args in self.exp_args_list + ].count(True) + return n_incomplete, n_error + + def load_exp_args_list(self): + logger.info(f"Loading experiments from {self.dir}") + self.exp_args_list = list(inspect_results.yield_all_exp_results(savedir_base=self.dir)) + + def set_reproducibility_info(self, strict_reproducibility=False, comment=None): + """Gather relevant information that may affect the reproducibility of the experiment + + e.g.: versions of BrowserGym, benchmark, AgentLab...""" + agent_names = [a.agent_name for a in self.agent_args] + info = repro.get_reproducibility_info( + agent_names, + self.benchmark, + self.uuid, + ignore_changes=not strict_reproducibility, + comment=comment, + ) + if self.reproducibility_info is not None: + repro.assert_compatible( + self.reproducibility_info, info, raise_if_incompatible=strict_reproducibility + ) + self.reproducibility_info = info + + def run( + self, + n_jobs=1, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=3, + relaunch_errors=True, + ): + + self.set_reproducibility_info( + strict_reproducibility=strict_reproducibility, comment=self.comment + ) + self.save() + + n_exp = len(self.exp_args_list) + last_error_count = None + + for i in range(n_relaunch): + logger.info(f"Launching study {self.name} - trial {i + 1} / {n_relaunch}") + self._run(n_jobs, parallel_backend, strict_reproducibility) + + suffix = f"trial_{i + 1}_of_{n_relaunch}" + _, summary_df, error_report = self.get_results(suffix=suffix) + logger.info("\n" + str(summary_df)) + + n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors) + + if n_error / n_exp > 0.3: + logger.warning(f"More than 30% of the experiments errored. Stopping the study.") + return + + if last_error_count is not None and n_error >= last_error_count: + logger.warning( + f"Last trial did not reduce the number of errors. Stopping the study." + ) + return + + if n_incomplete == 0: + logger.info(f"Study {self.name} finished.") + return + + logger.warning( + f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." + ) + + def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): + """Run all experiments in the study in parallel when possible. + + Args: + n_jobs: int + Number of parallel jobs. + + parallel_backend: str + Parallel backend to use. Either "joblib", "dask" or "sequential". + + strict_reproducibility: bool + If True, all modifications have to be committed before running the experiments. + Also, if relaunching a study, it will not be possible if the code has changed. + """ + + if self.exp_args_list is None: + raise ValueError("exp_args_list is None. Please set exp_args_list before running.") + + logger.info("Preparing backends...") + self.benchmark.prepare_backends() + logger.info("Backends ready.") + + run_experiments(n_jobs, self.exp_args_list, self.dir, parallel_backend=parallel_backend) + + def append_to_journal(self, strict_reproducibility=True): + """Append the study to the journal. + + Args: + strict_reproducibility: bool + If True, incomplete experiments will raise an error. + + Raises: + ValueError: If the reproducibility information is not compatible + with the report. + """ + repro.append_to_journal( + self.reproducibility_info, + self.get_report(), + strict_reproducibility=strict_reproducibility, + ) + + def get_results(self, suffix="", also_save=True): + result_df = inspect_results.load_result_df(self.dir) + error_report = inspect_results.error_report(result_df, max_stack_trace=3, use_log=True) + summary_df = inspect_results.summarize_study(result_df) + + if also_save: + suffix = f"_{suffix}" if suffix else "" + result_df.to_csv(self.dir / f"result_df{suffix}.csv") + summary_df.to_csv(self.dir / f"summary_df{suffix}.csv") + (self.dir / f"error_report{suffix}.md").write_text(error_report) + + return result_df, summary_df, error_report + + @property + def name(self): + agent_names = [a.agent_name for a in self.agent_args] + if len(agent_names) == 1: + study_name = f"{agent_names[0]}_on_{self.benchmark.name}" + else: + study_name = f"{len(agent_names)}_agents_on_{self.benchmark.name}" + + study_name = slugify(study_name, max_length=100, allow_unicode=True) + + if self.suffix: + study_name += f"_{self.suffix}" + return study_name + + def make_dir(self, exp_root=RESULTS_DIR): + if self.dir is None: + dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}" + + self.dir = Path(exp_root) / dir_name + self.dir.mkdir(parents=True, exist_ok=True) + + def save(self): + """Pickle the study to the directory""" + + # TODO perhaps remove exp_args_list before pickling and when loading bring them from the individual directories + + self.make_dir() + + with gzip.open(self.dir / "study.pkl.gz", "wb") as f: + pickle.dump(self, f) + + def get_report(self, ignore_cache=False, ignore_stale=False): + return inspect_results.get_study_summary( + self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale + ) + + def override_max_steps(self, max_steps): + for exp_args in self.exp_args_list: + exp_args.env_args.max_steps = max_steps + + @staticmethod + def load(dir: Path) -> "Study": + dir = Path(dir) + study_path = dir / "study.pkl.gz" + if not study_path.exists() and dir.is_dir(): + # For backward compatibility + first_result = next( + inspect_results.yield_all_exp_results(savedir_base=dir, progress_fn=None) + ) + benchmark_name = first_result.exp_args.env_args.task_name.split(".")[0] + agent_args = first_result.exp_args.agent_args + study = Study(agent_args=agent_args, benchmark=benchmark_name, dir=dir) + else: + with gzip.open(dir / "study.pkl.gz", "rb") as f: + study = pickle.load(f) # type: Study + study.dir = dir + + # # just a check + # for i, exp_args in enumerate(study.exp_args_list): + # if exp_args.order != i: + # logging.warning(f"The order of the experiments is not correct. {exp_args.order} != {i}") + + return study + + @staticmethod + def load_most_recent(root_dir: Path = None, contains=None) -> "Study": + return Study.load(get_most_recent_study(root_dir, contains=contains)) + + +def get_most_recent_study( + root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None +): + """Return the most recent directory based on the date in the folder name. + + Args: + root_dir: The directory to search in + date_format: The format of the date in the folder name + contains: If not None, only consider folders that contains this string + + Returns: + Path: The most recent folder satisfying the conditions + """ + + if root_dir is None: + root_dir = RESULTS_DIR + + most_recent_folder = None + most_recent_time = datetime.min + + for item in root_dir.iterdir(): + if item.is_dir() and not item.name.startswith("_"): + if contains is not None and contains not in item.name: + continue + try: + folder_date = datetime.strptime("_".join(item.name.split("_")[:2]), date_format) + if folder_date > most_recent_time: + most_recent_time = folder_date + most_recent_folder = item + except (ValueError, IndexError): + continue + + return most_recent_folder + + +def set_demo_mode(env_args_list: list[EnvArgs]): + + for env_args in env_args_list: + env_args.viewport = {"width": 1280, "height": 720} + env_args.record_video = True + env_args.wait_for_user_message = False + env_args.slow_mo = 1000 + + +def _agents_on_benchmark( + agents: list[AgentArgs] | AgentArgs, + benchmark: bgym.Benchmark, + demo_mode=False, + logging_level: int = logging.INFO, + logging_level_stdout: int = logging.INFO, + ignore_dependencies=False, +): + """Run one or multiple agents on a benchmark. + + Args: + agents: list[AgentArgs] | AgentArgs + The agent configuration(s) to run. + benchmark: bgym.Benchmark + The benchmark to run the agents on. + demo_mode: bool + If True, the experiments will be run in demo mode. + logging_level: int + The logging level for individual jobs. + + Returns: + study: Study + """ + + if not isinstance(agents, (list, tuple)): + agents = [agents] + + if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"): + if len(agents) > 1: + raise ValueError( + f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation." + ) + + for agent in agents: + agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark + + env_args_list = benchmark.env_args_list + if demo_mode: + set_demo_mode(env_args_list) + + # exp_args_list = args.expand_cross_product( + # ExpArgs( + # agent_args=args.CrossProd(agents), + # env_args=args.CrossProd(env_args_list), + # logging_level=logging_level, + # logging_level_stdout=logging_level_stdout, + # ) + # ) # type: list[ExpArgs] + + exp_args_list = [] + + for agent in agents: + for env_args in env_args_list: + exp_args = ExpArgs( + agent_args=agent, + env_args=env_args, + logging_level=logging_level, + logging_level_stdout=logging_level_stdout, + ) + exp_args_list.append(exp_args) + + for i, exp_args in enumerate(exp_args_list): + exp_args.order = i + + # not required with ray, but keeping around if we would need it for visualwebareana on joblib + # _flag_sequential_exp(exp_args_list, benchmark) + + if not ignore_dependencies: + # populate the depends_on field based on the task dependencies in the benchmark + exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks()) + else: + logger.warning( + f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results." + ) + + return exp_args_list + + +# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): +# if benchmark.name.startswith("visualwebarena"): +# sequential_subset = benchmark.subset_from_glob("requires_reset", "True") +# sequential_subset = set( +# [env_args.task_name for env_args in sequential_subset.env_args_list] +# ) +# for exp_args in exp_args_list: +# if exp_args.env_args.task_name in sequential_subset: +# exp_args.sequential = True + + +# def ablation_study(start_agent: AgentArgs, changes, benchmark: str, demo_mode=False): +# """Ablation study of an agent. + +# Changes is a list of tuples (path_to_attribute, value) to change in the agent +# configuration. + +# Args: +# start_agent: AgentArgs +# The agent configuration to start from. + +# changes: list[tuple] +# The changes to apply to the agent configuration. + +# benchmark: str +# The benchmark to use. + +# demo_mode: bool +# If True, the experiments will be run in demo mode. + +# Returns: +# Study +# """ +# agents = args.make_ablation_study(start_agent, changes) +# study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode) +# study.suffix = "ablation_study" +# return study + + +# def random_search( +# random_agent: AgentArgs = RANDOM_SEARCH_AGENT, +# n_samples=10, +# benchmark: str = "miniwob", +# demo_mode=False, +# ): +# """ +# Random search of AgentArgs (NOTE: not fully tested since refactoring) + +# The random search mechanism will recursively search through dataclasses and +# dict to find attributes of type args.Choice. It will sample iid and replace +# with the corresponding value. + +# *WARINING* The standard errror of the experiment will usually be relatively high and +# the search space is usually big so the false discovery rate will likely be +# high. Make sure to analyze the results with caution and don't actually draw +# final conclusions from these experiments. + +# Args: +# agent: AgentArgs +# The agent configuration, with some sub-arguments defined as args.Choice. + +# n_samples: int +# The number of samples to take. + +# benchmark: str +# The benchmark to use. + +# demo_mode: bool +# If True, the experiments will be run in demo mode. + +# Returns: +# Study +# """ +# agents = args.sample_and_expand_cross_product(random_agent, n_samples) +# study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode) +# study.suffix = "random_search" +# return study diff --git a/src/agentlab/experiments/study_generators.py b/src/agentlab/experiments/study_generators.py deleted file mode 100644 index 3a2567d51..000000000 --- a/src/agentlab/experiments/study_generators.py +++ /dev/null @@ -1,272 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -import logging -from pathlib import Path - -from bgym import ExpArgs, EnvArgs - -from agentlab.agents.agent_args import AgentArgs -from agentlab.agents.generic_agent.agent_configs import RANDOM_SEARCH_AGENT, AGENT_4o_MINI -from agentlab.analyze import inspect_results -from agentlab.experiments import args -from agentlab.experiments import task_collections as tasks -from agentlab.experiments.launch_exp import run_experiments, relaunch_study -from agentlab.experiments.exp_utils import RESULTS_DIR -from agentlab.experiments import reproducibility_util as repro - - -@dataclass -class Study: - """A study coresponds to one or multiple agents evaluated on a benchmark. - - This is part of the high level API to help keep experiments organized and reproducible. - - Attributes: - exp_args_list: list[ExpArgs] - The list of experiments to run. - - benchmark_name: str - The name of the benchmark. - - agent_names: list[str] - The names of the agents. - - dir: Path - The directory where the results will be saved. - - suffix: str - A suffix to add to the study name - """ - - exp_args_list: list[ExpArgs] = None - benchmark_name: str = None - agent_names: list[str] = None - dir: Path = None - suffix: str = "" # used for adding a personnal comment to the study name - - def run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): - """Run all experiments in the study in parallel when possible. - - Args: - n_jobs: int - Number of parallel jobs. - - parallel_backend: str - Parallel backend to use. Either "joblib", "dask" or "sequential". - - strict_reproducibility: bool - If True, you will have to commit all your files before running the experiments. - """ - - if self.exp_args_list is None: - raise ValueError("exp_args_list is None. Please set exp_args_list before running.") - - self.make_dir() - self.write_reproducibility_info(strict_reproducibility=strict_reproducibility) - - run_experiments(n_jobs, self.exp_args_list, self.dir, parallel_backend=parallel_backend) - report_df = self.get_report(ignore_cache=True) - logging.info(f"Study {self.name} finished.") - logging.info("\n" + str(report_df)) - - def append_to_journal(self, strict_reproducibility=True): - """Append the study to the journal. - - Args: - strict_reproducibility: bool - If True, incomplete experiments will raise an error. - - Raises: - ValueError: If the reproducibility information is not compatible - with the report. - """ - repro.append_to_journal( - self.load_reproducibility_info(), - self.get_report(), - strict_reproducibility=strict_reproducibility, - ) - - @property - def name(self): - if len(self.agent_names) == 1: - study_name = f"{self.agent_names[0]}_on_{self.benchmark_name}" - else: - study_name = f"{len(self.agent_names)}_agents_on_{self.benchmark_name}" - if self.suffix: - study_name += f"_{self.suffix}" - return study_name - - def make_dir(self, exp_root=RESULTS_DIR): - if self.dir is None: - dir_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{self.name}" - - self.dir = Path(exp_root) / dir_name - self.dir.mkdir(parents=True, exist_ok=True) - - def write_reproducibility_info(self, comment=None, strict_reproducibility=False): - info = repro.get_reproducibility_info( - self.agent_names, - self.benchmark_name, - comment, - ignore_changes=not strict_reproducibility, - ) - return repro.save_reproducibility_info(self.dir, info, strict_reproducibility) - - def get_report(self, ignore_cache=False, ignore_stale=False): - return inspect_results.get_study_summary( - self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale - ) - - def load_reproducibility_info(self): - return repro.load_reproducibility_info(self.dir) - - -def make_relaunch_study(study_dir, relaunch_mode="incomplete_or_error"): - """Create a study from an existing study directory. - - It will search for all experiments that needs to be relaunched depending on - `relaunch_mode`. - - Args: - study_dir: Path - The directory where the experiments are saved. - relaunch_mode: str - Find all incomplete experiments and relaunch them. - - "incomplete_only": relaunch only the incomplete experiments. - - "incomplete_or_error": relaunch incomplete or errors. - """ - study = Study(dir=study_dir) - study.exp_args_list, _ = relaunch_study(study.dir, relaunch_mode=relaunch_mode) - info = study.load_reproducibility_info() - study.benchmark_name = info["benchmark"] - study.agent_names = info["agent_names"] - return study - - -def set_demo_mode(env_args_list: list[EnvArgs]): - - for env_args in env_args_list: - env_args.viewport = {"width": 1280, "height": 720} - env_args.record_video = True - env_args.wait_for_user_message = False - env_args.slow_mo = 1000 - - -def run_agents_on_benchmark( - agents: list[AgentArgs] | AgentArgs = AGENT_4o_MINI, - benchmark: str = "miniwob", - demo_mode=False, - log_level=logging.INFO, -): - """Run one or multiple agents on a benchmark. - - Args: - agents: list[AgentArgs] | AgentArgs - The agent configuration(s) to run. - - benchmark: str - The benchmark to use. One of: - * miniwob - * webarena - * workarena.l1 - * workarena.l2 - * workarena.l3 - * miniwob_tiny_test - - Returns: - study: Study - """ - - if not isinstance(agents, (list, tuple)): - agents = [agents] - - for agent in agents: - agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark - - env_args_list = tasks.get_benchmark_env_args( - benchmark, meta_seed=43, max_steps=None, n_repeat=None - ) - if demo_mode: - set_demo_mode(env_args_list) - - exp_args_list = args.expand_cross_product( - ExpArgs( - agent_args=args.CrossProd(agents), - env_args=args.CrossProd(env_args_list), - logging_level=log_level, - ) - ) - - return Study( - exp_args_list=exp_args_list, - benchmark_name=benchmark, - agent_names=[a.agent_name for a in agents], - ) - - -def ablation_study(start_agent: AgentArgs, changes, benchmark: str, demo_mode=False): - """Ablation study of an agent. - - Changes is a list of tuples (path_to_attribute, value) to change in the agent - configuration. - - Args: - start_agent: AgentArgs - The agent configuration to start from. - - changes: list[tuple] - The changes to apply to the agent configuration. - - benchmark: str - The benchmark to use. - - demo_mode: bool - If True, the experiments will be run in demo mode. - - Returns: - Study - """ - agents = args.make_ablation_study(start_agent, changes) - study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode) - study.suffix = "ablation_study" - return study - - -def random_search( - random_agent: AgentArgs = RANDOM_SEARCH_AGENT, - n_samples=10, - benchmark: str = "miniwob", - demo_mode=False, -): - """ - Random search of AgentArgs (NOTE: not fully tested since refactoring) - - The random search mechanism will recursively search through dataclasses and - dict to find attributes of type args.Choice. It will sample iid and replace - with the corresponding value. - - *WARINING* The standard errror of the experiment will usually be relatively high and - the search space is usually big so the false discovery rate will likely be - high. Make sure to analyze the results with caution and don't actually draw - final conclusions from these experiments. - - Args: - agent: AgentArgs - The agent configuration, with some sub-arguments defined as args.Choice. - - n_samples: int - The number of samples to take. - - benchmark: str - The benchmark to use. - - demo_mode: bool - If True, the experiments will be run in demo mode. - - Returns: - Study - """ - agents = args.sample_and_expand_cross_product(random_agent, n_samples) - study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode) - study.suffix = "random_search" - return study diff --git a/src/agentlab/experiments/task_collections.py b/src/agentlab/experiments/task_collections.py deleted file mode 100644 index 66bf00b79..000000000 --- a/src/agentlab/experiments/task_collections.py +++ /dev/null @@ -1,212 +0,0 @@ -import logging -import time as t -from pathlib import Path - -import numpy as np -import pandas as pd - -logger = logging.getLogger(__name__) - -from browsergym.experiments import EnvArgs -from browsergym.webarena import ALL_WEBARENA_TASK_IDS - -df = pd.read_csv(Path(__file__).parent / "miniwob_tasks_all.csv") -# append miniwob. to task_name column -df["task_name"] = "miniwob." + df["task_name"] -MINIWOB_ALL = df["task_name"].tolist() -tasks_eval = df[df["miniwob_category"].isin(["original", "additional", "hidden test"])][ - "task_name" -].tolist() -miniwob_debug = df[df["miniwob_category"].isin(["debug"])]["task_name"].tolist() -MINIWOB_TINY_TEST = ["miniwob.click-dialog", "miniwob.click-checkboxes"] - -assert len(MINIWOB_ALL) == 125 -assert len(tasks_eval) == 107 -assert len(miniwob_debug) == 12 -assert len(MINIWOB_TINY_TEST) == 2 - - -webgum_tasks = [ - "miniwob.book-flight", - "miniwob.choose-date", - "miniwob.choose-date-easy", - "miniwob.choose-date-medium", - "miniwob.choose-list", - "miniwob.click-button", - "miniwob.click-button-sequence", - "miniwob.click-checkboxes", - "miniwob.click-checkboxes-large", - "miniwob.click-checkboxes-soft", - "miniwob.click-checkboxes-transfer", - "miniwob.click-collapsible", - "miniwob.click-collapsible-2", - "miniwob.click-color", - "miniwob.click-dialog", - "miniwob.click-dialog-2", - "miniwob.click-link", - "miniwob.click-menu", - "miniwob.click-option", - "miniwob.click-pie", - "miniwob.click-scroll-list", - "miniwob.click-shades", - "miniwob.click-shape", - "miniwob.click-tab", - "miniwob.click-tab-2", - "miniwob.click-tab-2-hard", - "miniwob.click-test", - "miniwob.click-test-2", - "miniwob.click-widget", - "miniwob.count-shape", - "miniwob.email-inbox", - "miniwob.email-inbox-forward-nl", - "miniwob.email-inbox-forward-nl-turk", - "miniwob.email-inbox-nl-turk", - "miniwob.enter-date", - "miniwob.enter-password", - "miniwob.enter-text", - "miniwob.enter-text-dynamic", - "miniwob.enter-time", - "miniwob.focus-text", - "miniwob.focus-text-2", - "miniwob.grid-coordinate", - "miniwob.guess-number", - "miniwob.identify-shape", - "miniwob.login-user", - "miniwob.login-user-popup", - "miniwob.multi-layouts", - "miniwob.multi-orderings", - "miniwob.navigate-tree", - "miniwob.search-engine", - "miniwob.social-media", - "miniwob.social-media-all", - "miniwob.social-media-some", - "miniwob.tic-tac-toe", - "miniwob.use-autocomplete", - "miniwob.use-spinner", -] - - -# TODO add miniwob_tiny_test as benchmarks -def get_benchmark_env_args( - benchmark_name: str, meta_seed=42, max_steps=None, n_repeat=None -) -> list[EnvArgs]: - """ - Returns a list of EnvArgs for the given benchmark_name. - - Args: - benchmark_name: A string representing the benchmark name. - meta_seed: The seed for the random number generator. - max_steps: None or int. The maximum number of steps for each task. - if None, it will use the default value for the benchmark. - n_repeat: None or int. The number of seeds for each task. - if None, it will use the default value for the benchmark. - is_agent_curriculum: wether to use the agent curriculum or the human curriculum. - - Returns: - A list of EnvArgs. - - Raises: - ValueError: If the benchmark_name is not recognized, or if the benchmark_name is not - followed by a subcategory for workarena. - """ - env_args_list = [] - rng = np.random.RandomState(meta_seed) - - filters = benchmark_name.split(".") - benchmark_id = filters[0] - if filters[0] == "workarena": - benchmark_id = "workarena." + filters[1] - - max_steps_default = { - "workarena.l1": 15, - "workarena.l2": 50, - "workarena.l3": 50, - "webarena": 15, - "miniwob": 10, - "miniwob_tiny_test": 5, - "weblinx": None, - } - - n_repeat_default = { - "workarena.l1": 10, - "workarena.l2": 1, - "workarena.l3": 1, - "webarena": 1, - "miniwob": 5, - "miniwob_tiny_test": 2, - "weblinx": 1, - } - - if max_steps is None: - max_steps = max_steps_default.get(benchmark_id, None) - if n_repeat is None: - n_repeat = n_repeat_default.get(benchmark_id, 1) - else: - if benchmark_id == "webarena" and n_repeat != 1: - logger.warning( - f"webarena is expected to have only one seed per task. Ignoring n_seeds_default = {n_repeat}" - ) - n_repeat = 1 - - if benchmark_name.startswith("workarena"): - t0 = t.time() - from browsergym.workarena import ALL_WORKARENA_TASKS, ATOMIC_TASKS, get_all_tasks_agents - - dt = t.time() - t0 - print(f"done importing workarena, took {dt:.2f} seconds") - - if len(filters) < 2: - raise ValueError(f"You must specify the sub set of workarena, e.g.: workarena.l2.") - - if benchmark_name == "workarena.l1.sort": - task_names = [task.get_task_id() for task in ATOMIC_TASKS] - task_names = [task for task in task_names if "sort" in task] - env_args_list = _make_env_args(task_names, max_steps, n_repeat, rng) - - else: - for task, seed in get_all_tasks_agents( - filter=".".join(filters[1:]), - meta_seed=meta_seed, - n_seed_l1=n_repeat, - ): - task_name = task.get_task_id() - env_args_list.append( - EnvArgs(task_name=task_name, task_seed=seed, max_steps=max_steps) - ) - - elif benchmark_name == "webarena": - from browsergym.webarena import ALL_WEBARENA_TASK_IDS - - env_args_list = _make_env_args(ALL_WEBARENA_TASK_IDS, max_steps, n_repeat, rng) - elif benchmark_name.startswith("miniwob"): - miniwob_benchmarks_map = { - "miniwob": MINIWOB_ALL, - "miniwob_tiny_test": MINIWOB_TINY_TEST, - } - env_args_list = _make_env_args( - miniwob_benchmarks_map[benchmark_name], max_steps, n_repeat, rng - ) - elif benchmark_name.startswith("weblinx"): - from weblinx_browsergym import ALL_WEBLINX_TASK_IDS - - env_args_list = _make_env_args(ALL_WEBLINX_TASK_IDS, max_steps, n_repeat, rng) - else: - raise ValueError(f"Unknown benchmark name: {benchmark_name}") - - return env_args_list - - -def _make_env_args(task_list, max_steps, n_seeds_default, rng): - env_args_list = [] - for task in task_list: - for seed in rng.randint(0, 100, n_seeds_default): - env_args_list.append(EnvArgs(task_name=task, task_seed=int(seed), max_steps=max_steps)) - return env_args_list - - -if __name__ == "__main__": - env_args_list = get_benchmark_env_args("workarena.l2") - print(f"Number of tasks: {len(env_args_list)}") - for env_args in env_args_list: - if "infeasible" in env_args.task_name: - print(env_args.task_seed, env_args.task_name) diff --git a/src/agentlab/experiments/view_dep_graph.py b/src/agentlab/experiments/view_dep_graph.py new file mode 100644 index 000000000..0639507bc --- /dev/null +++ b/src/agentlab/experiments/view_dep_graph.py @@ -0,0 +1,322 @@ +import math +import bgym +import matplotlib.pyplot as plt + +import networkx as nx +import numpy as np + + +def clean_dict(dependency_dict: dict[str, list[str]]) -> dict[str, list[str]]: + new_dep = {} + for key, deps in dependency_dict.items(): + new_key = key.split(".")[-1] + + new_dep[new_key] = [dep.split(".")[-1] for dep in deps] + return new_dep + + +def dict_to_networkx(dependency_dict: dict[str, list[str]]) -> nx.DiGraph: + + G = nx.DiGraph() + i = 0 + # Add edges from each node to its dependencies + for node, dependencies in dependency_dict.items(): + i += 1 + if i > 20: + pass + + print(node, dependencies) + # Add edges from the node to each of its dependencies + for dep in dependencies: + G.add_edge(dep, node) + return G + + +def plot_graph(G, ax, title=None, node_color="lightblue", node_size=40, font_size=8): + """ + Plot a single graph component on the given matplotlib axis. + + Args: + G: NetworkX graph (should be a single connected component) + ax: Matplotlib axis to plot on + title: Optional title for the subplot + node_color: Color for the nodes + node_size: Size of the nodes + font_size: Size of the node labels + """ + # Use a simple layout for better performance + # pos = nx.spring_layout(G, k=0.1, iterations=100) + + pos = nx.kamada_kawai_layout(G) + + # pos = nx.spectral_layout(G) + + def name_to_size(name): + if "-" in name: + start, end = name.split("-") + + n_nodes = int(end) - int(start) + 1 + else: + n_nodes = 1 + size_factor = node_size / 10 + return n_nodes * size_factor + + # compute size based on name + sizes = [name_to_size(name) for name in G.nodes] + + nx.draw( + G, + pos, + ax=ax, + with_labels=True, + node_color=node_color, + node_size=sizes, + font_size=font_size, + font_weight="normal", + arrows=True, + arrowsize=15, + ) + + if title: + ax.set_title(title) + ax.axis("off") + + +def plot_components_grid( + components, max_cols=4, node_color="lightblue", node_size=2000, font_size=10 +): + """ + Plot components in a grid layout. + + Args: + components: List of NetworkX graphs, one per component + max_cols: Maximum number of columns in the grid + node_color: Color for the nodes + node_size: Size of the nodes + font_size: Size of the node labels + + Returns: + matplotlib figure + """ + n_components = len(components) + + if n_components == 0: + print("No components found") + return None + + # Calculate grid dimensions + ncols = min(n_components, max_cols) + nrows = math.ceil(n_components / ncols) + + # Create figure with a reasonable size per subplot + fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows)) + fig.suptitle("Dependency Graph Components", size=16) + + # Make axes iterable even if there's only one + if n_components == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = np.array([axes]) + elif ncols == 1: + axes = axes.reshape(-1, 1) + + # Plot each component + for idx, component in enumerate(components): + i, j = divmod(idx, ncols) + title = f"Component {idx+1} ({component.number_of_nodes()} nodes)" + plot_graph( + component, + axes[i, j], + title, + node_color=node_color, + node_size=node_size, + font_size=font_size, + ) + + # Remove empty subplots + for idx in range(n_components, nrows * ncols): + i, j = divmod(idx, ncols) + axes[i, j].remove() + + plt.tight_layout() + return fig + + +def compress_sequential_chains(dep_dict: dict[str, list[str]]) -> dict[str, list[str]]: + """ + Compress chains of sequential numbers in a dependency dictionary. + Returns a new dictionary with compressed chains using range notation. + + Args: + dep_dict: Dictionary mapping string numbers to list of string number dependencies + + Returns: + Dictionary with compressed chains using range notation + """ + # Convert to integers for easier processing + int_dict = {int(k): [int(x) for x in v] for k, v in dep_dict.items()} + + # Find chains + chains = [] + current_chain = [] + + # Sort nodes for sequential processing + nodes = sorted(int_dict.keys()) + + i = 0 + while i < len(nodes): + node = nodes[i] + + # Start new chain + if not current_chain: + current_chain = [node] + i += 1 + continue + + # Check if this node continues the chain + last_node = current_chain[-1] + + # Conditions for chain continuation: + # 1. Numbers are consecutive + # 2. Current node has exactly one dependency + # 3. That dependency is the previous node in chain + # 4. The previous node has exactly one successor + is_consecutive = node == last_node + 1 + has_single_dep = len(int_dict[node]) == 1 + deps_on_last = has_single_dep and int_dict[node][0] == last_node + last_has_single_successor = sum(1 for k, v in int_dict.items() if last_node in v) == 1 + + if is_consecutive and deps_on_last and last_has_single_successor: + current_chain.append(node) + else: + if len(current_chain) > 1: + chains.append(current_chain) + current_chain = [node] + + i += 1 + + # Add last chain if it exists + if len(current_chain) > 1: + chains.append(current_chain) + + # Create compressed dictionary + compressed_dict = {} + processed_nodes = set() + + # Add compressed chains + for chain in chains: + chain_name = f"{chain[0]}-{chain[-1]}" + # Find dependencies of first node in chain + deps = int_dict[chain[0]] + compressed_dict[chain_name] = [str(d) for d in deps] + processed_nodes.update(chain) + + # Add remaining non-chain nodes + for node in nodes: + if node not in processed_nodes: + compressed_dict[str(node)] = [str(d) for d in int_dict[node]] + + # Update dependencies to use compressed names + for k in compressed_dict: + deps = compressed_dict[k] + new_deps = [] + for dep in deps: + dep_int = int(dep) + # Find if this dependency is part of a chain + chain_found = False + for chain in chains: + if dep_int in chain: + new_deps.append(f"{chain[0]}-{chain[-1]}") + chain_found = True + break + if not chain_found: + new_deps.append(dep) + compressed_dict[k] = new_deps + + return compressed_dict + + +def compress_chains(G): + """ + Compress chains in a directed graph by merging nodes that have single parent and single child. + + Args: + G: NetworkX directed graph + + Returns: + NetworkX directed graph with compressed chains + """ + G_compressed = G.copy() + processed_nodes = set() + + while True: + # Find nodes with exactly one parent and one child + nodes_to_compress = [] + for node in list( + G_compressed.nodes() + ): # Create a list to avoid modification during iteration + if node in processed_nodes: + continue + + predecessors = list(G_compressed.predecessors(node)) + successors = list(G_compressed.successors(node)) + + if len(predecessors) == 1 and len(successors) == 1: + pred = predecessors[0] + succ = successors[0] + + # Skip if any node in the chain is already processed + if pred in processed_nodes or succ in processed_nodes: + continue + + # Only compress if middle node has single parent/child + pred_preds = list(G_compressed.predecessors(pred)) + succ_succs = list(G_compressed.successors(succ)) + + if len(pred_preds) <= 1 and len(succ_succs) <= 1: + nodes_to_compress.append((pred, node, succ)) + processed_nodes.update([pred, node, succ]) + + if not nodes_to_compress: + break + + # Process each chain + for pred, mid, succ in nodes_to_compress: + if not all(G_compressed.has_node(n) for n in [pred, mid, succ]): + continue + + # Create new merged node name + new_node = ",".join(str(n) for n in [pred, mid, succ]) + + # Add the new node + G_compressed.add_node(new_node) + + # Add edges from all predecessors of first node + for p in list(G_compressed.predecessors(pred)): + G_compressed.add_edge(p, new_node) + + # Add edges to all successors of last node + for s in list(G_compressed.successors(succ)): + G_compressed.add_edge(new_node, s) + + # Remove the old nodes + G_compressed.remove_nodes_from([pred, mid, succ]) + + return G_compressed + + +# benchmark = bgym.DEFAULT_BENCHMARKS["webarena"]() +benchmark = bgym.DEFAULT_BENCHMARKS["visualwebarena"]() + +dep_graph = benchmark.dependency_graph_over_tasks() +dep_graph = clean_dict(dep_graph) + +dep_graph = compress_sequential_chains(dep_graph) +graph = dict_to_networkx(dep_graph) + +# graph = compress_chains(graph) + +components = nx.weakly_connected_components(graph) +components = [graph.subgraph(component).copy() for component in components] +plot_components_grid(components) +plt.show() diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index a4df0a977..2ed8f0d6e 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -13,6 +13,7 @@ import agentlab.llm.tracking as tracking from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs from agentlab.llm.huggingface_utils import HFBaseChatModel +from agentlab.llm.llm_utils import Discussion def make_system_message(content: str) -> dict: @@ -30,8 +31,18 @@ def make_assistant_message(content: str) -> dict: class CheatMiniWoBLLM(AbstractChatModel): """For unit-testing purposes only. It only work with miniwob.click-test task.""" + def __init__(self, wait_time=0) -> None: + self.wait_time = wait_time + def __call__(self, messages) -> str: - prompt = messages[-1]["content"] + if self.wait_time > 0: + print(f"Waiting for {self.wait_time} seconds") + time.sleep(self.wait_time) + + if isinstance(messages, Discussion): + prompt = messages.to_string() + else: + prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: @@ -54,9 +65,10 @@ class CheatMiniWoBLLMArgs: max_total_tokens = 10240 max_input_tokens = 8000 max_new_tokens = 128 + wait_time: int = 0 def make_model(self): - return CheatMiniWoBLLM() + return CheatMiniWoBLLM(self.wait_time) def prepare_server(self): pass @@ -196,6 +208,10 @@ def handle_error(error, itr, min_retry_wait_time, max_retry): return error_type +class OpenRouterError(openai.OpenAIError): + pass + + class ChatModel(AbstractChatModel): def __init__( self, @@ -262,6 +278,12 @@ def __call__(self, messages: list[dict]) -> dict: temperature=self.temperature, max_tokens=self.max_tokens, ) + + if completion.usage is None: + raise OpenRouterError( + "The completion object does not contain usage information. This is likely a bug in the OpenRouter API." + ) + self.success = True break except openai.OpenAIError as e: diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 30889be3d..ec6086868 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -20,41 +20,41 @@ "openai/gpt-4o-mini-2024-07-18": OpenAIModelArgs( model_name="gpt-4o-mini-2024-07-18", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=128_000, + max_new_tokens=16_384, vision_support=True, ), "openai/gpt-4-1106-preview": OpenAIModelArgs( model_name="gpt-4-1106-preview", max_total_tokens=128_000, - max_input_tokens=40_000, # make sure we don't bust budget - max_new_tokens=4000, + max_input_tokens=128_000, + max_new_tokens=4_096, ), "openai/gpt-4-vision-preview": OpenAIModelArgs( model_name="gpt-4-vision-preview", max_total_tokens=128_000, - max_input_tokens=40_000, # make sure we don't bust budget - max_new_tokens=4000, # I think this model has very small default value if we don't set max_new_tokens + max_input_tokens=128_000, + max_new_tokens=16_384, # I think this model has very small default value if we don't set max_new_tokens vision_support=True, ), "openai/gpt-4o-2024-05-13": OpenAIModelArgs( model_name="gpt-4o-2024-05-13", max_total_tokens=128_000, - max_input_tokens=40_000, # make sure we don't bust budget - max_new_tokens=4000, # I think this model has very small default value if we don't set max_new_tokens + max_input_tokens=128_000, + max_new_tokens=4_096, # I think this model has very small default value if we don't set max_new_tokens vision_support=True, ), "openai/gpt-3.5-turbo-0125": OpenAIModelArgs( model_name="gpt-3.5-turbo-0125", max_total_tokens=16_384, - max_input_tokens=15_000, - max_new_tokens=1_000, + max_input_tokens=16_384, + max_new_tokens=4096, ), "openai/gpt-3.5-turbo-1106": OpenAIModelArgs( model_name="gpt-3.5-turbo-1106", max_total_tokens=16_384, - max_input_tokens=15_000, - max_new_tokens=1_000, + max_input_tokens=16_384, + max_new_tokens=4096, ), "azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs( model_name="gpt-35-turbo", @@ -67,15 +67,25 @@ model_name="gpt-4o", deployment_name="gpt-4o-2024-05-13", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4_000, + max_input_tokens=100_000, + max_new_tokens=16_384, + vision_support=True, ), "azure/gpt-4o-2024-08-06": AzureModelArgs( model_name="gpt-4o", deployment_name="gpt-4o-2024-08-06", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4o-mini-2024-07-18": AzureModelArgs( + model_name="gpt-4o-mini", + deployment_name="gpt-4o-mini-2024-07-18", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, ), # ---------------- OSS LLMs ----------------# "meta-llama/Meta-Llama-3-70B-Instruct": SelfHostedModelArgs( @@ -106,44 +116,45 @@ "openrouter/meta-llama/llama-3.1-405b-instruct": OpenRouterModelArgs( model_name="meta-llama/llama-3.1-405b-instruct", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=100_000, + max_new_tokens=28_000, temperature=1e-1, ), "openrouter/meta-llama/llama-3.1-70b-instruct": OpenRouterModelArgs( model_name="meta-llama/llama-3.1-70b-instruct", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=100_000, + max_new_tokens=28_000, temperature=1e-1, ), "openrouter/meta-llama/llama-3-70b-instruct": OpenRouterModelArgs( model_name="meta-llama/llama-3-70b-instruct", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=100_000, + max_new_tokens=28_000, temperature=1e-1, ), "openrouter/meta-llama/llama-3.1-8b-instruct:free": OpenRouterModelArgs( model_name="meta-llama/llama-3.1-8b-instruct:free", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=100_000, + max_new_tokens=28_000, temperature=1e-1, ), "openrouter/meta-llama/llama-3.1-8b-instruct": OpenRouterModelArgs( model_name="meta-llama/llama-3.1-8b-instruct", max_total_tokens=128_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=100_000, + max_new_tokens=28_000, temperature=1e-1, ), "openrouter/anthropic/claude-3.5-sonnet:beta": OpenRouterModelArgs( model_name="anthropic/claude-3.5-sonnet:beta", max_total_tokens=200_000, - max_input_tokens=40_000, - max_new_tokens=4000, + max_input_tokens=200_000, + max_new_tokens=8_192, temperature=1e-1, + vision_support=True, ), "openrouter/qwen/qwen-2-72b-instruct": OpenRouterModelArgs( model_name="qwen/qwen-2-72b-instruct", @@ -152,4 +163,11 @@ max_new_tokens=2_000, temperature=1e-1, ), + "openrouter/openai/o1-mini-2024-09-12": OpenRouterModelArgs( + model_name="openai/o1-mini-2024-09-12", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=64_000, + temperature=1e-1, + ), } diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index c3d750098..eaa2a5e02 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -6,8 +6,10 @@ import os import re import time +from copy import deepcopy from functools import cache from typing import TYPE_CHECKING +from typing import Any, Union from warnings import warn import numpy as np @@ -23,14 +25,14 @@ def messages_to_dict(messages: list[dict] | list[BaseMessage]) -> dict: - new_messages = [] + new_messages = Discussion() for m in messages: if isinstance(m, dict): - new_messages.append(m) + new_messages.add_message(m) elif isinstance(m, str): - new_messages.append({"role": "", "content": m}) + new_messages.add_message({"role": "", "content": m}) elif isinstance(m, BaseMessage): - new_messages.append(convert_message_to_dict(m)) + new_messages.add_message(convert_message_to_dict(m)) else: raise ValueError(f"Unknown message type: {type(m)}") return new_messages @@ -42,7 +44,7 @@ class RetryError(ValueError): def retry( chat: "ChatModel", - messages: list[dict], + messages: "Discussion", n_retry: int, parser: callable, log: bool = True, @@ -80,8 +82,8 @@ def retry( tries = 0 while tries < n_retry: answer = chat(messages) - messages.append(answer) # TODO: could we change this to not use inplace modifications ? - + # TODO: could we change this to not use inplace modifications ? + messages.append(answer) try: return parser(answer["content"]) except ParseError as parsing_error: @@ -322,6 +324,157 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): return f"data:image/jpeg;base64,{image_base64}" +class BaseMessage(dict): + def __init__(self, role: str, content: Union[str, list[dict]]): + self["role"] = role + self["content"] = deepcopy(content) + + def __str__(self) -> str: + if isinstance(self["content"], str): + return self["content"] + if not all(elem["type"] == "text" for elem in self["content"]): + logging.warning( + "The content of the message has images, which are not displayed in the string representation." + ) + return "\n".join([elem["text"] for elem in self["content"] if elem["type"] == "text"]) + + def add_content(self, type: str, content: Any): + if isinstance(self["content"], str): + text = self["content"] + self["content"] = [] + self["content"].append({"type": "text", "text": text}) + self["content"].append({"type": type, type: content}) + + def add_text(self, text: str): + self.add_content("text", text) + + def add_image(self, image: np.ndarray | Image.Image | str, detail: str = None): + if not isinstance(image, str): + image_url = image_to_jpg_base64_url(image) + else: + image_url = image + if detail: + self.add_content("image_url", {"url": image_url, "detail": detail}) + else: + self.add_content("image_url", image_url) + + def to_markdown(self): + if isinstance(self["content"], str): + return f"\n```\n{self['content']}\n```\n" + res = [] + for elem in self["content"]: + # add texts between ticks and images + if elem["type"] == "text": + res.append(f"\n```\n{elem['text']}\n```\n") + elif elem["type"] == "image_url": + img_str = ( + elem["image_url"] + if isinstance(elem["image_url"], str) + else elem["image_url"]["url"] + ) + res.append(f"![image]({img_str})") + return "\n".join(res) + + def merge(self): + """Merges content elements of type 'text' if they are adjacent.""" + if isinstance(self["content"], str): + return + new_content = [] + for elem in self["content"]: + if elem["type"] == "text": + if new_content and new_content[-1]["type"] == "text": + new_content[-1]["text"] += "\n" + elem["text"] + else: + new_content.append(elem) + else: + new_content.append(elem) + self["content"] = new_content + + +class SystemMessage(BaseMessage): + def __init__(self, content: Union[str, list[dict]]): + super().__init__("system", content) + + +class HumanMessage(BaseMessage): + def __init__(self, content: Union[str, list[dict]]): + super().__init__("user", content) + + +class AIMessage(BaseMessage): + def __init__(self, content: Union[str, list[dict]]): + super().__init__("assistant", content) + + +class Discussion: + def __init__(self, messages: Union[list[BaseMessage], BaseMessage] = None): + if isinstance(messages, BaseMessage): + messages = [messages] + elif messages is None: + messages = [] + self.messages = messages + + @property + def last_message(self): + return self.messages[-1] + + def merge(self): + for m in self.messages: + m.merge() + + def __str__(self) -> str: + return "\n".join(str(m) for m in self.messages) + + def to_string(self): + self.merge() + return str(self) + + def to_openai(self): + self.merge() + return self.messages + + def add_message( + self, + message: BaseMessage | dict = None, + role: str = None, + content: Union[str, list[dict]] = None, + ): + if message is None: + message = BaseMessage(role, content) + else: + if isinstance(message, dict): + message = BaseMessage(**message) + self.messages.append(message) + + def append(self, message: BaseMessage | dict): + self.add_message(message) + + def add_content(self, type: str, content: Any): + """Add content to the last message.""" + self.last_message.add_content(type, content) + + def add_text(self, text: str): + """Add text to the last message.""" + self.last_message.add_text(text) + + def add_image(self, image: np.ndarray | Image.Image | str, detail: str = None): + """Add an image to the last message.""" + self.last_message.add_image(image, detail) + + def __iter__(self): + return iter(self.messages) + + def __len__(self): + return len(self.messages) + + def __getitem__(self, key): + return self.messages[key] + + def to_markdown(self): + self.merge() + return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)]) + + if __name__ == "__main__": # model_to_download = "THUDM/agentlm-70b" diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 7e2761ac6..8e3d812a0 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,3 +1,4 @@ +from functools import cache import os import threading from contextlib import contextmanager @@ -61,6 +62,7 @@ def wrapper(self, obs): return wrapper +@cache def get_pricing_openrouter(): api_key = os.getenv("OPENROUTER_API_KEY") assert api_key, "OpenRouter API key is required" diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index 0b2c31f28..2632f66b5 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -11,6 +11,7 @@ from agentlab.analyze import inspect_results from agentlab.experiments import launch_exp from agentlab.llm.chat_api import BaseModelArgs, CheatMiniWoBLLMArgs +from agentlab.llm.llm_utils import Discussion def test_generic_agent(): @@ -24,7 +25,9 @@ def test_generic_agent(): with tempfile.TemporaryDirectory() as tmp_dir: - launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + launch_exp.run_experiments( + 1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib" + ) result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) @@ -55,7 +58,10 @@ def __call__(self, messages) -> str: self.retry_count += 1 return dict(role="assistant", content="I'm retrying") - prompt = messages[1].get("content", "") + if isinstance(messages, Discussion): + prompt = messages.to_string() + else: + prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: @@ -93,7 +99,10 @@ class CheatLLM_LLMError: def __call__(self, messages) -> str: if self.success: - prompt = messages[1].get("content", "") + if isinstance(messages, Discussion): + prompt = messages.to_string() + else: + prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: @@ -137,9 +146,12 @@ def test_generic_agent_parse_retry(): ) with tempfile.TemporaryDirectory() as tmp_dir: - launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + # TODO why these tests don't work with ray backend? + launch_exp.run_experiments( + 1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib" + ) result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) - + print(result_record) target = { "stats.cum_n_retry": 2, "stats.cum_busted_retry": 0, @@ -162,7 +174,9 @@ def test_bust_parse_retry(): ) with tempfile.TemporaryDirectory() as tmp_dir: - launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + launch_exp.run_experiments( + 1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib" + ) result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) target = { @@ -188,7 +202,9 @@ def test_llm_error_success(): ) with tempfile.TemporaryDirectory() as tmp_dir: - launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + launch_exp.run_experiments( + 1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib" + ) result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) target = { @@ -213,7 +229,9 @@ def test_llm_error_no_success(): ) with tempfile.TemporaryDirectory() as tmp_dir: - launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + launch_exp.run_experiments( + 1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib" + ) result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) target = { @@ -229,4 +247,4 @@ def test_llm_error_no_success(): if __name__ == "__main__": # test_generic_agent() - test_llm_error_no_success() + test_generic_agent_parse_retry() diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index 712bc4db2..cc1f9036d 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -1,10 +1,14 @@ from copy import deepcopy +import bgym import pytest from agentlab.agents import dynamic_prompting as dp from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 -from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags, MainPrompt +from agentlab.agents.generic_agent.generic_agent_prompt import ( + GenericPromptFlags, + MainPrompt, +) from agentlab.llm.llm_utils import count_tokens html_template = """ @@ -19,30 +23,31 @@ """ +base_obs = { + "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], + "chat_messages": [{"role": "user", "message": "do this and that"}], + "axtree_txt": "[1] Click me", + "focused_element_bid": "45-256", + "open_pages_urls": ["https://example.com"], + "open_pages_titles": ["Example"], + "active_page_index": 0, +} OBS_HISTORY = [ - { - "goal": "do this and that", - "chat_messages": [{"role": "user", "message": "do this and that"}], + base_obs + | { "pruned_html": html_template.format(1), - "axtree_txt": "[1] Click me", - "focused_element_bid": "45-256", "last_action_error": "", }, - { - "goal": "do this and that", - "chat_messages": [{"role": "user", "message": "do this and that"}], + base_obs + | { "pruned_html": html_template.format(2), - "axtree_txt": "[1] Click me", - "focused_element_bid": "45-256", "last_action_error": "Hey, this is an error in the past", }, - { - "goal": "do this and that", - "chat_messages": [{"role": "user", "message": "do this and that"}], + base_obs + | { "pruned_html": html_template.format(3), - "axtree_txt": "[1] Click me", - "focused_element_bid": "45-256", "last_action_error": "Hey, there is an error now", }, ] @@ -54,6 +59,7 @@ obs=dp.ObsFlags( use_html=True, use_ax_tree=True, + use_tabs=True, use_focused_element=True, use_error_logs=True, use_history=True, @@ -70,7 +76,12 @@ filter_visible_elements_only=True, ), action=dp.ActionFlags( - multi_actions=True, + action_set=bgym.HighLevelActionSetArgs( + subsets=["bid"], + multiaction=True, + ), + long_description=True, + individual_examples=True, ), use_plan=True, use_criticise=True, @@ -95,6 +106,10 @@ "obs.use_ax_tree", ("AXTree:", "Click me"), ), + ( + "obs.use_tabs", + ("Currently open tabs:", "(active tab)"), + ), ( "obs.use_focused_element", ("Focused element:", "bid='45-256'"), @@ -144,10 +159,10 @@ "use_abstract_example", ("# Abstract Example",), ), - ( - "action.multi_actions", - ("One or several actions, separated by new lines",), - ), + # ( + # "action.action_set.multiaction", + # ("One or several actions, separated by new lines",), + # ), ] @@ -156,7 +171,7 @@ def test_shrinking_observation(): flags.obs.use_html = True prompt_maker = MainPrompt( - action_set=dp.HighLevelActionSet(), + action_set=bgym.HighLevelActionSet(), obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, @@ -166,9 +181,9 @@ def test_shrinking_observation(): flags=flags, ) - prompt = prompt_maker.prompt - new_prompt = dp.fit_tokens( - prompt_maker, max_prompt_tokens=count_tokens(prompt) - 1, max_iterations=7 + prompt = str(prompt_maker.prompt) + new_prompt = str( + dp.fit_tokens(prompt_maker, max_prompt_tokens=count_tokens(prompt) - 1, max_iterations=7) ) assert count_tokens(new_prompt) < count_tokens(prompt) assert "[1] Click me" in prompt @@ -198,16 +213,18 @@ def test_main_prompt_elements_gone_one_at_a_time(flag_name: str, expected_prompt memories = MEMORIES # Initialize MainPrompt - prompt = MainPrompt( - action_set=dp.make_action_set(flags.action), - obs_history=OBS_HISTORY, - actions=ACTIONS, - memories=memories, - thoughts=THOUGHTS, - previous_plan="1- think\n2- do it", - step=2, - flags=flags, - ).prompt + prompt = str( + MainPrompt( + action_set=flags.action.action_set.make_action_set(), + obs_history=OBS_HISTORY, + actions=ACTIONS, + memories=memories, + thoughts=THOUGHTS, + previous_plan="1- think\n2- do it", + step=2, + flags=flags, + ).prompt + ) # Verify all elements are not present for expected in expected_prompts: @@ -218,16 +235,18 @@ def test_main_prompt_elements_present(): # Make sure the flag is enabled # Initialize MainPrompt - prompt = MainPrompt( - action_set=dp.HighLevelActionSet(), - obs_history=OBS_HISTORY, - actions=ACTIONS, - memories=MEMORIES, - thoughts=THOUGHTS, - previous_plan="1- think\n2- do it", - step=2, - flags=ALL_TRUE_FLAGS, - ).prompt + prompt = str( + MainPrompt( + action_set=bgym.HighLevelActionSet(), + obs_history=OBS_HISTORY, + actions=ACTIONS, + memories=MEMORIES, + thoughts=THOUGHTS, + previous_plan="1- think\n2- do it", + step=2, + flags=ALL_TRUE_FLAGS, + ).prompt + ) # Verify all elements are not present for _, expected_prompts in FLAG_EXPECTED_PROMPT: for expected in expected_prompts: @@ -238,5 +257,5 @@ def test_main_prompt_elements_present(): # for debugging test_shrinking_observation() test_main_prompt_elements_present() - for flag, expected_prompts in FLAG_EXPECTED_PROMPT: - test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + # for flag, expected_prompts in FLAG_EXPECTED_PROMPT: + # test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) diff --git a/tests/experiments/test_dask.py b/tests/experiments/test_dask.py new file mode 100644 index 000000000..39822634f --- /dev/null +++ b/tests/experiments/test_dask.py @@ -0,0 +1,41 @@ +from agentlab.experiments.graph_execution_dask import execute_task_graph, make_dask_client +from agentlab.experiments.exp_utils import MockedExpArgs + +TASK_TIME = 3 + + +def test_execute_task_graph(): + # Define a list of ExpArgs with dependencies + exp_args_list = [ + MockedExpArgs(exp_id="task1", depends_on=[]), + MockedExpArgs(exp_id="task2", depends_on=["task1"]), + MockedExpArgs(exp_id="task3", depends_on=["task1"]), + MockedExpArgs(exp_id="task4", depends_on=["task2", "task3"]), + ] + + with make_dask_client(n_worker=5): + results = execute_task_graph(exp_args_list) + + exp_args_list = [results[task_id] for task_id in ["task1", "task2", "task3", "task4"]] + + # Verify that all tasks were executed in the proper order + assert exp_args_list[0].start_time < exp_args_list[1].start_time + assert exp_args_list[0].start_time < exp_args_list[2].start_time + assert exp_args_list[1].end_time < exp_args_list[3].start_time + assert exp_args_list[2].end_time < exp_args_list[3].start_time + + # # Verify that parallel tasks (task2 and task3) started within a short time of each other + # parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time) + # print(f"parallel_start_diff: {parallel_start_diff}") + # assert parallel_start_diff < 1.5 # Allow for a small delay + + # Ensure that the entire task graph took the expected amount of time + total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time + assert ( + total_time >= TASK_TIME * 3 + ) # Since the critical path involves at least 1.5 seconds of work + + +if __name__ == "__main__": + test_execute_task_graph() + # test_add_dependencies() diff --git a/tests/experiments/test_exp_configs.py b/tests/experiments/test_exp_configs.py index bf1b07f64..8a1b096bb 100644 --- a/tests/experiments/test_exp_configs.py +++ b/tests/experiments/test_exp_configs.py @@ -1,20 +1 @@ -from agentlab.experiments import study_generators - - -def test_all_configs(): - generators = [ - # study_generators.ablation_study, - study_generators.run_agents_on_benchmark, - study_generators.random_search, - ] - - for generator in generators: - study = generator() - assert isinstance(study, study_generators.Study) - assert isinstance(study.exp_args_list, list) - assert len(study.exp_args_list) > 0 - assert isinstance(study.exp_args_list[0], study_generators.ExpArgs) - - -if __name__ == "__main__": - test_all_configs() +from agentlab.experiments import study diff --git a/tests/experiments/test_launch_exp.py b/tests/experiments/test_launch_exp.py index daa6f2063..782a9edca 100644 --- a/tests/experiments/test_launch_exp.py +++ b/tests/experiments/test_launch_exp.py @@ -1,3 +1,4 @@ +import math import tempfile from pathlib import Path @@ -7,36 +8,41 @@ from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5, AGENT_4o_MINI from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs from agentlab.analyze import inspect_results -from agentlab.experiments.launch_exp import relaunch_study, run_experiments -from agentlab.experiments.study_generators import run_agents_on_benchmark +from agentlab.experiments.launch_exp import find_incomplete, run_experiments, non_dummy_count +from agentlab.experiments.study import Study from agentlab.llm.chat_api import CheatMiniWoBLLMArgs def test_relaunch_study(): study_dir = Path(__file__).parent.parent / "data" / "test_study" - exp_args_list, study_dir_ = relaunch_study(study_dir, relaunch_mode="incomplete_only") + exp_args_list = find_incomplete(study_dir, include_errors=False) - assert study_dir_ == study_dir - assert len(exp_args_list) == 1 + assert non_dummy_count(exp_args_list) == 1 assert exp_args_list[0].env_args.task_name == "miniwob.ascending-numbers" - exp_args_list, study_dir_ = relaunch_study(study_dir, relaunch_mode="incomplete_or_error") + exp_args_list = find_incomplete(study_dir, include_errors=True) - assert study_dir_ == study_dir - assert len(exp_args_list) == 2 + assert non_dummy_count(exp_args_list) == 2 -@pytest.mark.repeat(3) # there was stochastic bug caused by asyncio loop not started -def test_launch_system(backend="dask"): +def _test_launch_system(backend="ray", cause_timeout=False): + + if cause_timeout: + wait_time = 10 + avg_step_timeout = 0.5 + else: + wait_time = 0 + avg_step_timeout = 10 + exp_args_list = [] for seed in range(3): exp_args_list.append( ExpArgs( agent_args=GenericAgentArgs( - chat_model_args=CheatMiniWoBLLMArgs(), + chat_model_args=CheatMiniWoBLLMArgs(wait_time=wait_time), flags=FLAGS_GPT_3_5, ), - env_args=EnvArgs(task_name="miniwob.click-test", task_seed=seed), + env_args=EnvArgs(task_name="miniwob.click-test", task_seed=seed, max_steps=5), ) ) @@ -44,7 +50,11 @@ def test_launch_system(backend="dask"): study_dir = Path(tmp_dir) / "generic_agent_test" run_experiments( - n_jobs=2, exp_args_list=exp_args_list, study_dir=study_dir, parallel_backend=backend + n_jobs=2, + exp_args_list=exp_args_list, + study_dir=study_dir, + parallel_backend=backend, + avg_step_timeout=avg_step_timeout, ) results_df = inspect_results.load_result_df(study_dir, progress_fn=None) @@ -53,22 +63,36 @@ def test_launch_system(backend="dask"): for _, row in results_df.iterrows(): if row.stack_trace is not None: print(row.stack_trace) - assert row.err_msg is None - assert row.cum_reward == 1.0 + if cause_timeout: + # assert row.err_msg is not None + assert math.isnan(row.cum_reward) or row.cum_reward == 0 + else: + assert row.err_msg is None + assert row.cum_reward == 1.0 study_summary = inspect_results.summarize_study(results_df) assert len(study_summary) == 1 assert study_summary.std_err.iloc[0] == 0 - assert study_summary.n_completed.iloc[0] == "3/3" - assert study_summary.avg_reward.iloc[0] == 1.0 + + if not cause_timeout: + assert study_summary.n_completed.iloc[0] == "3/3" + assert study_summary.avg_reward.iloc[0] == 1.0 def test_launch_system_joblib(): - test_launch_system(backend="joblib") + _test_launch_system(backend="joblib") def test_launch_system_sequntial(): - test_launch_system(backend="sequential") + _test_launch_system(backend="sequential") + + +def test_launch_system_ray(): + _test_launch_system(backend="ray") + + +def test_timeout_ray(): + _test_launch_system(backend="ray", cause_timeout=True) @pytest.mark.pricy @@ -76,7 +100,7 @@ def test_4o_mini_on_miniwob_tiny_test(): """Run with `pytest -m pricy`.""" with tempfile.TemporaryDirectory() as tmp_dir: - study = run_agents_on_benchmark(agents=AGENT_4o_MINI, benchmark="miniwob_tiny_test") + study = Study(agent_args=[AGENT_4o_MINI], benchmark="miniwob_tiny_test", dir=tmp_dir) study.run(n_jobs=4) @@ -96,6 +120,7 @@ def test_4o_mini_on_miniwob_tiny_test(): if __name__ == "__main__": - test_4o_mini_on_miniwob_tiny_test() - # test_launch_system() + test_timeout_ray() + # test_4o_mini_on_miniwob_tiny_test() + # test_launch_system_ray() # test_launch_system_sequntial() diff --git a/tests/experiments/test_graph_execution.py b/tests/experiments/test_ray.py similarity index 63% rename from tests/experiments/test_graph_execution.py rename to tests/experiments/test_ray.py index 9235358db..9af5959a1 100644 --- a/tests/experiments/test_graph_execution.py +++ b/tests/experiments/test_ray.py @@ -1,36 +1,12 @@ +import bgym import pytest -from agentlab.experiments.graph_execution import ( - execute_task_graph, - add_dependencies, - make_dask_client, -) -from time import time, sleep -from browsergym.experiments.loop import ExpArgs, EnvArgs +import ray +from agentlab.experiments.graph_execution_ray import execute_task_graph +from agentlab.experiments.exp_utils import MockedExpArgs, add_dependencies TASK_TIME = 3 -# Mock implementation of the ExpArgs class with timestamp checks -class MockedExpArgs: - def __init__(self, exp_id, depends_on=None): - self.exp_id = exp_id - self.depends_on = depends_on if depends_on else [] - self.start_time = None - self.end_time = None - - def run(self): - self.start_time = time() - - # # simulate playright code, (this was causing issues due to python async loop) - # import playwright.sync_api - - # pw = playwright.sync_api.sync_playwright().start() - # pw.selectors.set_test_id_attribute("mytestid") - sleep(TASK_TIME) # Simulate task execution time - self.end_time = time() - return self - - def test_execute_task_graph(): # Define a list of ExpArgs with dependencies exp_args_list = [ @@ -40,8 +16,9 @@ def test_execute_task_graph(): MockedExpArgs(exp_id="task4", depends_on=["task2", "task3"]), ] - with make_dask_client(n_worker=5): - results = execute_task_graph(exp_args_list) + ray.init(num_cpus=4) + results = execute_task_graph(exp_args_list) + ray.shutdown() exp_args_list = [results[task_id] for task_id in ["task1", "task2", "task3", "task4"]] @@ -52,8 +29,9 @@ def test_execute_task_graph(): assert exp_args_list[2].end_time < exp_args_list[3].start_time # Verify that parallel tasks (task2 and task3) started within a short time of each other - # parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time) - # assert parallel_start_diff < 1.5 # Allow for a small delay + parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time) + print(f"parallel_start_diff: {parallel_start_diff}") + assert parallel_start_diff < 1.5 # Allow for a small delay # Ensure that the entire task graph took the expected amount of time total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time @@ -66,7 +44,9 @@ def test_add_dependencies(): # Prepare a simple list of ExpArgs def make_exp_args(task_name, exp_id): - return ExpArgs(agent_args=None, env_args=EnvArgs(task_name=task_name), exp_id=exp_id) + return bgym.ExpArgs( + agent_args=None, env_args=bgym.EnvArgs(task_name=task_name), exp_id=exp_id + ) exp_args_list = [ make_exp_args("task1", "1"), diff --git a/tests/experiments/test_reproducibility_util.py b/tests/experiments/test_reproducibility_util.py index 15b056580..aa10ff47f 100644 --- a/tests/experiments/test_reproducibility_util.py +++ b/tests/experiments/test_reproducibility_util.py @@ -1,20 +1,26 @@ -from pathlib import Path +import json import tempfile import time +from pathlib import Path + +import bgym +import pytest + +from agentlab.agents.generic_agent import AGENT_4o_MINI from agentlab.analyze import inspect_results from agentlab.experiments import reproducibility_util -from agentlab.agents.generic_agent import AGENT_4o_MINI -import pytest -import json @pytest.mark.parametrize( "benchmark_name", - ["miniwob", "workarena.l1", "webarena", "visualwebarena"], + ["miniwob", "workarena_l1", "webarena", "visualwebarena"], ) def test_get_reproducibility_info(benchmark_name): + + benchmark = bgym.DEFAULT_BENCHMARKS[benchmark_name]() + info = reproducibility_util.get_reproducibility_info( - "test_agent", benchmark_name, ignore_changes=True + "test_agent", benchmark, "test_id", ignore_changes=True ) print("reproducibility info:") @@ -32,68 +38,68 @@ def test_get_reproducibility_info(benchmark_name): assert "browsergym__local_modifications" in info -def test_save_reproducibility_info(): - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = Path(tmp_dir) - - info1 = reproducibility_util.save_reproducibility_info( - study_dir=tmp_dir, - info=reproducibility_util.get_reproducibility_info( - agent_name="GenericAgent", - benchmark_name="miniwob", - ignore_changes=True, - ), - ) - time.sleep(1) # make sure the date changes by at least 1s - - # this should overwrite the previous info since they are the same beside - # the date - info2 = reproducibility_util.save_reproducibility_info( - study_dir=tmp_dir, - info=reproducibility_util.get_reproducibility_info( - agent_name="GenericAgent", - benchmark_name="miniwob", - ignore_changes=True, - ), - ) - - reproducibility_util._assert_compatible(info1, info2) - - # this should not overwrite info2 as the agent name is different, it - # should raise an error - with pytest.raises(ValueError): - reproducibility_util.save_reproducibility_info( - study_dir=tmp_dir, - info=reproducibility_util.get_reproducibility_info( - agent_name="GenericAgent_alt", - benchmark_name="miniwob", - ignore_changes=True, - ), - ) - - # load json - info3 = reproducibility_util.load_reproducibility_info(tmp_dir) - - assert info2 == info3 - assert info1 != info3 - - test_study_dir = Path(__file__).parent.parent / "data" / "test_study" - result_df = inspect_results.load_result_df(test_study_dir, progress_fn=None) - report_df = inspect_results.summarize_study(result_df) - - with pytest.raises(ValueError): - reproducibility_util.append_to_journal( - info3, report_df, journal_path=tmp_dir / "journal.csv" - ) - - reproducibility_util.append_to_journal( - info3, report_df, journal_path=tmp_dir / "journal.csv", strict_reproducibility=False - ) - - print((tmp_dir / "journal.csv").read_text()) +# def test_save_reproducibility_info(): +# with tempfile.TemporaryDirectory() as tmp_dir: +# tmp_dir = Path(tmp_dir) + +# info1 = reproducibility_util.save_reproducibility_info( +# study_dir=tmp_dir, +# info=reproducibility_util.get_reproducibility_info( +# agents_args="GenericAgent", +# benchmark_name="miniwob", +# ignore_changes=True, +# ), +# ) +# time.sleep(1) # make sure the date changes by at least 1s + +# # this should overwrite the previous info since they are the same beside +# # the date +# info2 = reproducibility_util.save_reproducibility_info( +# study_dir=tmp_dir, +# info=reproducibility_util.get_reproducibility_info( +# agents_args="GenericAgent", +# benchmark_name="miniwob", +# ignore_changes=True, +# ), +# ) + +# reproducibility_util.assert_compatible(info1, info2) + +# # this should not overwrite info2 as the agent name is different, it +# # should raise an error +# with pytest.raises(ValueError): +# reproducibility_util.save_reproducibility_info( +# study_dir=tmp_dir, +# info=reproducibility_util.get_reproducibility_info( +# agents_args="GenericAgent_alt", +# benchmark_name="miniwob", +# ignore_changes=True, +# ), +# ) + +# # load json +# info3 = reproducibility_util.load_reproducibility_info(tmp_dir) + +# assert info2 == info3 +# assert info1 != info3 + +# test_study_dir = Path(__file__).parent.parent / "data" / "test_study" +# result_df = inspect_results.load_result_df(test_study_dir, progress_fn=None) +# report_df = inspect_results.summarize_study(result_df) + +# with pytest.raises(ValueError): +# reproducibility_util.append_to_journal( +# info3, report_df, journal_path=tmp_dir / "journal.csv" +# ) + +# reproducibility_util.append_to_journal( +# info3, report_df, journal_path=tmp_dir / "journal.csv", strict_reproducibility=False +# ) + +# print((tmp_dir / "journal.csv").read_text()) if __name__ == "__main__": # test_set_temp() test_get_reproducibility_info("miniwob") - test_save_reproducibility_info() + # test_save_reproducibility_info() diff --git a/tests/experiments/test_task_collection.py b/tests/experiments/test_task_collection.py deleted file mode 100644 index ea6267dcf..000000000 --- a/tests/experiments/test_task_collection.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest - -from agentlab.experiments.task_collections import get_benchmark_env_args - - -@pytest.mark.pricy -@pytest.mark.parametrize( - "benchmark_name, expected_length", - [ - ("workarena.l1", 330), - ("workarena.l2", 235), - ("workarena.l3", 235), - ("webarena", 812), - ("miniwob", 625), - ], -) -def test_get_benchmark_env_args(benchmark_name, expected_length): - result = get_benchmark_env_args(benchmark_name) - assert len(result) == expected_length - - -if __name__ == "__main__": - test_get_benchmark_env_args("workarena.l1", 5) - test_get_benchmark_env_args("workarena.l2", 5) - test_get_benchmark_env_args("workarena.l3", 5) - test_get_benchmark_env_args("webarena", 5) - test_get_benchmark_env_args("miniwob", 5) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 7e5bb87cc..10febbac1 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -183,7 +183,7 @@ def test_successful_parse_before_max_retries(): ] ) - result = llm_utils.retry(mock_chat, [], 5, mock_parser) + result = llm_utils.retry(mock_chat, llm_utils.Discussion(), 5, mock_parser) assert result == "Parsed value" assert mock_chat.call.call_count == 3 @@ -202,7 +202,7 @@ def test_unsuccessful_parse_before_max_retries(): ] ) with pytest.raises(llm_utils.ParseError): - result = llm_utils.retry(mock_chat, [], 2, mock_parser) + result = llm_utils.retry(mock_chat, llm_utils.Discussion(), 2, mock_parser) assert mock_chat.call.call_count == 2 @@ -213,7 +213,7 @@ def test_retry_parse_raises(): parser_raises = Mock(side_effect=ValueError("Parser error")) with pytest.raises(ValueError): - llm_utils.retry(mock_chat, [], 3, parser_raises) + llm_utils.retry(mock_chat, llm_utils.Discussion(), 3, parser_raises) def test_extract_code_blocks(): @@ -242,9 +242,38 @@ def hello_world(): assert llm_utils.extract_code_blocks(text) == expected_output +def test_message_merge_only_text(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 1 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + + +def test_message_merge_text_image(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + {"type": "image_url", "image_url": "this is a base64 image"}, + {"type": "text", "text": "This is another test."}, + {"type": "text", "text": "Goodbye, world!"}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 3 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + assert message["content"][1]["image_url"] == "this is a base64 image" + assert message["content"][2]["text"] == "This is another test.\nGoodbye, world!" + + if __name__ == "__main__": # test_retry_parallel() # test_rate_limit_max_wait_time() # test_successful_parse_before_max_retries() # test_unsuccessful_parse_before_max_retries() - test_extract_code_blocks() + # test_extract_code_blocks() + # test_message_merge_only_text() + test_message_merge_text_image()