diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 54d52f2cd..e1ee4ca85 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -443,24 +443,36 @@ def __init__(self, visible: bool = True) -> None: class GoalInstructions(PromptElement): - def __init__(self, goal, visible: bool = True, extra_instructions=None) -> None: + def __init__(self, goal_object, visible: bool = True, extra_instructions=None) -> None: super().__init__(visible) - self._prompt = f"""\ + self._prompt = [ + dict( + type="text", + text=f"""\ # Instructions Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. ## Goal: -{goal} -""" +""", + ) + ] + + self._prompt += goal_object + if extra_instructions: - self._prompt += f""" + self._prompt += [ + dict( + type="text", + text=f""" ## Extra instructions: {extra_instructions} -""" +""", + ) + ] class ChatInstructions(PromptElement): diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 26a4a276b..c9746d8c3 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -89,6 +89,7 @@ def get_action(self, obs): main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, + goal_object=obs["goal_object"], actions=self.actions, memories=self.memories, thoughts=self.thoughts, @@ -268,3 +269,4 @@ def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict): output += f"\n\n{action}\n" return system_prompt, instruction_prompt, output + return system_prompt, instruction_prompt, output diff --git a/src/agentlab/agents/generic_agent/generic_agent_prompt.py b/src/agentlab/agents/generic_agent/generic_agent_prompt.py index a655b42f3..7f60cd04d 100644 --- a/src/agentlab/agents/generic_agent/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent/generic_agent_prompt.py @@ -51,6 +51,7 @@ def __init__( self, action_set: AbstractActionSet, obs_history: list[dict], + goal_object: list[dict], actions: list[str], memories: list[str], thoughts: list[str], @@ -71,7 +72,7 @@ def __init__( "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) self.instructions = dp.GoalInstructions( - obs_history[-1]["goal"], extra_instructions=flags.extra_instructions + goal_object, extra_instructions=flags.extra_instructions ) self.obs = dp.Observation(obs_history[-1], self.flags.obs) @@ -93,9 +94,9 @@ def time_for_caution(): @property def _prompt(self) -> HumanMessage: - prompt = HumanMessage( + prompt = HumanMessage(self.instructions.prompt) + prompt.add_text( f"""\ -{self.instructions.prompt}\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index b0e8e8a06..dec6b7f7a 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -326,7 +326,7 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): def __init__(self, role: str, content: Union[str, list[dict]]): self["role"] = role - self["content"] = content + self["content"] = deepcopy(content) def __str__(self) -> str: if isinstance(self["content"], str): @@ -365,10 +365,30 @@ def to_markdown(self): # add texts between ticks and images if elem["type"] == "text": res.append(f"\n```\n{elem['text']}\n```\n") - elif elem["type"] == "image": - res.append(f"![image]({elem['url']})") + elif elem["type"] == "image_url": + img_str = ( + elem["image_url"] + if isinstance(elem["image_url"], str) + else elem["image_url"]["url"] + ) + res.append(f"![image]({img_str})") return "\n".join(res) + def merge(self): + """Merges content elements of type 'text' if they are adjacent.""" + if isinstance(self["content"], str): + return + new_content = [] + for elem in self["content"]: + if elem["type"] == "text": + if new_content and new_content[-1]["type"] == "text": + new_content[-1]["text"] += "\n" + elem["text"] + else: + new_content.append(elem) + else: + new_content.append(elem) + self["content"] = new_content + class SystemMessage(BaseMessage): def __init__(self, content: Union[str, list[dict]]): @@ -397,13 +417,19 @@ def __init__(self, messages: Union[list[BaseMessage], BaseMessage] = None): def last_message(self): return self.messages[-1] + def merge(self): + for m in self.messages: + m.merge() + def __str__(self) -> str: return "\n".join(str(m) for m in self.messages) def to_string(self): + self.merge() return str(self) def to_openai(self): + self.merge() return self.messages def add_message( @@ -444,6 +470,7 @@ def __getitem__(self, key): return self.messages[key] def to_markdown(self): + self.merge() return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)]) diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index e74e0a5c5..b3aa6db0a 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -24,6 +24,7 @@ OBS_HISTORY = [ { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(1), "axtree_txt": "[1] Click me", @@ -32,6 +33,7 @@ }, { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(2), "axtree_txt": "[1] Click me", @@ -40,6 +42,7 @@ }, { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(3), "axtree_txt": "[1] Click me", @@ -47,6 +50,7 @@ "last_action_error": "Hey, there is an error now", }, ] +GOAL_OBJECT = [{"type": "text", "text": "do this and that"}] ACTIONS = ["click('41')", "click('42')"] MEMORIES = ["memory A", "memory B"] THOUGHTS = ["thought A", "thought B"] @@ -164,6 +168,7 @@ def test_shrinking_observation(): prompt_maker = MainPrompt( action_set=dp.HighLevelActionSet(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=MEMORIES, thoughts=THOUGHTS, @@ -208,6 +213,7 @@ def test_main_prompt_elements_gone_one_at_a_time(flag_name: str, expected_prompt MainPrompt( action_set=flags.action.action_set.make_action_set(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=memories, thoughts=THOUGHTS, @@ -230,6 +236,7 @@ def test_main_prompt_elements_present(): MainPrompt( action_set=dp.HighLevelActionSet(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=MEMORIES, thoughts=THOUGHTS, @@ -250,3 +257,8 @@ def test_main_prompt_elements_present(): test_main_prompt_elements_present() for flag, expected_prompts in FLAG_EXPECTED_PROMPT: test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index d8c29a695..10febbac1 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -242,9 +242,38 @@ def hello_world(): assert llm_utils.extract_code_blocks(text) == expected_output +def test_message_merge_only_text(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 1 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + + +def test_message_merge_text_image(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + {"type": "image_url", "image_url": "this is a base64 image"}, + {"type": "text", "text": "This is another test."}, + {"type": "text", "text": "Goodbye, world!"}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 3 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + assert message["content"][1]["image_url"] == "this is a base64 image" + assert message["content"][2]["text"] == "This is another test.\nGoodbye, world!" + + if __name__ == "__main__": # test_retry_parallel() # test_rate_limit_max_wait_time() # test_successful_parse_before_max_retries() # test_unsuccessful_parse_before_max_retries() - test_extract_code_blocks() + # test_extract_code_blocks() + # test_message_merge_only_text() + test_message_merge_text_image()