From d8c8a7cbd18a1ffdb71c09533dc90d1fce46c07b Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 17 Oct 2024 09:58:19 -0400 Subject: [PATCH] added tmlr definitive config --- .../agents/generic_agent/tmlr_config.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/agentlab/agents/generic_agent/tmlr_config.py diff --git a/src/agentlab/agents/generic_agent/tmlr_config.py b/src/agentlab/agents/generic_agent/tmlr_config.py new file mode 100644 index 000000000..11860e691 --- /dev/null +++ b/src/agentlab/agents/generic_agent/tmlr_config.py @@ -0,0 +1,72 @@ +from copy import deepcopy + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags + +BASE_FLAGS = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, # gpt-4o config except for this line + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + multi_actions=False, + action_set="bid", + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=None, + be_cautious=True, + extra_instructions=None, +) + + +def get_base_agent(llm_config: str): + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=BASE_FLAGS, + ) + + +def get_vision_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) + + +def get_som_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + flags.obs.use_som = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + )