From 244e06b3955780e8f76ccb2355ca3772493e4cc4 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 1 Apr 2026 18:40:52 +0200 Subject: [PATCH 1/7] feat: adds new `AskUITruncationStrategy` that dynamically removes images from the history thereby staying compatible with prompt caching --- src/askui/models/shared/conversation.py | 31 +- .../models/shared/truncation_strategies.py | 698 ++++++++++-------- src/askui/speaker/agent_speaker.py | 2 +- .../unit/models/test_truncation_strategies.py | 464 ++++++++++++ 4 files changed, 850 insertions(+), 345 deletions(-) create mode 100644 tests/unit/models/test_truncation_strategies.py diff --git a/src/askui/models/shared/conversation.py b/src/askui/models/shared/conversation.py index 3afcf105..dc4fb2b3 100644 --- a/src/askui/models/shared/conversation.py +++ b/src/askui/models/shared/conversation.py @@ -13,9 +13,8 @@ from askui.models.shared.settings import ActSettings from askui.models.shared.tools import ToolCollection from askui.models.shared.truncation_strategies import ( - SimpleTruncationStrategyFactory, + AskUITruncationStrategy, TruncationStrategy, - TruncationStrategyFactory, ) from askui.reporting import NULL_REPORTER, Reporter from askui.speaker.speaker import SpeakerResult, Speakers @@ -55,7 +54,7 @@ class Conversation: detection_provider: Detection provider (optional) reporter: Reporter for logging messages and actions cache_manager: Cache manager for recording/playback (optional) - truncation_strategy_factory: Factory for creating truncation strategies + truncation_strategy: truncation strategies (optional) callbacks: List of callbacks for conversation lifecycle hooks (optional) """ @@ -67,7 +66,7 @@ def __init__( detection_provider: DetectionProvider | None = None, reporter: Reporter = NULL_REPORTER, cache_manager: "CacheManager | None" = None, - truncation_strategy_factory: TruncationStrategyFactory | None = None, + truncation_strategy: TruncationStrategy | None = None, callbacks: "list[ConversationCallback] | None" = None, ) -> None: """Initialize conversation with speakers and model providers.""" @@ -90,10 +89,6 @@ def __init__( # Infrastructure self._reporter = reporter self.cache_manager = cache_manager - self._truncation_strategy_factory = ( - truncation_strategy_factory or SimpleTruncationStrategyFactory() - ) - self._truncation_strategy: TruncationStrategy | None = None self._callbacks: "list[ConversationCallback]" = callbacks or [] # State for current execution (set in start()) @@ -102,6 +97,11 @@ def __init__( self._reporters: list[Reporter] = [] self._step_index: int = 0 + # truncation strategy + self._truncation_strategy = truncation_strategy or AskUITruncationStrategy( + vlm_provider=vlm_provider, + ) + # Track if cache execution was used (to prevent recording during playback) self._executed_from_cache: bool = False @@ -180,6 +180,7 @@ def _setup_control_loop( reporters: list[Reporter] | None = None, ) -> None: # Reset state + self._truncation_strategy.reset(messages) self._executed_from_cache = False self.speakers.reset_state() @@ -191,16 +192,6 @@ def _setup_control_loop( # Auto-populate speaker descriptions and switch_speaker tool self._setup_speaker_handoff() - # Initialize truncation strategy - self._truncation_strategy = ( - self._truncation_strategy_factory.create_truncation_strategy( - tools=self.tools.to_params(), - system=self.settings.messages.system, - messages=messages, - model=self.vlm_provider.model_id, - ) - ) - @tracer.start_as_current_span("_execute_control_loop") def _execute_control_loop(self) -> None: self._on_control_loop_start() @@ -443,7 +434,9 @@ def get_messages(self) -> list[MessageParam]: Returns: List of messages in current conversation """ - return self._truncation_strategy.messages if self._truncation_strategy else [] + return ( + self._truncation_strategy.full_messages if self._truncation_strategy else [] + ) def get_truncation_strategy(self) -> TruncationStrategy | None: """Get current truncation strategy. diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index 1b0be223..eebaa720 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -1,377 +1,425 @@ -from dataclasses import dataclass -from typing import Annotated +"""Truncation strategies for managing conversation message history.""" + +import logging +from abc import ABC, abstractmethod -from pydantic import Field from typing_extensions import override +from askui.model_providers.vlm_provider import VlmProvider from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, CacheControlEphemeralParam, + ContentBlockParam, + ImageBlockParam, MessageParam, TextBlockParam, - ToolParam, + ToolResultBlockParam, + ToolUseBlockParam, ) -from askui.models.shared.prompts import ActSystemPrompt -from askui.models.shared.token_counter import SimpleTokenCounter, TokenCounter +from askui.models.shared.token_counter import SimpleTokenCounter + +logger = logging.getLogger(__name__) # needs to be below limits imposed by endpoint MAX_INPUT_TOKENS = 100_000 +# we will truncate as soon as we reach this threshold +TRUNCATION_THRESHOLD = 0.7 + # see https://docs.anthropic.com/en/api/messages#body-messages MAX_MESSAGES = 100_000 +IMAGE_REMOVED_PLACEHOLDER = "[Screenshot removed to reduce message history length]" +"""Text used to replace stripped base64 images.""" + + +class TruncationStrategy(ABC): + """Abstract base class for truncation strategies. -class TruncationStrategy: - """Abstract base class for truncation strategies.""" + Manages two separate message histories: + + - ``full_messages``: append-only, preserves all original + messages for cache recording + - ``truncated_messages``: may have images stripped and + history summarized for LLM calls + + Args: + max_messages: Maximum number of messages before forcing truncation. + max_input_tokens: Maximum input tokens for the endpoint. + truncation_threshold: Fraction of `max_input_tokens` at which to truncate. + """ def __init__( self, - tools: list[ToolParam] | None, - system: ActSystemPrompt | None, - messages: list[MessageParam], - model: str, + max_messages: int = MAX_MESSAGES, + max_input_tokens: int = MAX_INPUT_TOKENS, + truncation_threshold: float = TRUNCATION_THRESHOLD, ) -> None: - self._tools = tools - self._messages = messages - self._system = system - self._model = model + self._full_message_history: list[MessageParam] = [] + self._truncated_message_history: list[MessageParam] = [] + self._max_messages = max_messages + self._absolute_truncation_threshold = int( + max_input_tokens * truncation_threshold + ) + @abstractmethod def append_message(self, message: MessageParam) -> None: - self._messages.append(message) + """Append a message and apply any truncation logic.""" + ... + + @abstractmethod + def truncate(self) -> None: + """Force-truncate the message history.""" + ... + + def reset(self, messages: list[MessageParam] | None = None) -> None: + """Reset message histories with optional initial messages. + + Creates independent copies so modifications to truncated history + do not affect full history. + + Args: + messages: Initial messages to populate both histories. + If ``None``, both histories are cleared. + """ + if messages is not None: + self._full_message_history = list(messages) + self._truncated_message_history = list(messages) + else: + self._full_message_history = [] + self._truncated_message_history = [] @property - def messages(self) -> list[MessageParam]: - """Get the truncated messages.""" - return self._messages - - -def _is_tool_result_user_message(message: MessageParam) -> bool: - return message.role == "user" and ( - isinstance(message.content, list) - and any(block.type == "tool_result" for block in message.content) - ) - - -def _is_tool_use_assistant_message(message: MessageParam) -> bool: - return message.role == "assistant" and ( - isinstance(message.content, list) - and any(block.type == "tool_use" for block in message.content) - ) - - -def _is_end_of_loop( - message: MessageParam, previous_message: MessageParam | None -) -> bool: - return ( - not _is_tool_result_user_message(message) - and previous_message is not None - and previous_message.role == "assistant" - ) - - -@dataclass(kw_only=True) -class MessageContainer: - index: int - message: MessageParam - tokens: int - - -class SimpleTruncationStrategy(TruncationStrategy): - """Simple truncation strategy that truncates messages to stay within token and - message limits. - - Clusters messages into "tool calling loops" - sequences of messages starting with - a user message (not containing `tool_result` blocks) or the first message, and - ending with an assistant message before the next such user message or the last - message. - - The last tool calling loop is called the "open loop" and represents the current - conversation context being worked on. - - Truncation follows this priority order until both token and message thresholds - are met: - 1. Remove tool calling turns (assistant tool_use + user tool_result pairs) - from closed loops - 2. Remove entire closed loops (except first and last which usually contain - the most important context) - 3. Remove the first loop if it's not the open loop - 4. Remove tool calling turns from the open loop (except the first and last turn) - - We need to preserve the thinking block in first turn of open loop. - - Also these are the blocks with the most important context. - 5. Raise ValueError if still exceeds limits after all truncation attempts - - We truncate until a threshold that is way below the limits to make sure that - the threshold is not reached immediately afterwards again and caching can work - in that time. + def truncated_messages(self) -> list[MessageParam]: + """Get the truncated messages sent to the LLM.""" + return self._truncated_message_history + + @property + def full_messages(self) -> list[MessageParam]: + """Get the full, untruncated messages for cache recording.""" + return self._full_message_history + + +class AskUITruncationStrategy(TruncationStrategy): + """Truncation strategy that strips old images, manages cache breakpoints, + and summarizes. + + On each appended message: + + 1. Strips base64 images beyond `n_images_to_keep` (oldest first) + 2. Places dual cache breakpoints (at image-removal boundary + last user message) + 3. If token count exceeds threshold, summarizes the history via the VLM Args: - tools (list[ToolParam] | None): Available tools for the conversation - system (str | list[BetaTextBlockParam] | None): System prompt or blocks - messages (list[MessageParam]): Initial conversation messages - model (str): Model name for token counting - max_input_tokens (int, optional): Maximum input tokens allowed. Defaults to - 100,000. - input_token_truncation_threshold (float, optional): Fraction of max tokens to - truncate at. Defaults to 0.75. - max_messages (int, optional): Maximum messages allowed. Defaults to 100,000. - message_truncation_threshold (float, optional): Fraction of max messages to - truncate at. Defaults to 0.75. - token_counter (TokenCounter | None, optional): Token counter instance. Defaults - to SimpleTokenCounter. - - Raises: - ValueError: If conversation cannot be truncated below limits after all attempts. + vlm_provider: VLM provider used for summarization calls. + n_images_to_keep: Number of most-recent base64 images to retain. + n_messages_to_keep: Number of most-recent messages to preserve + during summarization. + max_messages: Maximum number of messages before forcing truncation. + max_input_tokens: Maximum input tokens for the endpoint. + truncation_threshold: Fraction of `max_input_tokens` at which to truncate. """ def __init__( self, - tools: list[ToolParam] | None, - system: ActSystemPrompt | None, - messages: list[MessageParam], - model: str, - max_input_tokens: int = MAX_INPUT_TOKENS, - input_token_truncation_threshold: Annotated[ - float, Field(gt=0.0, lt=1.0) - ] = 0.75, + vlm_provider: VlmProvider, + n_images_to_keep: int = 3, + n_messages_to_keep: int = 10, max_messages: int = MAX_MESSAGES, - message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, - token_counter: TokenCounter | None = None, + max_input_tokens: int = MAX_INPUT_TOKENS, + truncation_threshold: float = TRUNCATION_THRESHOLD, ) -> None: - super().__init__( - tools=tools, - system=system, - messages=messages, - model=model, - ) - self._max_input_tokens = max_input_tokens - self._max_input_tokens_after_truncation = int( - input_token_truncation_threshold * max_input_tokens - ) - self._max_messages = max_messages - self._max_messages_after_truncation = int( - message_truncation_threshold * max_messages - ) - self._token_counter = token_counter or SimpleTokenCounter() - self._token_counts = self._token_counter.count_tokens( - tools=tools, - system=system, - messages=messages, - ) + super().__init__(max_messages, max_input_tokens, truncation_threshold) + self._vlm_provider = vlm_provider + self._n_images_to_keep = n_images_to_keep + self._n_messages_to_keep = n_messages_to_keep + self._token_counter = SimpleTokenCounter() + self._image_removal_boundary_index: int | None = None @override def append_message(self, message: MessageParam) -> None: - super().append_message(message) - self._token_counts.append_message_tokens( - self._token_counter.count_tokens(messages=[message]).total - ) - if self._should_truncate(): - self._truncate() + """Append a message and apply image stripping, cache breakpoints, + and truncation. - def _should_truncate(self) -> bool: - return ( - self._token_counts.total > self._max_input_tokens - or len(self._messages) > self._max_messages + Args: + message: The message to append. + """ + self._full_message_history.append(message) + self._truncated_message_history.append(message) + + # Strip old base64 images (sets _image_removal_boundary_index) + self._remove_images() + + # Place cache breakpoints using the boundary index + self._move_cache_breakpoints() + + # Check if truncation is needed + token_counts = self._token_counter.count_tokens( + messages=self._truncated_message_history, ) + if ( + len(self._truncated_message_history) > self._max_messages + or token_counts.total > self._absolute_truncation_threshold + ): + self.truncate() - @property @override - def messages(self) -> list[MessageParam]: - self._move_cache_control_to_last_user_message() - return self._messages - - def _move_cache_control_to_last_user_message(self) -> None: - found_last = False - for message in reversed(self._messages): - if message.role == "user": - if not found_last: - found_last = True - if isinstance(message.content, str): - message.content = [ - TextBlockParam( - text=message.content, - cache_control=CacheControlEphemeralParam( - type="ephemeral", - ), - ) - ] - elif len(message.content) > 0: - last_content = message.content[-1] - if hasattr(last_content, "cache_control"): - last_content.cache_control = CacheControlEphemeralParam( - type="ephemeral", - ) - else: - if isinstance(message.content, list) and message.content: - last_content = message.content[-1] - if hasattr(last_content, "cache_control"): - last_content.cache_control = None - break - - def _truncate(self) -> None: # noqa: C901 - messages_to_remove_min = min( - len(self._messages) - self._max_messages_after_truncation, 0 - ) - tokens_to_remove_min = max( - self._token_counts.total - self._max_input_tokens_after_truncation, 0 - ) - messages_removed_indices: set[int] = set() - tokens_removed = 0 - loops = self._cluster_into_tool_calling_loops() - - # 1. Remove tool calling turns within closed loops - last_message_was_tool_use_assistant_message = False - for closed_loop in loops[:-1]: - for message_container in closed_loop: - if last_message_was_tool_use_assistant_message: - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens - if ( - len(messages_removed_indices) >= messages_to_remove_min - or tokens_removed >= tokens_to_remove_min - ): - self._remove_messages(messages_removed_indices) - return - - last_message_was_tool_use_assistant_message = False - if _is_tool_use_assistant_message(message_container.message): - last_message_was_tool_use_assistant_message = True - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens - - # 2. Remove loops except first and last (open) loop - for closed_loop in loops[1:-1]: - for message_container in closed_loop: - if message_container.index not in messages_removed_indices: - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens - if ( - len(messages_removed_indices) >= messages_to_remove_min - or tokens_removed >= tokens_to_remove_min + def truncate(self) -> None: + """Summarize old messages and keep only the most recent ones.""" + if len(self._truncated_message_history) <= self._n_messages_to_keep: + msg = "Cannot truncate as there are too few messages in the history" + logger.warning(msg) + return + + summary = self._summarize_message_history() + + # Keep most recent messages + recent = self._truncated_message_history[-self._n_messages_to_keep :] + + # Build new history starting with the summary as a user message + new_messages: list[MessageParam] = [ + MessageParam(role="user", content=summary), + ] + + # Ensure valid role alternation: if first recent message is also + # "user", insert a synthetic assistant acknowledgement. + if recent and recent[0].role == "user": + new_messages.append( + MessageParam( + role="assistant", + content=( + "Understood. I'll continue based on " + "the conversation summary above." + ), + ) + ) + + new_messages.extend(recent) + self._truncated_message_history = new_messages + self._image_removal_boundary_index = None + + # ------------------------------------------------------------------ + # Image removal + # ------------------------------------------------------------------ + + def _remove_images(self) -> None: + """Strip old base64 images from truncated history, keeping the most recent ones. + + Walks from the beginning of the message list and replaces excess + base64 `ImageBlockParam` blocks with text placeholders. Also + recurses into `ToolResultBlockParam.content` lists. URL-based + images are never stripped. + """ + total = self._count_base64_images(self._truncated_message_history) + to_remove = total - self._n_images_to_keep + if to_remove <= 0: + return + + removed = 0 + for i, msg in enumerate(self._truncated_message_history): + if removed >= to_remove: + break + if isinstance(msg.content, str): + continue + + new_content, removed_in_msg = self._strip_base64_images( + msg.content, to_remove - removed + ) + if removed_in_msg > 0: + self._truncated_message_history[i] = MessageParam( + role=msg.role, + content=new_content, + stop_reason=msg.stop_reason, + usage=msg.usage, + ) + self._image_removal_boundary_index = i + removed += removed_in_msg + + @staticmethod + def _count_base64_images(messages: list[MessageParam]) -> int: + """Count total base64 image blocks across all messages.""" + count = 0 + for msg in messages: + if isinstance(msg.content, str): + continue + for block in msg.content: + if isinstance(block, ImageBlockParam) and isinstance( + block.source, Base64ImageSourceParam + ): + count += 1 + elif isinstance(block, ToolResultBlockParam) and isinstance( + block.content, list + ): + for nested in block.content: + if isinstance(nested, ImageBlockParam) and isinstance( + nested.source, Base64ImageSourceParam + ): + count += 1 + return count + + @staticmethod + def _strip_base64_images( + content: list[ContentBlockParam], + max_to_strip: int, + ) -> tuple[list[ContentBlockParam], int]: + """Strip up to `max_to_strip` base64 images from a content block list. + + Args: + content: The content blocks to process. + max_to_strip: Maximum number of images to strip. + + Returns: + Tuple of (new content list, number of images stripped). + """ + stripped = 0 + new_content: list[ContentBlockParam] = [] + + for block in content: + if stripped >= max_to_strip: + new_content.append(block) + continue + + if isinstance(block, ImageBlockParam) and isinstance( + block.source, Base64ImageSourceParam ): - self._remove_messages(messages_removed_indices) - return - - # 3. Remove first loop if it is not the last (open) loop - if len(loops) > 1: - for message_container in loops[0]: - if message_container.index not in messages_removed_indices: - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens - if ( - len(messages_removed_indices) >= messages_to_remove_min - or tokens_removed >= tokens_to_remove_min + new_content.append(TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER)) + stripped += 1 + elif isinstance(block, ToolResultBlockParam) and isinstance( + block.content, list ): - self._remove_messages(messages_removed_indices) - return - - # 4. Remove tool calling turns within open loop except last turn - if len(loops) > 0: - open_loop = loops[-1] - last_message_was_tool_use_assistant_message = False - for i, message_container in enumerate(open_loop): - if last_message_was_tool_use_assistant_message: - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens + new_nested: list[TextBlockParam | ImageBlockParam] = [] + for nested in block.content: if ( - len(messages_removed_indices) >= messages_to_remove_min - or tokens_removed >= tokens_to_remove_min + stripped < max_to_strip + and isinstance(nested, ImageBlockParam) + and isinstance(nested.source, Base64ImageSourceParam) ): - self._remove_messages(messages_removed_indices) - return + new_nested.append( + TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER) + ) + stripped += 1 + else: + new_nested.append(nested) + new_content.append( + ToolResultBlockParam( + tool_use_id=block.tool_use_id, + content=new_nested, + is_error=block.is_error, + cache_control=block.cache_control, + ) + ) + else: + new_content.append(block) + + return new_content, stripped + + # ------------------------------------------------------------------ + # Cache breakpoints + # ------------------------------------------------------------------ + + def _move_cache_breakpoints(self) -> None: + """Place dual cache breakpoints on truncated history. + + - **Breakpoint 1** – at the image-removal boundary (stable prefix). + - **Breakpoint 2** – on the last user message (recent context). + """ + # Clear all existing cache_control + for msg in self._truncated_message_history: + self._clear_cache_control(msg) + + # Breakpoint 1: at image removal boundary + if ( + self._image_removal_boundary_index is not None + and self._image_removal_boundary_index + < len(self._truncated_message_history) + ): + self._set_cache_breakpoint( + self._truncated_message_history[self._image_removal_boundary_index] + ) - last_message_was_tool_use_assistant_message = False - if ( - _is_tool_use_assistant_message(message_container.message) - and 1 < i < len(open_loop) - 2 - ): - last_message_was_tool_use_assistant_message = True - messages_removed_indices.add(message_container.index) - tokens_removed += message_container.tokens - - # Everything that is left is the last non-tool-result user message - # and the last (open or closed) tool calling turn (if there is one) - error_msg = "Conversation too long. Please start a new conversation." - raise ValueError(error_msg) - - def _remove_messages(self, indices: set[int]) -> None: - self._token_counts.reset_message_tokens( - [ - self._token_counts.retrieve_message_tokens(i) - for i, _ in enumerate(self._messages) - if i not in indices - ] + # Breakpoint 2: last user message + for msg in reversed(self._truncated_message_history): + if msg.role == "user": + self._set_cache_breakpoint(msg) + break + + @staticmethod + def _clear_cache_control(msg: MessageParam) -> None: + """Clear ``cache_control`` on all blocks in a message.""" + if isinstance(msg.content, str): + return + cacheable = ( + ImageBlockParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, ) - self._messages = [ - message for i, message in enumerate(self._messages) if i not in indices - ] - - def _cluster_into_tool_calling_loops(self) -> list[list[MessageContainer]]: - loops: list[list[MessageContainer]] = [] - current_loop: list[MessageContainer] = [] - for i, message in enumerate(self._messages): - if _is_end_of_loop( - message, current_loop[-1].message if current_loop else None + for block in msg.content: + if isinstance(block, cacheable): + block.cache_control = None + if isinstance(block, ToolResultBlockParam) and isinstance( + block.content, list ): - loops.append(current_loop) - current_loop = [] - current_loop.append( - MessageContainer( - index=i, - message=message, - tokens=self._token_counts.retrieve_message_tokens(i), - ), - ) - loops.append(current_loop) - return loops + for nested in block.content: + nested.cache_control = None + + @staticmethod + def _set_cache_breakpoint(msg: MessageParam) -> None: + """Set a cache breakpoint on the last block of a message.""" + if isinstance(msg.content, str) or not msg.content: + return + last_block = msg.content[-1] + cacheable = ( + ImageBlockParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, + ) + if isinstance(last_block, cacheable): + last_block.cache_control = CacheControlEphemeralParam() + # ------------------------------------------------------------------ + # Summarization + # ------------------------------------------------------------------ -class TruncationStrategyFactory: - def create_truncation_strategy( - self, - tools: list[ToolParam] | None, - system: ActSystemPrompt | None, - messages: list[MessageParam], - model: str, - ) -> TruncationStrategy: - return TruncationStrategy( - tools=tools, - system=system, - messages=messages, - model=model, - ) + def _summarize_message_history(self) -> str: + """Ask the VLM to summarize the conversation history. + Returns: + A summary string of the conversation so far. + """ + messages_to_summarize = list(self._truncated_message_history) -class SimpleTruncationStrategyFactory(TruncationStrategyFactory): - def __init__( - self, - max_input_tokens: int = MAX_INPUT_TOKENS, - input_token_truncation_threshold: Annotated[ - float, Field(gt=0.0, lt=1.0) - ] = 0.75, - max_messages: int = MAX_MESSAGES, - message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, - token_counter: TokenCounter | None = None, - ) -> None: - self._max_input_tokens = max_input_tokens - self._input_token_truncation_threshold = input_token_truncation_threshold - self._max_messages = max_messages - self._message_truncation_threshold = message_truncation_threshold - self._token_counter = token_counter or SimpleTokenCounter() + # Ensure valid role alternation before the summarization prompt + if messages_to_summarize and messages_to_summarize[-1].role == "user": + messages_to_summarize.append( + MessageParam(role="assistant", content="I understand. Please go ahead.") + ) - def create_truncation_strategy( - self, - tools: list[ToolParam] | None, - system: ActSystemPrompt | None, - messages: list[MessageParam], - model: str, - ) -> TruncationStrategy: - return SimpleTruncationStrategy( - tools=tools, - system=system, - messages=messages, - model=model, - max_input_tokens=self._max_input_tokens, - input_token_truncation_threshold=self._input_token_truncation_threshold, - max_messages=self._max_messages, - message_truncation_threshold=self._message_truncation_threshold, - token_counter=self._token_counter, + messages_to_summarize.append( + MessageParam( + role="user", + content=( + "Please provide a concise summary of the conversation " + "history above. Focus on: key actions taken, results " + "observed, current state, and any pending tasks or " + "errors. This summary will replace the earlier " + "conversation history to save context space." + ), + ) + ) + + response = self._vlm_provider.create_message( + messages=messages_to_summarize, + max_tokens=2048, ) + + if isinstance(response.content, str): + return response.content + + texts = [ + block.text + for block in response.content + if isinstance(block, TextBlockParam) + ] + return "\n".join(texts) diff --git a/src/askui/speaker/agent_speaker.py b/src/askui/speaker/agent_speaker.py index f1155d24..cc03ef41 100644 --- a/src/askui/speaker/agent_speaker.py +++ b/src/askui/speaker/agent_speaker.py @@ -81,7 +81,7 @@ def handle_step( # Make API call to get agent response using VlmProvider response = conversation.vlm_provider.create_message( - messages=truncation_strategy.messages, + messages=truncation_strategy.truncated_messages, tools=conversation.tools, max_tokens=conversation.settings.messages.max_tokens, system=conversation.settings.messages.system, diff --git a/tests/unit/models/test_truncation_strategies.py b/tests/unit/models/test_truncation_strategies.py new file mode 100644 index 00000000..e50e5360 --- /dev/null +++ b/tests/unit/models/test_truncation_strategies.py @@ -0,0 +1,464 @@ +"""Unit tests for truncation strategies.""" + +from unittest.mock import MagicMock + +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, +) +from askui.models.shared.truncation_strategies import ( + AskUITruncationStrategy, +) + +IMAGE_REMOVED_PLACEHOLDER = "[Screenshot removed to reduce message history length]" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_base64_image_block() -> ImageBlockParam: + return ImageBlockParam( + source=Base64ImageSourceParam(data="abc123", media_type="image/png"), + ) + + +def _make_url_image_block() -> ImageBlockParam: + return ImageBlockParam( + source=UrlImageSourceParam(url="https://example.com/img.png"), + ) + + +def _make_tool_result_with_image(tool_use_id: str = "tool_1") -> ToolResultBlockParam: + return ToolResultBlockParam( + tool_use_id=tool_use_id, + content=[ + TextBlockParam(text="result text"), + _make_base64_image_block(), + ], + ) + + +def _make_vlm_provider() -> MagicMock: + provider = MagicMock() + provider.create_message.return_value = MessageParam( + role="assistant", + content="Summary of the conversation.", + ) + return provider + + +def _make_strategy( + vlm_provider: MagicMock | None = None, + n_images_to_keep: int = 3, + n_messages_to_keep: int = 10, + max_input_tokens: int = 100_000, +) -> AskUITruncationStrategy: + return AskUITruncationStrategy( + vlm_provider=vlm_provider or _make_vlm_provider(), + n_images_to_keep=n_images_to_keep, + n_messages_to_keep=n_messages_to_keep, + max_input_tokens=max_input_tokens, + ) + + +def _get_cache_control(block: ContentBlockParam) -> object: + """Safely get cache_control from a block (returns None for thinking blocks).""" + return getattr(block, "cache_control", None) + + +# --------------------------------------------------------------------------- +# Reset +# --------------------------------------------------------------------------- + + +class TestReset: + def test_reset_creates_independent_lists(self) -> None: + strategy = _make_strategy() + msgs = [MessageParam(role="user", content="hello")] + strategy.reset(msgs) + assert strategy.full_messages is not strategy.truncated_messages + + def test_reset_none_clears_both(self) -> None: + strategy = _make_strategy() + strategy.reset([MessageParam(role="user", content="hello")]) + strategy.reset() + assert strategy.full_messages == [] + assert strategy.truncated_messages == [] + + def test_reset_populates_both_histories(self) -> None: + strategy = _make_strategy() + msgs = [ + MessageParam(role="user", content="hi"), + MessageParam(role="assistant", content="hey"), + ] + strategy.reset(msgs) + assert len(strategy.full_messages) == 2 + assert len(strategy.truncated_messages) == 2 + + +# --------------------------------------------------------------------------- +# Append message +# --------------------------------------------------------------------------- + + +class TestAppendMessage: + def test_appends_to_both_histories(self) -> None: + strategy = _make_strategy() + msg = MessageParam(role="user", content="hello") + strategy.append_message(msg) + assert len(strategy.full_messages) == 1 + assert len(strategy.truncated_messages) == 1 + + def test_string_content_no_crash(self) -> None: + strategy = _make_strategy() + strategy.append_message(MessageParam(role="user", content="just text")) + assert strategy.truncated_messages[0].content == "just text" + + +# --------------------------------------------------------------------------- +# Image removal +# --------------------------------------------------------------------------- + + +class TestRemoveImages: + def test_strips_oldest_base64_images(self) -> None: + strategy = _make_strategy(n_images_to_keep=1) + # Append 3 messages each with a base64 image + for i in range(3): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam( + role=role, + content=[_make_base64_image_block()], + ) + ) + # Only the last image should remain; first two should be placeholders + truncated = strategy.truncated_messages + # Message 0: stripped + assert isinstance(truncated[0].content, list) + assert isinstance(truncated[0].content[0], TextBlockParam) + assert truncated[0].content[0].text == IMAGE_REMOVED_PLACEHOLDER + # Message 1: stripped + assert isinstance(truncated[1].content, list) + assert isinstance(truncated[1].content[0], TextBlockParam) + assert truncated[1].content[0].text == IMAGE_REMOVED_PLACEHOLDER + # Message 2: preserved + assert isinstance(truncated[2].content, list) + assert isinstance(truncated[2].content[0], ImageBlockParam) + + def test_skips_url_images(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + strategy.append_message( + MessageParam( + role="user", + content=[_make_url_image_block()], + ) + ) + # URL image should not be stripped + content = strategy.truncated_messages[0].content + assert isinstance(content, list) + assert isinstance(content[0], ImageBlockParam) + + def test_strips_images_inside_tool_results(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + strategy.append_message( + MessageParam( + role="user", + content=[_make_tool_result_with_image("tool_1")], + ) + ) + content = strategy.truncated_messages[0].content + assert isinstance(content, list) + tool_result = content[0] + assert isinstance(tool_result, ToolResultBlockParam) + assert isinstance(tool_result.content, list) + # First block is text (kept), second was image (stripped) + assert isinstance(tool_result.content[0], TextBlockParam) + assert tool_result.content[0].text == "result text" + assert isinstance(tool_result.content[1], TextBlockParam) + assert tool_result.content[1].text == IMAGE_REMOVED_PLACEHOLDER + + def test_preserves_non_image_blocks(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + strategy.append_message( + MessageParam( + role="user", + content=[ + TextBlockParam(text="keep me"), + _make_base64_image_block(), + ], + ) + ) + content = strategy.truncated_messages[0].content + assert isinstance(content, list) + assert isinstance(content[0], TextBlockParam) + assert content[0].text == "keep me" + + def test_full_messages_unaffected_by_stripping(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + strategy.append_message( + MessageParam( + role="user", + content=[_make_base64_image_block()], + ) + ) + # Full history should still have the original image + full_content = strategy.full_messages[0].content + assert isinstance(full_content, list) + assert isinstance(full_content[0], ImageBlockParam) + + def test_no_stripping_when_under_limit(self) -> None: + strategy = _make_strategy(n_images_to_keep=5) + strategy.append_message( + MessageParam( + role="user", + content=[_make_base64_image_block()], + ) + ) + content = strategy.truncated_messages[0].content + assert isinstance(content, list) + assert isinstance(content[0], ImageBlockParam) + + +# --------------------------------------------------------------------------- +# Cache breakpoints +# --------------------------------------------------------------------------- + + +class TestCacheBreakpoints: + def test_breakpoint_on_last_user_message(self) -> None: + strategy = _make_strategy() + strategy.append_message( + MessageParam(role="user", content=[TextBlockParam(text="hello")]) + ) + strategy.append_message( + MessageParam(role="assistant", content=[TextBlockParam(text="hi")]) + ) + # Last user message (index 0) should have cache_control on its last block + user_msg = strategy.truncated_messages[0] + assert isinstance(user_msg.content, list) + assert _get_cache_control(user_msg.content[-1]) is not None + + def test_breakpoint_at_image_removal_boundary(self) -> None: + strategy = _make_strategy(n_images_to_keep=1) + # Add messages with images - first two will be stripped + strategy.append_message( + MessageParam( + role="user", + content=[_make_base64_image_block()], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[_make_base64_image_block()], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[_make_base64_image_block()], + ) + ) + # Boundary message (last stripped = index 1) should have cache_control + boundary_msg = strategy.truncated_messages[1] + assert isinstance(boundary_msg.content, list) + assert _get_cache_control(boundary_msg.content[-1]) is not None + + def test_clears_previous_breakpoints(self) -> None: + strategy = _make_strategy() + # First append sets breakpoint on message 0 + strategy.append_message( + MessageParam(role="user", content=[TextBlockParam(text="first")]) + ) + assert isinstance(strategy.truncated_messages[0].content, list) + assert ( + _get_cache_control(strategy.truncated_messages[0].content[-1]) is not None + ) + # Second append should clear old breakpoint and set on new last user + strategy.append_message( + MessageParam(role="assistant", content=[TextBlockParam(text="reply")]) + ) + strategy.append_message( + MessageParam(role="user", content=[TextBlockParam(text="second")]) + ) + # Old user message (index 0) should have cache_control cleared + # New user message (index 2) should have it set + old_content = strategy.truncated_messages[0].content + new_content = strategy.truncated_messages[2].content + assert isinstance(old_content, list) + assert isinstance(new_content, list) + assert _get_cache_control(old_content[-1]) is None + assert _get_cache_control(new_content[-1]) is not None + + +# --------------------------------------------------------------------------- +# Truncation / summarization +# --------------------------------------------------------------------------- + + +class TestTruncation: + def test_truncate_replaces_history_with_summary(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=2) + # Add enough messages to truncate + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + # Force truncation + strategy.truncate() + msgs = strategy.truncated_messages + # First message should be the summary (user role) + assert msgs[0].role == "user" + assert msgs[0].content == "Summary of the conversation." + # Last 2 messages preserved + assert msgs[-1].content == "msg 5" + assert msgs[-2].content == "msg 4" + + def test_truncate_inserts_synthetic_assistant_for_alternation(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=2) + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + msgs = strategy.truncated_messages + # Summary (user) -> msgs[-2] is "msg 4" (user) + # So a synthetic assistant should be inserted between + assert msgs[0].role == "user" # summary + assert msgs[1].role == "assistant" # synthetic + assert "Understood" in str(msgs[1].content) + assert msgs[2].role == "user" # msg 4 + + def test_truncate_skips_when_too_few_messages(self) -> None: + strategy = _make_strategy(n_messages_to_keep=10) + for i in range(4): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + # Should not truncate - still 4 messages + assert len(strategy.truncated_messages) == 4 + + def test_truncate_resets_image_boundary(self) -> None: + strategy = _make_strategy(n_images_to_keep=0, n_messages_to_keep=2) + strategy.append_message( + MessageParam( + role="user", + content=[_make_base64_image_block()], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[TextBlockParam(text="ok")], + ) + ) + strategy.append_message( + MessageParam(role="user", content=[TextBlockParam(text="more")]) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[TextBlockParam(text="sure")], + ) + ) + # _image_removal_boundary_index should be set after image stripping + assert strategy._image_removal_boundary_index is not None # noqa: SLF001 + strategy.truncate() + assert strategy._image_removal_boundary_index is None # noqa: SLF001 + + def test_full_messages_preserved_after_truncation(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=2) + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + # Full messages should still have all 6 + assert len(strategy.full_messages) == 6 + # Truncated messages should be shorter + assert len(strategy.truncated_messages) < 6 + + def test_auto_truncation_on_token_limit(self) -> None: + vlm = _make_vlm_provider() + # Very low token threshold to trigger auto-truncation + strategy = _make_strategy( + vlm_provider=vlm, + n_messages_to_keep=2, + max_input_tokens=100, + ) + # Add messages with enough text to exceed 100 * 0.7 = 70 token threshold + strategy.append_message(MessageParam(role="user", content="x" * 300)) + strategy.append_message(MessageParam(role="assistant", content="y" * 300)) + strategy.append_message(MessageParam(role="user", content="z" * 300)) + # Should have been auto-truncated + vlm.create_message.assert_called_once() + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_messages_no_crash(self) -> None: + strategy = _make_strategy() + strategy.reset([]) + assert strategy.truncated_messages == [] + assert strategy.full_messages == [] + + def test_single_message_with_many_images(self) -> None: + strategy = _make_strategy(n_images_to_keep=1) + content: list[ContentBlockParam] = [ + _make_base64_image_block() for _ in range(5) + ] + strategy.append_message(MessageParam(role="user", content=content)) + result = strategy.truncated_messages[0].content + assert isinstance(result, list) + # First 4 should be placeholders, last should be image + placeholders = [b for b in result if isinstance(b, TextBlockParam)] + images = [b for b in result if isinstance(b, ImageBlockParam)] + assert len(placeholders) == 4 + assert len(images) == 1 + + def test_mixed_base64_and_url_images(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + content: list[ContentBlockParam] = [ + _make_base64_image_block(), + _make_url_image_block(), + _make_base64_image_block(), + ] + strategy.append_message(MessageParam(role="user", content=content)) + result = strategy.truncated_messages[0].content + assert isinstance(result, list) + # base64 images stripped, URL image kept + assert isinstance(result[0], TextBlockParam) # was base64 + assert isinstance(result[1], ImageBlockParam) # URL kept + assert isinstance(result[2], TextBlockParam) # was base64 + + def test_tool_use_blocks_preserved(self) -> None: + strategy = _make_strategy(n_images_to_keep=0) + strategy.append_message( + MessageParam( + role="assistant", + content=[ + ToolUseBlockParam( + id="t1", + input={"x": 1}, + name="my_tool", + type="tool_use", + ), + ], + ) + ) + result = strategy.truncated_messages[0].content + assert isinstance(result, list) + assert isinstance(result[0], ToolUseBlockParam) From 8f52b17438c58220137507e32c2b339886933147 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 1 Apr 2026 18:52:47 +0200 Subject: [PATCH 2/7] feat: make truncation_strategy an init param of Agents --- src/askui/agent_base.py | 3 +++ src/askui/android_agent.py | 3 +++ src/askui/computer_agent.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 2b2ff8c4..f150cfc5 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -23,6 +23,7 @@ LocateSettings, ) from askui.models.shared.tools import Tool, ToolCollection +from askui.models.shared.truncation_strategies import TruncationStrategy from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt from askui.telemetry.otel import OtelSettings, setup_opentelemetry_tracing from askui.tools.agent_os import AgentOs @@ -59,6 +60,7 @@ def __init__( agent_os: AgentOs | AndroidAgentOs | None = None, settings: AgentSettings | None = None, callbacks: list[ConversationCallback] | None = None, + truncation_strategy: TruncationStrategy | None = None, ) -> None: load_dotenv() self._reporter: Reporter = reporter or CompositeReporter(reporters=None) @@ -87,6 +89,7 @@ def __init__( image_qa_provider=self._image_qa_provider, detection_provider=self._detection_provider, reporter=self._reporter, + truncation_strategy=truncation_strategy, callbacks=_callbacks, ) diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py index b4fb0182..25438180 100644 --- a/src/askui/android_agent.py +++ b/src/askui/android_agent.py @@ -12,6 +12,7 @@ from askui.models.models import Point from askui.models.shared.settings import ActSettings, MessageSettings from askui.models.shared.tools import Tool +from askui.models.shared.truncation_strategies import TruncationStrategy from askui.prompts.act_prompts import create_android_agent_prompt from askui.tools.android.agent_os import ANDROID_KEY from askui.tools.android.agent_os_facade import AndroidAgentOsFacade @@ -74,6 +75,7 @@ def __init__( retry: Retry | None = None, act_tools: list[Tool] | None = None, callbacks: list[ConversationCallback] | None = None, + truncation_strategy: TruncationStrategy | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) self.os = PpadbAgentOs(device_identifier=device, reporter=reporter) @@ -85,6 +87,7 @@ def __init__( agent_os=self.os, settings=settings, callbacks=callbacks, + truncation_strategy=truncation_strategy, ) self.act_tool_collection.add_agent_os(self.act_agent_os_facade) # Override default act settings with Android-specific settings diff --git a/src/askui/computer_agent.py b/src/askui/computer_agent.py index d35a97a7..56232828 100644 --- a/src/askui/computer_agent.py +++ b/src/askui/computer_agent.py @@ -12,6 +12,7 @@ from askui.models.models import Point from askui.models.shared.settings import ActSettings, LocateSettings, MessageSettings from askui.models.shared.tools import Tool +from askui.models.shared.truncation_strategies import TruncationStrategy from askui.prompts.act_prompts import ( create_computer_agent_prompt, ) @@ -81,6 +82,7 @@ def __init__( retry: Retry | None = None, act_tools: list[Tool] | None = None, callbacks: list[ConversationCallback] | None = None, + truncation_strategy: TruncationStrategy | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) self.tools = tools or AgentToolbox( @@ -96,6 +98,7 @@ def __init__( agent_os=self.tools.os, settings=settings, callbacks=callbacks, + truncation_strategy=truncation_strategy, ) self.act_agent_os_facade: ComputerAgentOsFacade = ComputerAgentOsFacade( self.tools.os From c23f33d5a4107c0c10572f9419212a43223507d3 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 1 Apr 2026 21:58:19 +0200 Subject: [PATCH 3/7] feat: adds new `SummarizingTruncationStrategy` --- src/askui/models/shared/conversation.py | 9 +- .../models/shared/truncation_strategies.py | 545 ++++++++++++++---- .../unit/models/test_truncation_strategies.py | 283 ++++++++- 3 files changed, 731 insertions(+), 106 deletions(-) diff --git a/src/askui/models/shared/conversation.py b/src/askui/models/shared/conversation.py index dc4fb2b3..fe6f2901 100644 --- a/src/askui/models/shared/conversation.py +++ b/src/askui/models/shared/conversation.py @@ -13,7 +13,7 @@ from askui.models.shared.settings import ActSettings from askui.models.shared.tools import ToolCollection from askui.models.shared.truncation_strategies import ( - AskUITruncationStrategy, + SummarizingTruncationStrategy, TruncationStrategy, ) from askui.reporting import NULL_REPORTER, Reporter @@ -98,8 +98,11 @@ def __init__( self._step_index: int = 0 # truncation strategy - self._truncation_strategy = truncation_strategy or AskUITruncationStrategy( - vlm_provider=vlm_provider, + self._truncation_strategy = ( + truncation_strategy + or SummarizingTruncationStrategy( + vlm_provider=vlm_provider, + ) ) # Track if cache execution was used (to prevent recording during playback) diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index eebaa720..f83a9c93 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -1,7 +1,9 @@ """Truncation strategies for managing conversation message history.""" +import json import logging from abc import ABC, abstractmethod +from pathlib import Path from typing_extensions import override @@ -33,6 +35,103 @@ """Text used to replace stripped base64 images.""" +def _has_orphaned_tool_results(msg: MessageParam) -> bool: + """Check if a message contains tool_result blocks. + + Such a message cannot be the first in ``recent`` + because the preceding assistant message with the + matching tool_use would be lost. + """ + if msg.role != "user" or isinstance(msg.content, str): + return False + return any(isinstance(b, ToolResultBlockParam) for b in msg.content) + + +def _summarize_message_history( + vlm_provider: VlmProvider, + messages: list[MessageParam], +) -> str: + """Ask the VLM to summarize the conversation history. + + Args: + vlm_provider: VLM provider to use for summarization. + messages: Messages to summarize. + + Returns: + A summary string of the conversation so far. + """ + messages_to_summarize = list(messages) + + # Ensure valid role alternation + if messages_to_summarize and messages_to_summarize[-1].role == "user": + messages_to_summarize.append( + MessageParam( + role="assistant", + content="I understand. Please go ahead.", + ) + ) + + messages_to_summarize.append( + MessageParam( + role="user", + content=( + "Please provide a concise summary of the " + "conversation history above. Focus on: key " + "actions taken, results observed, current " + "state, and any pending tasks or errors. " + "This summary will replace the earlier " + "conversation history to save context space." + ), + ) + ) + + response = vlm_provider.create_message( + messages=messages_to_summarize, + max_tokens=2048, + ) + + if isinstance(response.content, str): + return response.content + + texts = [ + block.text for block in response.content if isinstance(block, TextBlockParam) + ] + return "\n".join(texts) + + +def _clear_cache_control(msg: MessageParam) -> None: + """Clear ``cache_control`` on all blocks of a message.""" + if isinstance(msg.content, str): + return + cacheable = ( + ImageBlockParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, + ) + for block in msg.content: + if isinstance(block, cacheable): + block.cache_control = None + if isinstance(block, ToolResultBlockParam) and isinstance(block.content, list): + for nested in block.content: + nested.cache_control = None + + +def _set_cache_breakpoint(msg: MessageParam) -> None: + """Set cache breakpoint on last block of a message.""" + if isinstance(msg.content, str) or not msg.content: + return + last_block = msg.content[-1] + cacheable = ( + ImageBlockParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, + ) + if isinstance(last_block, cacheable): + last_block.cache_control = CacheControlEphemeralParam() + + class TruncationStrategy(ABC): """Abstract base class for truncation strategies. @@ -44,9 +143,11 @@ class TruncationStrategy(ABC): history summarized for LLM calls Args: - max_messages: Maximum number of messages before forcing truncation. + max_messages: Maximum number of messages before + forcing truncation. max_input_tokens: Maximum input tokens for the endpoint. - truncation_threshold: Fraction of `max_input_tokens` at which to truncate. + truncation_threshold: Fraction of `max_input_tokens` + at which to truncate. """ def __init__( @@ -75,8 +176,8 @@ def truncate(self) -> None: def reset(self, messages: list[MessageParam] | None = None) -> None: """Reset message histories with optional initial messages. - Creates independent copies so modifications to truncated history - do not affect full history. + Creates independent copies so modifications to truncated + history do not affect full history. Args: messages: Initial messages to populate both histories. @@ -100,24 +201,33 @@ def full_messages(self) -> list[MessageParam]: return self._full_message_history -class AskUITruncationStrategy(TruncationStrategy): - """Truncation strategy that strips old images, manages cache breakpoints, - and summarizes. +class SlidingImageWindowSummarizingTruncationStrategy(TruncationStrategy): + """Truncation strategy that strips old images, manages + cache breakpoints, and summarizes. On each appended message: - 1. Strips base64 images beyond `n_images_to_keep` (oldest first) - 2. Places dual cache breakpoints (at image-removal boundary + last user message) - 3. If token count exceeds threshold, summarizes the history via the VLM + 1. Strips base64 images beyond `n_images_to_keep` + (oldest first) + 2. Places dual cache breakpoints (at image-removal + boundary + last user message) + 3. If token count exceeds threshold, summarizes + the history via the VLM Args: - vlm_provider: VLM provider used for summarization calls. - n_images_to_keep: Number of most-recent base64 images to retain. - n_messages_to_keep: Number of most-recent messages to preserve - during summarization. - max_messages: Maximum number of messages before forcing truncation. - max_input_tokens: Maximum input tokens for the endpoint. - truncation_threshold: Fraction of `max_input_tokens` at which to truncate. + vlm_provider: VLM provider used for summarization. + n_images_to_keep: Number of most-recent base64 images + to retain. + n_messages_to_keep: Number of most-recent messages to + preserve during summarization. + max_messages: Maximum number of messages before + forcing truncation. + max_input_tokens: Maximum input tokens for the + endpoint. + truncation_threshold: Fraction of `max_input_tokens` + at which to truncate. + debug_dir: When set, write diagnostic snapshots to + this directory after each append/truncate. """ def __init__( @@ -128,6 +238,7 @@ def __init__( max_messages: int = MAX_MESSAGES, max_input_tokens: int = MAX_INPUT_TOKENS, truncation_threshold: float = TRUNCATION_THRESHOLD, + debug_dir: Path | None = None, ) -> None: super().__init__(max_messages, max_input_tokens, truncation_threshold) self._vlm_provider = vlm_provider @@ -135,11 +246,34 @@ def __init__( self._n_messages_to_keep = n_messages_to_keep self._token_counter = SimpleTokenCounter() self._image_removal_boundary_index: int | None = None + self._debug_dir = debug_dir + self._debug_step = 0 + + msg = """CAUTION: The Truncation Strategy you are using is experimental! + While it will lead to faster executions in longer runs it might crash or + lead to overall unexpected behavior! If in doubt, we recommend using the + default truncation strategy instead.""" + logger.warning(msg) + + if self._debug_dir is not None: + self._debug_dir.mkdir(parents=True, exist_ok=True) + # Write config + config = { + "n_images_to_keep": n_images_to_keep, + "n_messages_to_keep": n_messages_to_keep, + "max_messages": max_messages, + "max_input_tokens": max_input_tokens, + "truncation_threshold": truncation_threshold, + "absolute_threshold": self._absolute_truncation_threshold, + } + (self._debug_dir / "config.json").write_text( + json.dumps(config, indent=2), encoding="utf-8" + ) @override def append_message(self, message: MessageParam) -> None: - """Append a message and apply image stripping, cache breakpoints, - and truncation. + """Append a message and apply image stripping, + cache breakpoints, and truncation. Args: message: The message to append. @@ -157,32 +291,56 @@ def append_message(self, message: MessageParam) -> None: token_counts = self._token_counter.count_tokens( messages=self._truncated_message_history, ) + truncated = False if ( len(self._truncated_message_history) > self._max_messages or token_counts.total > self._absolute_truncation_threshold ): self.truncate() + truncated = True + + self._write_debug_snapshot( + event="truncate" if truncated else "append", + token_total=token_counts.total, + ) @override def truncate(self) -> None: - """Summarize old messages and keep only the most recent ones.""" + """Summarize old messages and keep only recent ones.""" if len(self._truncated_message_history) <= self._n_messages_to_keep: - msg = "Cannot truncate as there are too few messages in the history" + msg = "Cannot truncate: too few messages in history" logger.warning(msg) return - summary = self._summarize_message_history() + logger.info("Summarizing message history") + summary = _summarize_message_history( + self._vlm_provider, self._truncated_message_history + ) - # Keep most recent messages - recent = self._truncated_message_history[-self._n_messages_to_keep :] + # Find a safe cut point that doesn't orphan tool_results. + # A user message with tool_result blocks requires the + # preceding assistant message to contain the matching + # tool_use blocks, so we must not start `recent` on one. + cut = len(self._truncated_message_history) - self._n_messages_to_keep + while cut > 0 and _has_orphaned_tool_results( + self._truncated_message_history[cut] + ): + cut -= 1 + + if cut <= 0: + msg = "Cannot truncate: no safe cut point found" + logger.warning(msg) + return - # Build new history starting with the summary as a user message + recent = self._truncated_message_history[cut:] + + # Build new history with the summary as a user message new_messages: list[MessageParam] = [ MessageParam(role="user", content=summary), ] - # Ensure valid role alternation: if first recent message is also - # "user", insert a synthetic assistant acknowledgement. + # Ensure valid role alternation: if first recent message + # is also "user", insert a synthetic assistant ack. if recent and recent[0].role == "user": new_messages.append( MessageParam( @@ -203,12 +361,12 @@ def truncate(self) -> None: # ------------------------------------------------------------------ def _remove_images(self) -> None: - """Strip old base64 images from truncated history, keeping the most recent ones. + """Strip old base64 images from truncated history. - Walks from the beginning of the message list and replaces excess - base64 `ImageBlockParam` blocks with text placeholders. Also - recurses into `ToolResultBlockParam.content` lists. URL-based - images are never stripped. + Walks from the beginning and replaces excess base64 + `ImageBlockParam` blocks with text placeholders. Also + recurses into `ToolResultBlockParam.content` lists. + URL-based images are never stripped. """ total = self._count_base64_images(self._truncated_message_history) to_remove = total - self._n_images_to_keep @@ -236,8 +394,10 @@ def _remove_images(self) -> None: removed += removed_in_msg @staticmethod - def _count_base64_images(messages: list[MessageParam]) -> int: - """Count total base64 image blocks across all messages.""" + def _count_base64_images( + messages: list[MessageParam], + ) -> int: + """Count total base64 image blocks across messages.""" count = 0 for msg in messages: if isinstance(msg.content, str): @@ -252,7 +412,8 @@ def _count_base64_images(messages: list[MessageParam]) -> int: ): for nested in block.content: if isinstance(nested, ImageBlockParam) and isinstance( - nested.source, Base64ImageSourceParam + nested.source, + Base64ImageSourceParam, ): count += 1 return count @@ -262,14 +423,14 @@ def _strip_base64_images( content: list[ContentBlockParam], max_to_strip: int, ) -> tuple[list[ContentBlockParam], int]: - """Strip up to `max_to_strip` base64 images from a content block list. + """Strip up to `max_to_strip` base64 images. Args: content: The content blocks to process. max_to_strip: Maximum number of images to strip. Returns: - Tuple of (new content list, number of images stripped). + Tuple of (new content list, count stripped). """ stripped = 0 new_content: list[ContentBlockParam] = [] @@ -292,7 +453,10 @@ def _strip_base64_images( if ( stripped < max_to_strip and isinstance(nested, ImageBlockParam) - and isinstance(nested.source, Base64ImageSourceParam) + and isinstance( + nested.source, + Base64ImageSourceParam, + ) ): new_nested.append( TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER) @@ -320,8 +484,8 @@ def _strip_base64_images( def _move_cache_breakpoints(self) -> None: """Place dual cache breakpoints on truncated history. - - **Breakpoint 1** – at the image-removal boundary (stable prefix). - - **Breakpoint 2** – on the last user message (recent context). + - **Breakpoint 1** – image-removal boundary. + - **Breakpoint 2** – last user message. """ # Clear all existing cache_control for msg in self._truncated_message_history: @@ -345,81 +509,262 @@ def _move_cache_breakpoints(self) -> None: @staticmethod def _clear_cache_control(msg: MessageParam) -> None: - """Clear ``cache_control`` on all blocks in a message.""" - if isinstance(msg.content, str): - return - cacheable = ( - ImageBlockParam, - TextBlockParam, - ToolResultBlockParam, - ToolUseBlockParam, - ) - for block in msg.content: - if isinstance(block, cacheable): - block.cache_control = None - if isinstance(block, ToolResultBlockParam) and isinstance( - block.content, list - ): - for nested in block.content: - nested.cache_control = None + """Clear ``cache_control`` on all blocks.""" + _clear_cache_control(msg) @staticmethod def _set_cache_breakpoint(msg: MessageParam) -> None: - """Set a cache breakpoint on the last block of a message.""" - if isinstance(msg.content, str) or not msg.content: - return - last_block = msg.content[-1] - cacheable = ( - ImageBlockParam, - TextBlockParam, - ToolResultBlockParam, - ToolUseBlockParam, - ) - if isinstance(last_block, cacheable): - last_block.cache_control = CacheControlEphemeralParam() + """Set cache breakpoint on last block of a message.""" + _set_cache_breakpoint(msg) # ------------------------------------------------------------------ - # Summarization + # Debug diagnostics # ------------------------------------------------------------------ - def _summarize_message_history(self) -> str: - """Ask the VLM to summarize the conversation history. + def _write_debug_snapshot( + self, + event: str, + token_total: int = 0, + ) -> None: + """Write a diagnostic snapshot to the debug dir. - Returns: - A summary string of the conversation so far. + Each snapshot summarises both message histories + compactly (no base64 data) so you can verify at a + glance that image stripping, cache breakpoints, and + truncation work correctly. + + Args: + event: ``"append"`` or ``"truncate"``. + token_total: Estimated token count for the + truncated history. """ - messages_to_summarize = list(self._truncated_message_history) + if self._debug_dir is None: + return - # Ensure valid role alternation before the summarization prompt - if messages_to_summarize and messages_to_summarize[-1].role == "user": - messages_to_summarize.append( - MessageParam(role="assistant", content="I understand. Please go ahead.") - ) + self._debug_step += 1 + + full_imgs = self._count_base64_images(self._full_message_history) + trunc_imgs = self._count_base64_images(self._truncated_message_history) + + snapshot: dict[str, object] = { + "step": self._debug_step, + "event": event, + "full_msg_count": len(self._full_message_history), + "full_base64_images": full_imgs, + "truncated_msg_count": len(self._truncated_message_history), + "truncated_base64_images": trunc_imgs, + "images_stripped": full_imgs - trunc_imgs, + "image_boundary_idx": (self._image_removal_boundary_index), + "token_estimate": token_total, + "threshold": self._absolute_truncation_threshold, + "truncated_messages": [ + self._summarise_message(i, m) + for i, m in enumerate(self._truncated_message_history) + ], + } + + filename = f"step_{self._debug_step:03d}_{event}.json" + (self._debug_dir / filename).write_text( + json.dumps(snapshot, indent=2, ensure_ascii=False), + encoding="utf-8", + ) - messages_to_summarize.append( - MessageParam( - role="user", - content=( - "Please provide a concise summary of the conversation " - "history above. Focus on: key actions taken, results " - "observed, current state, and any pending tasks or " - "errors. This summary will replace the earlier " - "conversation history to save context space." - ), + def _summarise_message(self, index: int, msg: MessageParam) -> dict[str, object]: + """Build a compact summary of a message for debug. + + Returns a dict with role, block descriptions, and + cache breakpoint info — no base64 data. + """ + result: dict[str, object] = { + "index": index, + "role": msg.role, + } + + if isinstance(msg.content, str): + preview = msg.content + if len(preview) > 120: + preview = preview[:120] + "..." + result["content"] = preview + return result + + blocks: list[str] = [] + has_cache_bp = False + + for block in msg.content: + desc = self._describe_block(block) + blocks.append(desc) + cc = getattr(block, "cache_control", None) + if cc is not None: + has_cache_bp = True + + result["blocks"] = blocks + result["has_cache_breakpoint"] = has_cache_bp + return result + + @staticmethod + def _describe_block( + block: ContentBlockParam, + ) -> str: + """One-line description of a content block.""" + if isinstance(block, TextBlockParam): + preview = block.text[:60] + if len(block.text) > 60: + preview += "..." + return f"text({preview})" + + if isinstance(block, ImageBlockParam): + kind = ( + "base64" if isinstance(block.source, Base64ImageSourceParam) else "url" ) + return f"image:{kind}" + + if isinstance(block, ToolUseBlockParam): + return f"tool_use({block.name})" + + if isinstance(block, ToolResultBlockParam): + return _describe_tool_result(block) + + return block.type + + +class SummarizingTruncationStrategy(TruncationStrategy): + """Truncation strategy that summarizes when limits are hit. + + Unlike `SlidingImageWindowSummarizingTruncationStrategy`, + this strategy does **not** strip images. It places a + single cache breakpoint on the last user message (moving + it forward on each append) and summarizes the conversation + history via the VLM when the token or message count + exceeds the configured threshold. + + Args: + vlm_provider: VLM provider used for summarization. + n_messages_to_keep: Number of most-recent messages to + preserve during summarization. + max_messages: Maximum number of messages before + forcing truncation. + max_input_tokens: Maximum input tokens for the + endpoint. + truncation_threshold: Fraction of `max_input_tokens` + at which to truncate. + """ + + def __init__( + self, + vlm_provider: VlmProvider, + n_messages_to_keep: int = 10, + max_messages: int = MAX_MESSAGES, + max_input_tokens: int = MAX_INPUT_TOKENS, + truncation_threshold: float = TRUNCATION_THRESHOLD, + ) -> None: + super().__init__(max_messages, max_input_tokens, truncation_threshold) + self._vlm_provider = vlm_provider + self._n_messages_to_keep = n_messages_to_keep + self._token_counter = SimpleTokenCounter() + + @override + def append_message(self, message: MessageParam) -> None: + """Append a message, move cache breakpoint, summarize + if limits are hit. + + Places a cache breakpoint on the last user message + and clears it from any previous position so the LLM + caches the full prefix optimally. + + Args: + message: The message to append. + """ + self._full_message_history.append(message) + self._truncated_message_history.append(message) + + # Move cache breakpoint to last user message + self._move_cache_breakpoint() + + token_counts = self._token_counter.count_tokens( + messages=self._truncated_message_history, ) + if ( + len(self._truncated_message_history) > self._max_messages + or token_counts.total > self._absolute_truncation_threshold + ): + self.truncate() + + def _move_cache_breakpoint(self) -> None: + """Place a cache breakpoint on the last user message. - response = self._vlm_provider.create_message( - messages=messages_to_summarize, - max_tokens=2048, + Clears ``cache_control`` from the previous last user + message first so only one breakpoint exists at a time. + """ + found_last = False + for msg in reversed(self._truncated_message_history): + if msg.role != "user": + continue + if not found_last: + found_last = True + _set_cache_breakpoint(msg) + else: + _clear_cache_control(msg) + break + + @override + def truncate(self) -> None: + """Summarize old messages and keep only recent ones.""" + if len(self._truncated_message_history) <= self._n_messages_to_keep: + msg = "Cannot truncate: too few messages in history" + logger.warning(msg) + return + + logger.info("Summarizing message history") + summary = _summarize_message_history( + self._vlm_provider, self._truncated_message_history ) - if isinstance(response.content, str): - return response.content + # Find a safe cut point that doesn't orphan + # tool_results from their tool_use. + cut = len(self._truncated_message_history) - self._n_messages_to_keep + while cut > 0 and _has_orphaned_tool_results( + self._truncated_message_history[cut] + ): + cut -= 1 + + if cut <= 0: + msg = "Cannot truncate: no safe cut point found" + logger.warning(msg) + return + + recent = self._truncated_message_history[cut:] - texts = [ - block.text - for block in response.content - if isinstance(block, TextBlockParam) + new_messages: list[MessageParam] = [ + MessageParam(role="user", content=summary), ] - return "\n".join(texts) + + # Ensure valid role alternation + if recent and recent[0].role == "user": + new_messages.append( + MessageParam( + role="assistant", + content=( + "Understood. I'll continue based on " + "the conversation summary above." + ), + ) + ) + + new_messages.extend(recent) + self._truncated_message_history = new_messages + + +def _describe_tool_result( + block: ToolResultBlockParam, +) -> str: + """One-line description of a tool result block.""" + if isinstance(block.content, str): + return f"tool_result({block.content[:40]})" + nested = [] + for n in block.content: + if isinstance(n, TextBlockParam): + nested.append(f"text({n.text[:30]})") + elif isinstance(n, ImageBlockParam): + kind = "b64" if isinstance(n.source, Base64ImageSourceParam) else "url" + nested.append(f"img:{kind}") + return f"tool_result[{', '.join(nested)}]" diff --git a/tests/unit/models/test_truncation_strategies.py b/tests/unit/models/test_truncation_strategies.py index e50e5360..a1055f28 100644 --- a/tests/unit/models/test_truncation_strategies.py +++ b/tests/unit/models/test_truncation_strategies.py @@ -13,7 +13,8 @@ UrlImageSourceParam, ) from askui.models.shared.truncation_strategies import ( - AskUITruncationStrategy, + SlidingImageWindowSummarizingTruncationStrategy, + SummarizingTruncationStrategy, ) IMAGE_REMOVED_PLACEHOLDER = "[Screenshot removed to reduce message history length]" @@ -60,8 +61,8 @@ def _make_strategy( n_images_to_keep: int = 3, n_messages_to_keep: int = 10, max_input_tokens: int = 100_000, -) -> AskUITruncationStrategy: - return AskUITruncationStrategy( +) -> SlidingImageWindowSummarizingTruncationStrategy: + return SlidingImageWindowSummarizingTruncationStrategy( vlm_provider=vlm_provider or _make_vlm_provider(), n_images_to_keep=n_images_to_keep, n_messages_to_keep=n_messages_to_keep, @@ -387,6 +388,71 @@ def test_full_messages_preserved_after_truncation(self) -> None: # Truncated messages should be shorter assert len(strategy.truncated_messages) < 6 + def test_truncate_preserves_tool_use_tool_result_pairs(self) -> None: + vlm = _make_vlm_provider() + # n_messages_to_keep=3: naive cut would start on the + # user tool_result, orphaning it from its tool_use. + strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=3) + # Build a realistic tool-calling conversation: + # user(goal) -> asst(tool_use) -> user(tool_result) + # -> asst(tool_use) -> user(tool_result) + # -> asst(text) + strategy.append_message(MessageParam(role="user", content="do something")) + strategy.append_message( + MessageParam( + role="assistant", + content=[ + ToolUseBlockParam( + id="tu_1", input={}, name="tool_a", type="tool_use" + ), + ], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[ + ToolResultBlockParam(tool_use_id="tu_1", content="result 1"), + ], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[ + ToolUseBlockParam( + id="tu_2", input={}, name="tool_b", type="tool_use" + ), + ], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[ + ToolResultBlockParam(tool_use_id="tu_2", content="result 2"), + ], + ) + ) + strategy.append_message(MessageParam(role="assistant", content="all done")) + + strategy.truncate() + msgs = strategy.truncated_messages + + # Every user message with tool_results must be preceded + # by an assistant message (not the summary or synthetic). + for i, m in enumerate(msgs): + if m.role != "user" or isinstance(m.content, str): + continue + has_tr = any(isinstance(b, ToolResultBlockParam) for b in m.content) + if has_tr: + prev = msgs[i - 1] + assert prev.role == "assistant" + assert ( + not isinstance(prev.content, str) + or "Understood" not in prev.content + ), f"tool_result at index {i} follows synthetic assistant" + def test_auto_truncation_on_token_limit(self) -> None: vlm = _make_vlm_provider() # Very low token threshold to trigger auto-truncation @@ -462,3 +528,214 @@ def test_tool_use_blocks_preserved(self) -> None: result = strategy.truncated_messages[0].content assert isinstance(result, list) assert isinstance(result[0], ToolUseBlockParam) + + +# --------------------------------------------------------------------------- +# SummarizingTruncationStrategy +# --------------------------------------------------------------------------- + + +def _make_summarizing_strategy( + vlm_provider: MagicMock | None = None, + n_messages_to_keep: int = 10, + max_input_tokens: int = 100_000, +) -> SummarizingTruncationStrategy: + return SummarizingTruncationStrategy( + vlm_provider=vlm_provider or _make_vlm_provider(), + n_messages_to_keep=n_messages_to_keep, + max_input_tokens=max_input_tokens, + ) + + +class TestSummarizingAppend: + def test_appends_to_both_histories(self) -> None: + strategy = _make_summarizing_strategy() + msg = MessageParam(role="user", content="hello") + strategy.append_message(msg) + assert len(strategy.full_messages) == 1 + assert len(strategy.truncated_messages) == 1 + + def test_does_not_strip_images(self) -> None: + strategy = _make_summarizing_strategy() + for i in range(5): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam( + role=role, + content=[_make_base64_image_block()], + ) + ) + # All images should remain since no image stripping + for msg in strategy.truncated_messages: + assert isinstance(msg.content, list) + assert isinstance(msg.content[0], ImageBlockParam) + + def test_sets_cache_breakpoint_on_last_user_message(self) -> None: + strategy = _make_summarizing_strategy() + strategy.append_message( + MessageParam( + role="user", + content=[TextBlockParam(text="hello")], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[TextBlockParam(text="hi")], + ) + ) + # Last (and only) user message should have cache breakpoint + user_msg = strategy.truncated_messages[0] + assert isinstance(user_msg.content, list) + assert _get_cache_control(user_msg.content[-1]) is not None + # Assistant message should not + asst_msg = strategy.truncated_messages[1] + assert isinstance(asst_msg.content, list) + assert _get_cache_control(asst_msg.content[-1]) is None + + def test_moves_cache_breakpoint_forward(self) -> None: + strategy = _make_summarizing_strategy() + strategy.append_message( + MessageParam( + role="user", + content=[TextBlockParam(text="first")], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[TextBlockParam(text="reply")], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[TextBlockParam(text="second")], + ) + ) + # Old user message (index 0) should have cache_control cleared + old_content = strategy.truncated_messages[0].content + assert isinstance(old_content, list) + assert _get_cache_control(old_content[-1]) is None + # New user message (index 2) should have it set + new_content = strategy.truncated_messages[2].content + assert isinstance(new_content, list) + assert _get_cache_control(new_content[-1]) is not None + + +class TestSummarizingTruncation: + def test_truncate_replaces_history_with_summary(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2) + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + msgs = strategy.truncated_messages + assert msgs[0].role == "user" + assert msgs[0].content == "Summary of the conversation." + assert msgs[-1].content == "msg 5" + assert msgs[-2].content == "msg 4" + + def test_truncate_inserts_synthetic_assistant(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2) + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + msgs = strategy.truncated_messages + assert msgs[0].role == "user" + assert msgs[1].role == "assistant" + assert "Understood" in str(msgs[1].content) + + def test_truncate_skips_when_too_few_messages(self) -> None: + strategy = _make_summarizing_strategy(n_messages_to_keep=10) + for i in range(4): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + assert len(strategy.truncated_messages) == 4 + + def test_full_messages_preserved_after_truncation(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2) + for i in range(6): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message(MessageParam(role=role, content=f"msg {i}")) + strategy.truncate() + assert len(strategy.full_messages) == 6 + assert len(strategy.truncated_messages) < 6 + + def test_preserves_tool_use_tool_result_pairs(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=3) + strategy.append_message(MessageParam(role="user", content="do something")) + strategy.append_message( + MessageParam( + role="assistant", + content=[ + ToolUseBlockParam( + id="tu_1", + input={}, + name="tool_a", + type="tool_use", + ), + ], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[ + ToolResultBlockParam(tool_use_id="tu_1", content="result 1"), + ], + ) + ) + strategy.append_message( + MessageParam( + role="assistant", + content=[ + ToolUseBlockParam( + id="tu_2", + input={}, + name="tool_b", + type="tool_use", + ), + ], + ) + ) + strategy.append_message( + MessageParam( + role="user", + content=[ + ToolResultBlockParam(tool_use_id="tu_2", content="result 2"), + ], + ) + ) + strategy.append_message(MessageParam(role="assistant", content="all done")) + strategy.truncate() + msgs = strategy.truncated_messages + for i, m in enumerate(msgs): + if m.role != "user" or isinstance(m.content, str): + continue + has_tr = any(isinstance(b, ToolResultBlockParam) for b in m.content) + if has_tr: + prev = msgs[i - 1] + assert prev.role == "assistant" + assert ( + not isinstance(prev.content, str) + or "Understood" not in prev.content + ), f"tool_result at index {i} follows synthetic" + + def test_auto_truncation_on_token_limit(self) -> None: + vlm = _make_vlm_provider() + strategy = _make_summarizing_strategy( + vlm_provider=vlm, + n_messages_to_keep=2, + max_input_tokens=100, + ) + strategy.append_message(MessageParam(role="user", content="x" * 300)) + strategy.append_message(MessageParam(role="assistant", content="y" * 300)) + strategy.append_message(MessageParam(role="user", content="z" * 300)) + vlm.create_message.assert_called_once() From 64f0c006fced6074fd284ce76b43c9f343b0d7b4 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 1 Apr 2026 22:16:12 +0200 Subject: [PATCH 4/7] feat: add debug writing logic --- .../models/shared/truncation_strategies.py | 173 ++---------------- 1 file changed, 16 insertions(+), 157 deletions(-) diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index f83a9c93..77db47fc 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -1,9 +1,7 @@ """Truncation strategies for managing conversation message history.""" -import json import logging from abc import ABC, abstractmethod -from pathlib import Path from typing_extensions import override @@ -226,8 +224,6 @@ class SlidingImageWindowSummarizingTruncationStrategy(TruncationStrategy): endpoint. truncation_threshold: Fraction of `max_input_tokens` at which to truncate. - debug_dir: When set, write diagnostic snapshots to - this directory after each append/truncate. """ def __init__( @@ -238,7 +234,6 @@ def __init__( max_messages: int = MAX_MESSAGES, max_input_tokens: int = MAX_INPUT_TOKENS, truncation_threshold: float = TRUNCATION_THRESHOLD, - debug_dir: Path | None = None, ) -> None: super().__init__(max_messages, max_input_tokens, truncation_threshold) self._vlm_provider = vlm_provider @@ -246,29 +241,13 @@ def __init__( self._n_messages_to_keep = n_messages_to_keep self._token_counter = SimpleTokenCounter() self._image_removal_boundary_index: int | None = None - self._debug_dir = debug_dir - self._debug_step = 0 - - msg = """CAUTION: The Truncation Strategy you are using is experimental! - While it will lead to faster executions in longer runs it might crash or - lead to overall unexpected behavior! If in doubt, we recommend using the - default truncation strategy instead.""" - logger.warning(msg) - - if self._debug_dir is not None: - self._debug_dir.mkdir(parents=True, exist_ok=True) - # Write config - config = { - "n_images_to_keep": n_images_to_keep, - "n_messages_to_keep": n_messages_to_keep, - "max_messages": max_messages, - "max_input_tokens": max_input_tokens, - "truncation_threshold": truncation_threshold, - "absolute_threshold": self._absolute_truncation_threshold, - } - (self._debug_dir / "config.json").write_text( - json.dumps(config, indent=2), encoding="utf-8" - ) + try: + from askui.models.shared.truncation_debug import TruncationDebugWriter + + self._debug_writer = TruncationDebugWriter() + except ImportError: + self._debug_writer = None + logger.exception("Could not add truncation debug writer") @override def append_message(self, message: MessageParam) -> None: @@ -299,10 +278,15 @@ def append_message(self, message: MessageParam) -> None: self.truncate() truncated = True - self._write_debug_snapshot( - event="truncate" if truncated else "append", - token_total=token_counts.total, - ) + if self._debug_writer: + self._debug_writer.write_snapshot( + event="truncate" if truncated else "append", + full_messages=self._full_message_history, + truncated_messages=self._truncated_message_history, + token_estimate=token_counts.total, + threshold=self._absolute_truncation_threshold, + image_boundary_idx=self._image_removal_boundary_index, + ) @override def truncate(self) -> None: @@ -517,115 +501,6 @@ def _set_cache_breakpoint(msg: MessageParam) -> None: """Set cache breakpoint on last block of a message.""" _set_cache_breakpoint(msg) - # ------------------------------------------------------------------ - # Debug diagnostics - # ------------------------------------------------------------------ - - def _write_debug_snapshot( - self, - event: str, - token_total: int = 0, - ) -> None: - """Write a diagnostic snapshot to the debug dir. - - Each snapshot summarises both message histories - compactly (no base64 data) so you can verify at a - glance that image stripping, cache breakpoints, and - truncation work correctly. - - Args: - event: ``"append"`` or ``"truncate"``. - token_total: Estimated token count for the - truncated history. - """ - if self._debug_dir is None: - return - - self._debug_step += 1 - - full_imgs = self._count_base64_images(self._full_message_history) - trunc_imgs = self._count_base64_images(self._truncated_message_history) - - snapshot: dict[str, object] = { - "step": self._debug_step, - "event": event, - "full_msg_count": len(self._full_message_history), - "full_base64_images": full_imgs, - "truncated_msg_count": len(self._truncated_message_history), - "truncated_base64_images": trunc_imgs, - "images_stripped": full_imgs - trunc_imgs, - "image_boundary_idx": (self._image_removal_boundary_index), - "token_estimate": token_total, - "threshold": self._absolute_truncation_threshold, - "truncated_messages": [ - self._summarise_message(i, m) - for i, m in enumerate(self._truncated_message_history) - ], - } - - filename = f"step_{self._debug_step:03d}_{event}.json" - (self._debug_dir / filename).write_text( - json.dumps(snapshot, indent=2, ensure_ascii=False), - encoding="utf-8", - ) - - def _summarise_message(self, index: int, msg: MessageParam) -> dict[str, object]: - """Build a compact summary of a message for debug. - - Returns a dict with role, block descriptions, and - cache breakpoint info — no base64 data. - """ - result: dict[str, object] = { - "index": index, - "role": msg.role, - } - - if isinstance(msg.content, str): - preview = msg.content - if len(preview) > 120: - preview = preview[:120] + "..." - result["content"] = preview - return result - - blocks: list[str] = [] - has_cache_bp = False - - for block in msg.content: - desc = self._describe_block(block) - blocks.append(desc) - cc = getattr(block, "cache_control", None) - if cc is not None: - has_cache_bp = True - - result["blocks"] = blocks - result["has_cache_breakpoint"] = has_cache_bp - return result - - @staticmethod - def _describe_block( - block: ContentBlockParam, - ) -> str: - """One-line description of a content block.""" - if isinstance(block, TextBlockParam): - preview = block.text[:60] - if len(block.text) > 60: - preview += "..." - return f"text({preview})" - - if isinstance(block, ImageBlockParam): - kind = ( - "base64" if isinstance(block.source, Base64ImageSourceParam) else "url" - ) - return f"image:{kind}" - - if isinstance(block, ToolUseBlockParam): - return f"tool_use({block.name})" - - if isinstance(block, ToolResultBlockParam): - return _describe_tool_result(block) - - return block.type - class SummarizingTruncationStrategy(TruncationStrategy): """Truncation strategy that summarizes when limits are hit. @@ -752,19 +627,3 @@ def truncate(self) -> None: new_messages.extend(recent) self._truncated_message_history = new_messages - - -def _describe_tool_result( - block: ToolResultBlockParam, -) -> str: - """One-line description of a tool result block.""" - if isinstance(block.content, str): - return f"tool_result({block.content[:40]})" - nested = [] - for n in block.content: - if isinstance(n, TextBlockParam): - nested.append(f"text({n.text[:30]})") - elif isinstance(n, ImageBlockParam): - kind = "b64" if isinstance(n.source, Base64ImageSourceParam) else "url" - nested.append(f"img:{kind}") - return f"tool_result[{', '.join(nested)}]" From ac5728f8938f13bf0f4f8c769b35b7d1f86f877c Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Thu, 2 Apr 2026 07:43:37 +0200 Subject: [PATCH 5/7] fix: qa --- src/askui/models/shared/truncation_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index 77db47fc..01c94375 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -242,7 +242,7 @@ def __init__( self._token_counter = SimpleTokenCounter() self._image_removal_boundary_index: int | None = None try: - from askui.models.shared.truncation_debug import TruncationDebugWriter + from askui.models.shared.truncation_debug import TruncationDebugWriter # noqa self._debug_writer = TruncationDebugWriter() except ImportError: From 9bb12b7242f4ef4d191b11eb1061dddeaf559fd5 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Thu, 2 Apr 2026 07:49:48 +0200 Subject: [PATCH 6/7] fix: qa --- src/askui/models/shared/truncation_strategies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index 01c94375..c8bba3c9 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -242,7 +242,9 @@ def __init__( self._token_counter = SimpleTokenCounter() self._image_removal_boundary_index: int | None = None try: - from askui.models.shared.truncation_debug import TruncationDebugWriter # noqa + from askui.models.shared.truncation_debug import ( # type: ignore[import-untyped] + TruncationDebugWriter, + ) self._debug_writer = TruncationDebugWriter() except ImportError: From b23dd88b5c5433bab9a4dcd61b21a63fd4b0244c Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Thu, 2 Apr 2026 13:35:55 +0200 Subject: [PATCH 7/7] feat: make summarization prompt for truncation a prompt entity --- src/askui/models/shared/truncation_strategies.py | 10 ++-------- src/askui/prompts/truncation.py | 8 ++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 src/askui/prompts/truncation.py diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index c8bba3c9..591f0462 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -17,6 +17,7 @@ ToolUseBlockParam, ) from askui.models.shared.token_counter import SimpleTokenCounter +from askui.prompts.truncation import SUMMARIZE_INSTRUCTION_PROMPT logger = logging.getLogger(__name__) @@ -72,14 +73,7 @@ def _summarize_message_history( messages_to_summarize.append( MessageParam( role="user", - content=( - "Please provide a concise summary of the " - "conversation history above. Focus on: key " - "actions taken, results observed, current " - "state, and any pending tasks or errors. " - "This summary will replace the earlier " - "conversation history to save context space." - ), + content=SUMMARIZE_INSTRUCTION_PROMPT, ) ) diff --git a/src/askui/prompts/truncation.py b/src/askui/prompts/truncation.py new file mode 100644 index 00000000..a1748928 --- /dev/null +++ b/src/askui/prompts/truncation.py @@ -0,0 +1,8 @@ +SUMMARIZE_INSTRUCTION_PROMPT = ( + "Please provide a concise summary of the " + "conversation history above. Focus on: key " + "actions taken, results observed, current " + "state, and any pending tasks or errors. " + "This summary will replace the earlier " + "conversation history to save context space." +)