Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/groundlight/experimental_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import json
import time
from io import BufferedReader, BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -40,6 +41,7 @@
)
from urllib3.response import HTTPResponse

from groundlight.edge.config import EdgeEndpointConfig
from groundlight.images import parse_supported_image_types
from groundlight.internalapi import _generate_request_id
from groundlight.optional_imports import Image, np
Expand Down Expand Up @@ -817,3 +819,72 @@ def make_generic_api_request( # noqa: PLR0913 # pylint: disable=too-many-argume
auth_settings=["ApiToken"],
_preload_content=False, # This returns the urllib3 response rather than trying any type of processing
)

def _edge_base_url(self) -> str:
"""Return the scheme+host+port of the configured endpoint, without the /device-api path."""
from urllib.parse import urlparse, urlunparse

parsed = urlparse(self.configuration.host)
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))

def get_edge_config(self) -> EdgeEndpointConfig:
"""Retrieve the active edge endpoint configuration.

Only works when the client is pointed at an edge endpoint
(via GROUNDLIGHT_ENDPOINT or the endpoint constructor arg).
"""
url = f"{self._edge_base_url()}/edge-config"
headers = self.get_raw_headers()
response = requests.get(url, headers=headers, verify=self.configuration.verify_ssl)
response.raise_for_status()
return EdgeEndpointConfig.from_payload(response.json())

def get_edge_detector_readiness(self) -> dict[str, bool]:
"""Check which configured detectors have inference pods ready to serve.

Only works when the client is pointed at an edge endpoint.

:return: Dict mapping detector_id to readiness (True/False).
"""
url = f"{self._edge_base_url()}/edge-detector-readiness"
headers = self.get_raw_headers()
response = requests.get(url, headers=headers, verify=self.configuration.verify_ssl)
response.raise_for_status()
return {det_id: info["ready"] for det_id, info in response.json().items()}

def set_edge_config(
self,
config: EdgeEndpointConfig,
mode: str = "REPLACE",
timeout_sec: float = 300,
poll_interval_sec: float = 1,
) -> EdgeEndpointConfig:
"""Send a new edge endpoint configuration and wait until all detectors are ready.

Only works when the client is pointed at an edge endpoint.

:param config: The new configuration to apply.
:param mode: Currently only "REPLACE" is supported.
:param timeout_sec: Max seconds to wait for all detectors to become ready.
:param poll_interval_sec: How often to poll readiness while waiting.
:return: The applied configuration as reported by the edge endpoint.
"""
if mode != "REPLACE":
raise ValueError(f"Unsupported mode: {mode!r}. Currently only 'REPLACE' is supported.")

url = f"{self._edge_base_url()}/edge-config"
headers = self.get_raw_headers()
response = requests.put(url, json=config.to_payload(), headers=headers, verify=self.configuration.verify_ssl)
response.raise_for_status()

desired_ids = {d.detector_id for d in config.detectors if d.detector_id}
deadline = time.time() + timeout_sec
while time.time() < deadline:
readiness = self.get_edge_detector_readiness()
if desired_ids and all(readiness.get(did, False) for did in desired_ids):
return self.get_edge_config()
time.sleep(poll_interval_sec)

raise TimeoutError(
f"Edge detectors were not all ready within {timeout_sec}s. The edge endpoint may still be converging."
)
27 changes: 27 additions & 0 deletions test/unit/test_edge_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,30 @@ def test_inference_config_validation_errors():
always_return_edge_prediction=True,
min_time_between_escalations=-1.0,
)


def test_get_edge_config_parses_response():
"""ExperimentalApi.get_edge_config() parses the HTTP response into an EdgeEndpointConfig."""
from unittest.mock import Mock, patch

from groundlight import ExperimentalApi

payload = {
"global_config": {"refresh_rate": REFRESH_RATE_SECONDS},
"edge_inference_configs": {"default": {"enabled": True}},
"detectors": [{"detector_id": "det_1", "edge_inference_config": "default"}],
}

mock_response = Mock()
mock_response.json.return_value = payload
mock_response.raise_for_status = Mock()

gl = ExperimentalApi()
with patch("requests.get", return_value=mock_response) as mock_get:
config = gl.get_edge_config()

mock_get.assert_called_once()
assert isinstance(config, EdgeEndpointConfig)
assert config.global_config.refresh_rate == REFRESH_RATE_SECONDS
assert config.edge_inference_configs["default"].name == "default"
assert [d.detector_id for d in config.detectors] == ["det_1"]