Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/compress/search.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { SessionState, WithParts } from "../state"
import { formatBlockRef, parseBoundaryId } from "../message-ids"
import { isIgnoredUserMessage } from "../messages/query"
import { filterProcessableMessages } from "../messages/shape"
import { countAllMessageTokens } from "../token-utils"
import type { BoundaryReference, SearchContext, SelectionResolution } from "./types"

Expand All @@ -9,8 +10,7 @@ export async function fetchSessionMessages(client: any, sessionId: string): Prom
path: { id: sessionId },
})

const payload = (response?.data || response) as WithParts[]
return Array.isArray(payload) ? payload : []
return filterProcessableMessages(response?.data || response)
}

export function buildSearchContext(state: SessionState, rawMessages: WithParts[]): SearchContext {
Expand Down
39 changes: 24 additions & 15 deletions lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
consumeCompressionStart,
resolveCompressionDuration,
} from "./compress/timing"
import { filterProcessableMessages } from "./messages/shape"
import {
applyPendingManualTrigger,
handleContextCommand,
Expand Down Expand Up @@ -103,41 +104,49 @@ export function createChatMessageTransformHandler(
hostPermissions: HostPermissionSnapshot,
) {
return async (input: {}, output: { messages: WithParts[] }) => {
await checkSession(client, state, logger, output.messages, config.manualMode.enabled)
const messages = filterProcessableMessages(output.messages)
if (messages.length !== output.messages.length) {
logger.warn("Skipping messages with unexpected shape during chat transform", {
received: output.messages.length,
usable: messages.length,
})
}

await checkSession(client, state, logger, messages, config.manualMode.enabled)

syncCompressPermissionState(state, config, hostPermissions, output.messages)
syncCompressPermissionState(state, config, hostPermissions, messages)

if (state.isSubAgent && !config.experimental.allowSubAgents) {
return
}

stripHallucinations(output.messages)
cacheSystemPromptTokens(state, output.messages)
assignMessageRefs(state, output.messages)
syncCompressionBlocks(state, logger, output.messages)
syncToolCache(state, config, logger, output.messages)
buildToolIdList(state, output.messages)
prune(state, logger, config, output.messages)
cacheSystemPromptTokens(state, messages)
assignMessageRefs(state, messages)
syncCompressionBlocks(state, logger, messages)
syncToolCache(state, config, logger, messages)
buildToolIdList(state, messages)
prune(state, logger, config, messages)
await injectExtendedSubAgentResults(
client,
state,
logger,
output.messages,
messages,
config.experimental.allowSubAgents,
)
const compressionPriorities = buildPriorityMap(config, state, output.messages)
const compressionPriorities = buildPriorityMap(config, state, messages)
prompts.reload()
injectCompressNudges(
state,
config,
logger,
output.messages,
messages,
prompts.getRuntimePrompts(),
compressionPriorities,
)
injectMessageIds(state, config, output.messages, compressionPriorities)
applyPendingManualTrigger(state, output.messages, logger)
stripStaleMetadata(output.messages)
injectMessageIds(state, config, messages, compressionPriorities)
applyPendingManualTrigger(state, messages, logger)
stripStaleMetadata(messages)

if (state.sessionId) {
await logger.saveContext(state.sessionId, output.messages)
Expand Down Expand Up @@ -165,7 +174,7 @@ export function createCommandExecuteHandler(
const messagesResponse = await client.session.messages({
path: { id: input.sessionID },
})
const messages = (messagesResponse.data || messagesResponse) as WithParts[]
const messages = filterProcessableMessages(messagesResponse.data || messagesResponse)

await ensureSessionInitialized(
client,
Expand Down
4 changes: 2 additions & 2 deletions lib/messages/inject/subagent-results.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { Logger } from "../../logger"
import type { SessionState, WithParts } from "../../state"
import { filterProcessableMessages } from "../shape"
import {
buildSubagentResultText,
getSubAgentId,
Expand All @@ -12,8 +13,7 @@ async function fetchSubAgentMessages(client: any, sessionId: string): Promise<Wi
path: { id: sessionId },
})

const payload = (response?.data || response) as WithParts[]
return Array.isArray(payload) ? payload : []
return filterProcessableMessages(response?.data || response)
}

export const injectExtendedSubAgentResults = async (
Expand Down
16 changes: 16 additions & 0 deletions lib/messages/query.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { PluginConfig } from "../config"
import type { WithParts } from "../state"
import { isMessageWithInfo } from "./shape"

export const getLastUserMessage = (
messages: WithParts[],
Expand All @@ -8,6 +9,9 @@ export const getLastUserMessage = (
const start = startIndex ?? messages.length - 1
for (let i = start; i >= 0; i--) {
const msg = messages[i]
if (!isMessageWithInfo(msg)) {
continue
}
if (msg.info.role === "user" && !isIgnoredUserMessage(msg)) {
return msg
}
Expand All @@ -16,6 +20,10 @@ export const getLastUserMessage = (
}

export const messageHasCompress = (message: WithParts): boolean => {
if (!isMessageWithInfo(message)) {
return false
}

if (message.info.role !== "assistant") {
return false
}
Expand All @@ -28,6 +36,10 @@ export const messageHasCompress = (message: WithParts): boolean => {
}

export const isIgnoredUserMessage = (message: WithParts): boolean => {
if (!isMessageWithInfo(message)) {
return false
}

if (message.info.role !== "user") {
return false
}
Expand All @@ -47,6 +59,10 @@ export const isIgnoredUserMessage = (message: WithParts): boolean => {
}

export function isProtectedUserMessage(config: PluginConfig, message: WithParts): boolean {
if (!isMessageWithInfo(message)) {
return false
}

return (
config.compress.mode === "message" &&
config.compress.protectUserMessages &&
Expand Down
33 changes: 33 additions & 0 deletions lib/messages/shape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import type { WithParts } from "../state"

export function isMessageWithInfo(message: unknown): message is WithParts {
if (!message || typeof message !== "object") {
return false
}

const info = (message as any).info
const parts = (message as any).parts
if (!info || typeof info !== "object") {
return false
}

return (
typeof info.id === "string" &&
info.id.length > 0 &&
typeof info.sessionID === "string" &&
info.sessionID.length > 0 &&
(info.role === "user" || info.role === "assistant") &&
info.time &&
typeof info.time === "object" &&
typeof info.time.created === "number" &&
Array.isArray(parts)
)
}

export function filterProcessableMessages(messages: unknown): WithParts[] {
if (!Array.isArray(messages)) {
return []
}

return messages.filter(isMessageWithInfo)
}
11 changes: 11 additions & 0 deletions lib/state/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ import type {
WithParts,
} from "./types"
import { isIgnoredUserMessage, messageHasCompress } from "../messages/query"
import { isMessageWithInfo } from "../messages/shape"
import { countTokens } from "../token-utils"

export const isMessageCompacted = (state: SessionState, msg: WithParts): boolean => {
if (!isMessageWithInfo(msg)) {
return false
}

if (msg.info.time.created < state.lastCompaction) {
return true
}
Expand Down Expand Up @@ -58,6 +63,9 @@ export async function isSubAgentSession(client: any, sessionID: string): Promise
export function findLastCompactionTimestamp(messages: WithParts[]): number {
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i]
if (!isMessageWithInfo(msg)) {
continue
}
if (msg.info.role === "assistant" && msg.info.summary === true) {
return msg.info.time.created
}
Expand All @@ -68,6 +76,9 @@ export function findLastCompactionTimestamp(messages: WithParts[]): number {
export function countTurns(state: SessionState, messages: WithParts[]): number {
let turnCount = 0
for (const msg of messages) {
if (!isMessageWithInfo(msg)) {
continue
}
if (isMessageCompacted(state, msg)) {
continue
}
Expand Down
38 changes: 38 additions & 0 deletions tests/hooks-permission.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,44 @@ test("chat message transform strips hallucinated tags even when compress is deni
assert.equal((output.messages[0]?.parts[0] as any).text, "alpha omega")
})

test("chat message transform ignores messages without info instead of crashing", async () => {
const state = createSessionState()
const logger = new Logger(false)
const config = buildConfig("deny")
const handler = createChatMessageTransformHandler(
{ session: { get: async () => ({}) } } as any,
state,
logger,
config,
{
reload() {},
getRuntimePrompts() {
return {} as any
},
} as any,
{ global: undefined, agents: {} },
)
const output = {
messages: [
{
role: "user",
time: 1,
parts: [
{
type: "text",
text: "Carica le skill di laravel",
},
],
} as any,
],
}

await handler({}, output as any)

assert.equal(state.sessionId, null)
assert.equal(output.messages.length, 1)
})

test("command execute exits after effective permission resolves to deny", async () => {
let sessionMessagesCalls = 0
const output = { parts: [] as any[] }
Expand Down
Loading