Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions openviking/retrieve/hierarchical_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,22 @@ async def retrieve(

collection = self._type_to_collection(query.context_type)

target_dirs = [d for d in (query.target_directories or []) if d]

# Create context_type filter
type_filter = {"op": "must", "field": "context_type", "conds": [query.context_type.value]}

# Merge all filters
filters_to_merge = [type_filter]
if target_dirs:
target_filter = {
"op": "or",
"conds": [
{"op": "prefix", "field": "uri", "prefix": target_dir}
for target_dir in target_dirs
],
}
filters_to_merge.append(target_filter)
if metadata_filter:
filters_to_merge.append(metadata_filter)

Expand All @@ -124,8 +135,11 @@ async def retrieve(
query_vector = result.dense_vector
sparse_query_vector = result.sparse_vector

# Step 1: Determine starting directories based on context_type
root_uris = self._get_root_uris_for_type(query.context_type)
# Step 1: Determine starting directories based on target_directories or context_type
if target_dirs:
root_uris = target_dirs
else:
root_uris = self._get_root_uris_for_type(query.context_type)

# Step 2: Global vector search to supplement starting points
global_results = await self._global_vector_search(
Expand Down
70 changes: 70 additions & 0 deletions tests/retrieve/test_hierarchical_retriever_target_dirs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0

"""Hierarchical retriever target_directories tests."""

import pytest

from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever
from openviking_cli.retrieve.types import ContextType, TypedQuery


class DummyStorage:
"""Minimal storage stub to capture search filters."""

def __init__(self) -> None:
self.search_calls = []

async def collection_exists(self, _name: str) -> bool:
return True

async def search(
self,
collection: str,
query_vector=None,
sparse_query_vector=None,
filter=None,
limit: int = 10,
offset: int = 0,
output_fields=None,
with_vector: bool = False,
):
self.search_calls.append(
{
"collection": collection,
"filter": filter,
"limit": limit,
"offset": offset,
}
)
return []


def _contains_prefix_filter(obj, prefix: str) -> bool:
if isinstance(obj, dict):
if obj.get("op") == "prefix" and obj.get("field") == "uri" and obj.get("prefix") == prefix:
return True
return any(_contains_prefix_filter(v, prefix) for v in obj.values())
if isinstance(obj, list):
return any(_contains_prefix_filter(v, prefix) for v in obj)
return False


@pytest.mark.asyncio
async def test_retrieve_honors_target_directories_prefix_filter():
target_uri = "viking://resources/foo"
storage = DummyStorage()
retriever = HierarchicalRetriever(storage=storage, embedder=None, rerank_config=None)

query = TypedQuery(
query="test",
context_type=ContextType.RESOURCE,
intent="",
target_directories=[target_uri],
)

result = await retriever.retrieve(query, limit=3)

assert result.searched_directories == [target_uri]
assert storage.search_calls
assert _contains_prefix_filter(storage.search_calls[0]["filter"], target_uri)
Loading