From 676df1b69b7515c7580b579eb6cc1b6ce17705fd Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 11:48:06 -0400 Subject: [PATCH 1/7] updating generic agent to goal_object --- src/agentlab/agents/dynamic_prompting.py | 24 ++++++++++++++----- .../agents/generic_agent/generic_agent.py | 2 ++ .../generic_agent/generic_agent_prompt.py | 7 +++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 54d52f2cd..e1ee4ca85 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -443,24 +443,36 @@ def __init__(self, visible: bool = True) -> None: class GoalInstructions(PromptElement): - def __init__(self, goal, visible: bool = True, extra_instructions=None) -> None: + def __init__(self, goal_object, visible: bool = True, extra_instructions=None) -> None: super().__init__(visible) - self._prompt = f"""\ + self._prompt = [ + dict( + type="text", + text=f"""\ # Instructions Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. ## Goal: -{goal} -""" +""", + ) + ] + + self._prompt += goal_object + if extra_instructions: - self._prompt += f""" + self._prompt += [ + dict( + type="text", + text=f""" ## Extra instructions: {extra_instructions} -""" +""", + ) + ] class ChatInstructions(PromptElement): diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 26a4a276b..c9746d8c3 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -89,6 +89,7 @@ def get_action(self, obs): main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, + goal_object=obs["goal_object"], actions=self.actions, memories=self.memories, thoughts=self.thoughts, @@ -268,3 +269,4 @@ def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict): output += f"\n\n{action}\n" return system_prompt, instruction_prompt, output + return system_prompt, instruction_prompt, output diff --git a/src/agentlab/agents/generic_agent/generic_agent_prompt.py b/src/agentlab/agents/generic_agent/generic_agent_prompt.py index a655b42f3..7f60cd04d 100644 --- a/src/agentlab/agents/generic_agent/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent/generic_agent_prompt.py @@ -51,6 +51,7 @@ def __init__( self, action_set: AbstractActionSet, obs_history: list[dict], + goal_object: list[dict], actions: list[str], memories: list[str], thoughts: list[str], @@ -71,7 +72,7 @@ def __init__( "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) self.instructions = dp.GoalInstructions( - obs_history[-1]["goal"], extra_instructions=flags.extra_instructions + goal_object, extra_instructions=flags.extra_instructions ) self.obs = dp.Observation(obs_history[-1], self.flags.obs) @@ -93,9 +94,9 @@ def time_for_caution(): @property def _prompt(self) -> HumanMessage: - prompt = HumanMessage( + prompt = HumanMessage(self.instructions.prompt) + prompt.add_text( f"""\ -{self.instructions.prompt}\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ From 64877f94781f5298542f46c49491712328dec964 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 11:48:32 -0400 Subject: [PATCH 2/7] fixing image markdown display --- src/agentlab/llm/llm_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index b0e8e8a06..19c192719 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -365,8 +365,13 @@ def to_markdown(self): # add texts between ticks and images if elem["type"] == "text": res.append(f"\n```\n{elem['text']}\n```\n") - elif elem["type"] == "image": - res.append(f"![image]({elem['url']})") + elif elem["type"] == "image_url": + img_str = ( + elem["image_url"] + if isinstance(elem["image_url"], str) + else elem["image_url"]["url"] + ) + res.append(f"![image]({img_str})") return "\n".join(res) From a43fd76730d6400d16600bb3bdb2d162664e2548 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 11:51:55 -0400 Subject: [PATCH 3/7] updating tests --- tests/agents/test_generic_prompt.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index e74e0a5c5..b3aa6db0a 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -24,6 +24,7 @@ OBS_HISTORY = [ { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(1), "axtree_txt": "[1] Click me", @@ -32,6 +33,7 @@ }, { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(2), "axtree_txt": "[1] Click me", @@ -40,6 +42,7 @@ }, { "goal": "do this and that", + "goal_object": [{"type": "text", "text": "do this and that"}], "chat_messages": [{"role": "user", "message": "do this and that"}], "pruned_html": html_template.format(3), "axtree_txt": "[1] Click me", @@ -47,6 +50,7 @@ "last_action_error": "Hey, there is an error now", }, ] +GOAL_OBJECT = [{"type": "text", "text": "do this and that"}] ACTIONS = ["click('41')", "click('42')"] MEMORIES = ["memory A", "memory B"] THOUGHTS = ["thought A", "thought B"] @@ -164,6 +168,7 @@ def test_shrinking_observation(): prompt_maker = MainPrompt( action_set=dp.HighLevelActionSet(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=MEMORIES, thoughts=THOUGHTS, @@ -208,6 +213,7 @@ def test_main_prompt_elements_gone_one_at_a_time(flag_name: str, expected_prompt MainPrompt( action_set=flags.action.action_set.make_action_set(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=memories, thoughts=THOUGHTS, @@ -230,6 +236,7 @@ def test_main_prompt_elements_present(): MainPrompt( action_set=dp.HighLevelActionSet(), obs_history=OBS_HISTORY, + goal_object=GOAL_OBJECT, actions=ACTIONS, memories=MEMORIES, thoughts=THOUGHTS, @@ -250,3 +257,8 @@ def test_main_prompt_elements_present(): test_main_prompt_elements_present() for flag, expected_prompts in FLAG_EXPECTED_PROMPT: test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) + test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts) From 6f8075cadb0664fb0440008bd3f6723bafd30742 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 12:52:32 -0400 Subject: [PATCH 4/7] fixing intruction BaseMessage --- src/agentlab/llm/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 19c192719..fc0b9f058 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -326,7 +326,7 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): def __init__(self, role: str, content: Union[str, list[dict]]): self["role"] = role - self["content"] = content + self["content"] = deepcopy(content) def __str__(self) -> str: if isinstance(self["content"], str): From 1bf1b176f4c2a4cfcc55d2e1aa70eacb178bd567 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 15:15:31 -0400 Subject: [PATCH 5/7] added merge text in discussion --- src/agentlab/llm/llm_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index fc0b9f058..feb9ac3be 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -374,6 +374,21 @@ def to_markdown(self): res.append(f"![image]({img_str})") return "\n".join(res) + def merge(self): + """Merges content elements of type 'text' if they are adjacent.""" + if isinstance(self["content"], str): + return + new_content = [] + for elem in self["content"]: + if elem["type"] == "text": + if new_content and new_content[-1]["type"] == "text": + new_content[-1]["text"] += "\n" + elem["text"] + else: + new_content.append(elem) + else: + new_content.append(elem) + self["content"] = new_content + class SystemMessage(BaseMessage): def __init__(self, content: Union[str, list[dict]]): From a7e4123ad1822a952a4b5a417e736572078f6123 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 15:31:06 -0400 Subject: [PATCH 6/7] added merge to discussion class --- src/agentlab/llm/llm_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index feb9ac3be..dec6b7f7a 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -417,13 +417,19 @@ def __init__(self, messages: Union[list[BaseMessage], BaseMessage] = None): def last_message(self): return self.messages[-1] + def merge(self): + for m in self.messages: + m.merge() + def __str__(self) -> str: return "\n".join(str(m) for m in self.messages) def to_string(self): + self.merge() return str(self) def to_openai(self): + self.merge() return self.messages def add_message( @@ -464,6 +470,7 @@ def __getitem__(self, key): return self.messages[key] def to_markdown(self): + self.merge() return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)]) From 86a23a90261652858029fbb57c5108a1948dd425 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 22 Oct 2024 15:44:35 -0400 Subject: [PATCH 7/7] added tests --- tests/llm/test_llm_utils.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index d8c29a695..10febbac1 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -242,9 +242,38 @@ def hello_world(): assert llm_utils.extract_code_blocks(text) == expected_output +def test_message_merge_only_text(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 1 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + + +def test_message_merge_text_image(): + content = [ + {"type": "text", "text": "Hello, world!"}, + {"type": "text", "text": "This is a test."}, + {"type": "image_url", "image_url": "this is a base64 image"}, + {"type": "text", "text": "This is another test."}, + {"type": "text", "text": "Goodbye, world!"}, + ] + message = llm_utils.BaseMessage(role="system", content=content) + message.merge() + assert len(message["content"]) == 3 + assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + assert message["content"][1]["image_url"] == "this is a base64 image" + assert message["content"][2]["text"] == "This is another test.\nGoodbye, world!" + + if __name__ == "__main__": # test_retry_parallel() # test_rate_limit_max_wait_time() # test_successful_parse_before_max_retries() # test_unsuccessful_parse_before_max_retries() - test_extract_code_blocks() + # test_extract_code_blocks() + # test_message_merge_only_text() + test_message_merge_text_image()