diff --git a/.gitignore b/.gitignore index cabce75..408f194 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ cython_debug/ #.idea/ test/ +debug.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c00853..df1e042 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,116 @@ # CHANGELOG +## Version 0.61.1 (March 2026) + +**Released**: March 15, 2026 + +This release improves async support with a new `arun()` helper, makes the Device Utilities module fully non-blocking, adds VT100 terminal emulation for screen-based commands, and introduces interactive SSH shell access for EX/SRX devices. (PR #16) + +--- + +### 1. NEW FEATURES + +#### **`mistapi.arun()` — Async Helper** +New helper function to run any sync mistapi function without blocking the event loop. Wraps the function call in `asyncio.to_thread()` so blocking HTTP requests run in a thread pool. + +```python +import asyncio +import mistapi +from mistapi.api.v1.sites import devices + +async def main(): + session = mistapi.APISession(env_file="~/.mist_env") + session.login() + + # Run sync API call without blocking the event loop + response = await mistapi.arun(devices.listSiteDevices, session, site_id) + print(response.data) + +asyncio.run(main()) +``` + +#### **Interactive SSH Shell** (`device_utils.ex` / `device_utils.srx`) +New `interactiveShell()` and `createShellSession()` functions for SSH-over-WebSocket access to EX and SRX devices. + +- `interactiveShell()` — takes over the terminal for human SSH access (uses `sshkeyboard`) +- `createShellSession()` — returns a `ShellSession` object for programmatic send/recv +- `ShellSession` — bidirectional WebSocket session with `send()`, `recv()`, `resize()`, context manager support + +```python +from mistapi.device_utils import ex + +# Interactive (human at the keyboard) +ex.interactiveShell(apisession, site_id, device_id) + +# Programmatic +with ex.createShellSession(apisession, site_id, device_id) as session: + session.send_text("show version\r\n") + import time; time.sleep(3) + while (data := session.recv(timeout=0.5)): + print(data.decode("utf-8", errors="replace"), end="") +``` + +#### **`topCommand`** (`device_utils.ex` / `device_utils.srx`) +New `topCommand()` function to stream `top` output from EX and SRX devices. Uses VT100 screen-buffer rendering for proper in-place display. + +#### **VT100 Terminal Emulation** +Added ANSI escape stripping and a minimal VT100 screen-buffer renderer for device command output. Stream-mode commands (ping, traceroute) have ANSI codes stripped automatically. Screen-mode commands (top, monitor interface) are rendered through a virtual terminal buffer. + +--- + +### 2. IMPROVEMENTS + +#### **Non-Blocking Device Utilities** +All `mistapi.device_utils` functions now return immediately. The HTTP trigger and WebSocket streaming run in background threads, allowing your code to continue executing while data is collected. + +**UtilResponse Object:** +| Method/Property | Description | +|-----------------|-------------| +| `.ws_data` | List of processed messages | +| `.done` | `True` if data collection is complete | +| `.wait(timeout)` | Block until complete, returns self | +| `.receive()` | Generator yielding messages as they arrive | +| `.disconnect()` | Stop the WebSocket connection early | +| `await response` | Async-friendly wait (non-blocking event loop) | + +**Example Usage:** +```python +from mistapi.device_utils import ex + +# Non-blocking - returns immediately, data collected in background +response = ex.ping(apisession, site_id, device_id, host="8.8.8.8") +do_other_work() # Can do other things while waiting +response.wait() # Block when ready to collect results +print(response.ws_data) + +# Generator style - process messages as they arrive +for msg in response.receive(): + print(msg) + +# Async-friendly - doesn't block the event loop +await response +``` + +#### **Binary WebSocket Frame Support** +`_MistWebsocket._handle_message()` now handles binary frames (strips null bytes, decodes UTF-8 with replacement characters). + +#### **Trigger-Only Commands Run Synchronously** +Fire-and-forget device commands (e.g., `clearMacTable`, `clearBpduError`, `clearHitCount`) that don't require a WebSocket stream now run the API trigger synchronously, ensuring `trigger_api_response` is immediately available on the returned `UtilResponse`. + +--- + +### 3. BUG FIXES + +- Fixed double-space typo in API token privilege mismatch error message +- Fixed `first_message_timeout` timer stop to check timer is active before stopping + +--- + +### 4. DEPENDENCIES + +- Added `sshkeyboard>=2.3.1` (for `interactiveShell()`) + +--- + ## Version 0.61.0 (March 2026) **Released**: March 13, 2026 @@ -19,7 +131,7 @@ Complete real-time event streaming support with flexible consumption patterns: |-------|-------------| | `mistapi.websockets.orgs.InsightsEvents` | Real-time insights events for an organization | | `mistapi.websockets.orgs.MxEdgesStatsEvents` | Real-time MX edges stats for an organization | -| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | Real-time MX edges upgrades events for an organization | +| `mistapi.websockets.orgs.MxEdgesEvents` | Real-time MX edges events for an organization | * Site Channels @@ -28,7 +140,8 @@ Complete real-time event streaming support with flexible consumption patterns: | `mistapi.websockets.sites.ClientsStatsEvents` | Real-time clients stats for a site | | `mistapi.websockets.sites.DeviceCmdEvents` | Real-time device command events for a site | | `mistapi.websockets.sites.DeviceStatsEvents` | Real-time device stats for a site | -| `mistapi.websockets.sites.DeviceUpgradesEvents` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.DeviceEvents` | Real-time device events for a site | +| `mistapi.websockets.sites.MxEdgesEvents` | Real-time MX edges events for a site | | `mistapi.websockets.sites.MxEdgesStatsEvents` | Real-time MX edges stats for a site | | `mistapi.websockets.sites.PcapEvents` | Real-time PCAP events for a site | @@ -89,7 +202,7 @@ print(result.ws_data) def handle(msg): print("got:", msg) -result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_message=handle) +result = ex.cableTest(apisession, site_id, device_id, port_id="ge-0/0/0", on_message=handle) ``` #### **1.3 New API Endpoints** @@ -153,49 +266,6 @@ result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_messag --- -## Version 0.60.3 (February 2026) - -**Released**: February 21, 2026 - -This release add a missing query parameter to the `searchOrgWanClients()` function. - ---- - -### 1. CHANGES - -##### **API Function Updates** -- Updated `searchOrgWanClients()` and related functions in `orgs/wan_clients.py`. - ---- - -## Version 0.60.1 (February 2026) - -**Released**: February 21, 2026 - -This release includes function updates and bug fixes in the self/logs.py and sites/sle.py modules. - ---- - -### 1. CHANGES - -##### **API Function Updates** -- Updated `listSelfAuditLogs()` and related functions in `self/logs.py`. -- Updated deprecated and new SLE classifier functions in `sites/sle.py`. - ---- - -### 2. BUG FIXES - -- Minor bug fixes and improvements in API modules. - ---- - -### Breaking Changes - -No breaking changes in this release. - ---- - ## Version 0.60.4 (March 2026) **Released**: March 3, 2026 @@ -715,4 +785,4 @@ Previous stable release. See commit history for details. **Author**: Thomas Munzer **License**: MIT License -**Python Compatibility**: Python 3.8+ +**Python Compatibility**: Python 3.10+ diff --git a/README.md b/README.md index 09d27e4..b5c343a 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ A comprehensive Python package to interact with the Mist Cloud APIs, built from - [Callbacks](#callbacks) - [Available Channels](#available-channels) - [Usage Patterns](#usage-patterns) +- [Async Usage](#async-usage) + - [Running API Calls Asynchronously](#running-api-calls-asynchronously) + - [Concurrent API Calls](#concurrent-api-calls) + - [Combining with Device Utilities](#combining-with-device-utilities) - [Device Utilities](#device-utilities) - [Supported Devices](#supported-devices) - [Usage](#device-utilities-usage) @@ -63,9 +67,10 @@ Support for all Mist cloud instances worldwide: ### Core Features - **Complete API Coverage**: Auto-generated from OpenAPI specs +- **Async Support**: Run any API call asynchronously with `mistapi.arun()` — no changes to existing code - **Automatic Pagination**: Built-in support for paginated responses - **WebSocket Streaming**: Real-time event streaming for devices, clients, and location data -- **Device Diagnostics**: High-level utilities for ping, traceroute, ARP, BGP, OSPF, and more +- **Device Diagnostics**: High-level, non-blocking utilities for ping, traceroute, ARP, BGP, OSPF, and more - **Error Handling**: Detailed error responses and logging - **Proxy Support**: HTTP/HTTPS proxy configuration - **Log Sanitization**: Automatic redaction of sensitive data in logs @@ -492,6 +497,82 @@ events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( --- +## Async Usage + +All API functions in `mistapi.api.v1` are synchronous by default. To use them in an `asyncio` context (e.g., FastAPI, aiohttp, or any async application) without blocking the event loop, use `mistapi.arun()`. + +`arun()` wraps any sync mistapi function in `asyncio.to_thread()`, running the blocking HTTP request in a thread pool while the event loop continues. No changes are needed to the existing API functions. + +### Running API Calls Asynchronously + +```python +import asyncio +import mistapi +from mistapi.api.v1.sites import devices + +apisession = mistapi.APISession(env_file="~/.mist_env") +apisession.login() + +async def main(): + # Wrap any sync API call with mistapi.arun() + response = await mistapi.arun( + devices.listSiteDevices, apisession, site_id + ) + print(response.data) + +asyncio.run(main()) +``` + +### Concurrent API Calls + +Use `asyncio.gather()` to run multiple API calls concurrently: + +```python +import asyncio +import mistapi +from mistapi.api.v1.orgs import orgs +from mistapi.api.v1.sites import devices + +async def main(): + org_info, site_devices = await asyncio.gather( + mistapi.arun(orgs.getOrg, apisession, org_id), + mistapi.arun(devices.listSiteDevices, apisession, site_id), + ) + print(f"Org: {org_info.data['name']}") + print(f"Devices: {len(site_devices.data)}") + +asyncio.run(main()) +``` + +### Combining with Device Utilities + +Device utility functions are already non-blocking and return a `UtilResponse` that supports `await`. You can mix `arun()` for API calls and `await` for device utilities: + +```python +import asyncio +import mistapi +from mistapi.api.v1.sites import devices +from mistapi.device_utils import ex + +async def main(): + # Start device utility — returns immediately, collects data in a background thread + response = ex.retrieveArpTable(apisession, site_id, device_id) + + # Meanwhile, run an API call via arun() — both execute concurrently + device_info = await mistapi.arun( + devices.getSiteDevice, apisession, site_id, device_id + ) + print(f"Device: {device_info.data['name']}") + + # Wait for the device utility background thread to finish + await response + print(f"ARP entries: {len(response.ws_data)}") + +asyncio.run(main()) +``` + +--- + ## WebSocket Streaming The package provides a WebSocket client for real-time event streaming from the Mist API (`wss://{host}/api-ws/v1/stream`). Authentication is handled automatically using the same session credentials (API token or login/password). @@ -533,7 +614,7 @@ ws.connect() |-------|---------|-------------| | `mistapi.websockets.orgs.InsightsEvents` | `/orgs/{org_id}/insights/summary` | Real-time insights events for an organization | | `mistapi.websockets.orgs.MxEdgesStatsEvents` | `/orgs/{org_id}/stats/mxedges` | Real-time MX edges stats for an organization | -| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges upgrades events for an organization | +| `mistapi.websockets.orgs.MxEdgesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges events for an organization | #### Site Channels @@ -542,7 +623,7 @@ ws.connect() | `mistapi.websockets.sites.ClientsStatsEvents` | `/sites/{site_id}/stats/clients` | Real-time clients stats for a site | | `mistapi.websockets.sites.DeviceCmdEvents` | `/sites/{site_id}/devices/{device_id}/cmd` | Real-time device command events for a site | | `mistapi.websockets.sites.DeviceStatsEvents` | `/sites/{site_id}/stats/devices` | Real-time device stats for a site | -| `mistapi.websockets.sites.DeviceUpgradesEvents` | `/sites/{site_id}/devices` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.DeviceEvents` | `/sites/{site_id}/devices` | Real-time device events for a site | | `mistapi.websockets.sites.MxEdgesStatsEvents` | `/sites/{site_id}/stats/mxedges` | Real-time MX edges stats for a site | | `mistapi.websockets.sites.PcapEvents` | `/sites/{site_id}/pcap` | Real-time PCAP events for a site | @@ -631,46 +712,173 @@ with mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[" | Module | Device Type | Functions | |--------|-------------|-----------| | `device_utils.ap` | Mist Access Points | `ping`, `traceroute`, `retrieveArpTable` | -| `device_utils.ex` | Juniper EX Switches | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveMacTable`, `clearMacTable`, `clearLearnedMac`, `clearBpduError`, `clearDot1xSessions`, `clearHitCount`, `bouncePort`, `cableTest` | -| `device_utils.srx` | Juniper SRX Firewalls | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes` | -| `device_utils.ssr` | Juniper SSR Routers | `ping`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes`, `showServicePath` | +| `device_utils.ex` | Juniper EX Switches | `ping`, `monitorTraffic`, `topCommand`, `interactiveShell`, `createShellSession`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveMacTable`, `clearMacTable`, `clearLearnedMac`, `clearBpduError`, `clearDot1xSessions`, `clearHitCount`, `bouncePort`, `cableTest` | +| `device_utils.srx` | Juniper SRX Firewalls | `ping`, `monitorTraffic`, `topCommand`, `interactiveShell`, `createShellSession`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveOspfDatabase`, `retrieveOspfNeighbors`, `retrieveOspfInterfaces`, `retrieveOspfSummary`, `retrieveSessions`, `clearSessions`, `bouncePort`, `retrieveRoutes` | +| `device_utils.ssr` | Juniper SSR Routers | `ping`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveOspfDatabase`, `retrieveOspfNeighbors`, `retrieveOspfInterfaces`, `retrieveOspfSummary`, `retrieveSessions`, `clearSessions`, `bouncePort`, `retrieveRoutes`, `showServicePath` | ### Device Utilities Usage -```python -from mistapi.device_utils import ap, ex +All device utility functions are **non-blocking**: they trigger the REST API call, start a WebSocket stream in the background, and return a `UtilResponse` immediately. Your script can continue processing while data streams in. + +#### Callback style -# Ping from an AP -result = ap.ping(apisession, site_id, device_id, host="8.8.8.8") -print(result.ws_data) +Pass an `on_message` callback to process each result as it arrives: -# Retrieve ARP table from a switch -result = ex.retrieveArpTable(apisession, site_id, device_id) -print(result.ws_data) +```python +from mistapi.device_utils import ex -# With real-time callback def handle(msg): - print("got:", msg) + print("Live:", msg) + +response = ex.retrieveArpTable(apisession, site_id, device_id, on_message=handle) +# returns immediately — on_message fires for each message in the background + +do_other_work() + +response.wait() # block until streaming is complete +print(response.ws_data) # all collected data +``` + +#### Generator style + +Iterate over processed messages as they arrive, similar to `_MistWebsocket.receive()`: -result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_message=handle) +```python +response = ex.retrieveMacTable(apisession, site_id, device_id) +for msg in response.receive(): # blocking generator, yields each message + print(msg, end="", flush=True) +# loop ends when the WebSocket closes +print(response.ws_data) +``` + +#### Context manager + +`disconnect()` is called automatically when the context exits: + +```python +with ex.cableTest(apisession, site_id, device_id, port_id="ge-0/0/0") as response: + for msg in response.receive(): + print(msg, end="", flush=True) +# WebSocket disconnected, data ready +print(response.ws_data) +``` + +#### Polling + +Check `response.done` to avoid blocking: + +```python +response = ex.retrieveBgpSummary(apisession, site_id, device_id) +while not response.done: + do_other_work() +print(response.ws_data) +``` + +#### Cancel early + +Stop a long-running stream before it completes: + +```python +response = ex.monitorTraffic(apisession, site_id, device_id, port_id="ge-0/0/0") +do_some_work() +response.disconnect() # stop the WebSocket +print(response.ws_data) # data collected so far +``` + +#### Async await + +Works in `asyncio` contexts without blocking the event loop: + +```python +import asyncio +from mistapi.device_utils import ex + +async def main(): + response = ex.retrieveArpTable(apisession, site_id, device_id) + await response # non-blocking await + print(response.ws_data) + +asyncio.run(main()) ``` ### UtilResponse Object All device utility functions return a `UtilResponse` object: +#### Attributes + | Attribute | Type | Description | |-----------|------|-------------| | `trigger_api_response` | `APIResponse` | The initial REST API response that triggered the device command. Contains `status_code`, `data`, and `headers` from the trigger request. | | `ws_required` | `bool` | `True` if the command required a WebSocket connection to stream results (most diagnostic commands do). `False` if the REST response alone was sufficient. | -| `ws_data` | `list[str]` | Parsed result data extracted from the WebSocket stream. Each entry is a processed output line from the device (e.g., a line of ping output or an ARP table row). | +| `ws_data` | `list[str]` | Parsed result data extracted from the WebSocket stream. This list is **live** — it grows as messages arrive in the background, even before `wait()` is called. | | `ws_raw_events` | `list[str]` | Raw, unprocessed WebSocket event payloads as received from the Mist API. Useful for debugging or custom parsing. | +#### Properties and Methods + +| Method / Property | Returns | Description | +|-------------------|---------|-------------| +| `done` | `bool` | `True` if data collection is complete (or no WS was needed). | +| `wait(timeout=None)` | `UtilResponse` | Block until data collection is complete. Returns `self`. | +| `receive()` | `Generator` | Blocking generator that yields each processed message as it arrives. Exits when the WebSocket closes. | +| `disconnect()` | `None` | Stop the WebSocket connection early. | +| `await response` | `UtilResponse` | Non-blocking await for `asyncio` contexts. | + +`UtilResponse` also supports the context manager protocol (`with` statement). + ### Enums - `ap.TracerouteProtocol` — `ICMP`, `UDP` (for `ap.traceroute()`) - `srx.Node` / `ssr.Node` — `NODE0`, `NODE1` (for dual-node devices) +### Interactive Shell + +`interactiveShell()` and `createShellSession()` provide SSH-over-WebSocket access to EX and SRX devices. Unlike the diagnostic utilities above, the shell is **bidirectional** — you send keystrokes and receive terminal output in real time. + +#### Interactive mode (human at the keyboard) + +Takes over the terminal. Blocks until the connection closes or you press Ctrl+C: + +```python +from mistapi.device_utils import ex + +ex.interactiveShell(apisession, site_id, device_id) +``` + +Requires the `sshkeyboard` package (installed automatically as a dependency). + +#### Programmatic mode + +Use `createShellSession()` to get a `ShellSession` object for scripting: + +```python +from mistapi.device_utils import ex +import time + +with ex.createShellSession(apisession, site_id, device_id) as session: + session.send_text("show version\r\n") + time.sleep(3) + while True: + data = session.recv(timeout=0.5) + if data is None: + break + print(data.decode("utf-8", errors="replace"), end="") +``` + +#### ShellSession API + +| Method / Property | Returns | Description | +|-------------------|---------|-------------| +| `connect()` | `None` | Open the WebSocket connection. Called automatically by `createShellSession()`. | +| `disconnect()` | `None` | Close the WebSocket connection. | +| `connected` | `bool` | `True` if the WebSocket is currently connected. | +| `send(data)` | `None` | Send raw bytes (keystrokes) to the device. | +| `send_text(text)` | `None` | Send a text string to the device (auto-prefixed with `\x00`). | +| `recv(timeout=0.1)` | `bytes \| None` | Receive output from the device. Returns `None` on timeout or if disconnected. | +| `resize(rows, cols)` | `None` | Send a terminal resize message. | + +`ShellSession` also supports the context manager protocol (`with` statement). + --- ## Development and Testing diff --git a/pyproject.toml b/pyproject.toml index 844d985..2025baa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mistapi" -version = "0.61.0" +version = "0.61.1" authors = [{ name = "Thomas Munzer", email = "tmunzer@juniper.net" }] description = "Python package to simplify the Mist System APIs usage" keywords = ["Mist", "Juniper", "API"] @@ -28,6 +28,7 @@ dependencies = [ "hvac>=2.3.0", "keyring>=24.3.0", "websocket-client>=1.8.0", + "sshkeyboard>=2.3.1", ] [project.urls] diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index 1ceb605..3b777fc 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -17,6 +17,7 @@ import json import os import re +import threading import time import urllib.parse from collections.abc import Callable @@ -30,6 +31,26 @@ from mistapi.__models.privilege import Privileges +def _apitoken_sanitizer(apitoken: str) -> str: + """ + Return a substring of the API token to be used in the logs, to avoid + logging the full token value. + + PARAMS + ----------- + apitoken : str + API token value + + RETURN + ----------- + str + Substring of the API token to be used in the logs + """ + if len(apitoken) <= 16: + return "***hidden***" + return f"{apitoken[:4]}...{apitoken[-4:]}" + + class APIRequest: """ Class handling API Request to the Mist Cloud @@ -45,6 +66,7 @@ def __init__(self) -> None: self._count: int = 0 self._apitoken: list[str] = [] self._apitoken_index: int = -1 + self._token_lock: threading.Lock = threading.Lock() def get_request_count(self): """ @@ -86,41 +108,43 @@ def _log_proxy(self) -> None: ) def _next_apitoken(self) -> None: - logger.info("apirequest:_next_apitoken:rotating API Token") - logger.debug( - "apirequest:_next_apitoken:current API Token is %s...%s", - self._apitoken[self._apitoken_index][:4], - self._apitoken[self._apitoken_index][-4:], - ) - new_index = self._apitoken_index + 1 - if new_index >= len(self._apitoken): - new_index = 0 - if self._apitoken_index != new_index: - self._apitoken_index = new_index - self._session.headers.update( - {"Authorization": "Token " + self._apitoken[self._apitoken_index]} - ) + with self._token_lock: + logger.info("apirequest:_next_apitoken:rotating API Token") + masked = _apitoken_sanitizer(self._apitoken[self._apitoken_index]) logger.debug( - "apirequest:_next_apitoken:new API Token is %s...%s", - self._apitoken[self._apitoken_index][:4], - self._apitoken[self._apitoken_index][-4:], - ) - else: - logger.critical(" /!\\ API TOKEN CRITICAL ERROR /!\\") - logger.critical( - " There is no other API Token to use and the API" - " Request limit has been reached for the current one" - ) - logger.critical( - " For large organization, it is recommended to configure" - " multiple API Tokens (comma separated list) to avoid this issue" - ) - raise RuntimeError( - "API rate limit reached and no other API Token available. " - "For large organizations, configure multiple API Tokens " - "(comma separated list) to avoid this issue." + "apirequest:_next_apitoken:current API Token is %s", + masked, ) + new_index = self._apitoken_index + 1 + if new_index >= len(self._apitoken): + new_index = 0 + if self._apitoken_index != new_index: + self._apitoken_index = new_index + self._session.headers.update( + {"Authorization": "Token " + self._apitoken[self._apitoken_index]} + ) + masked = _apitoken_sanitizer(self._apitoken[self._apitoken_index]) + logger.debug( + "apirequest:_next_apitoken:new API Token is %s", + masked, + ) + else: + logger.critical(" /!\\ API TOKEN CRITICAL ERROR /!\\") + logger.critical( + " There is no other API Token to use and the API" + " Request limit has been reached for the current one" + ) + logger.critical( + " For large organization, it is recommended to configure" + " multiple API Tokens (comma separated list) to avoid this issue" + ) + raise RuntimeError( + "API rate limit reached and no other API Token available. " + "For large organizations, configure multiple API Tokens " + "(comma separated list) to avoid this issue." + ) + def _gen_query(self, query: dict[str, str] | None) -> str: if not query: return "" @@ -344,6 +368,7 @@ def mist_post_file( multipart_form_data, ) generated_multipart_form_data: dict[str, Any] = {} + opened_files: list = [] for key in multipart_form_data: logger.debug( "apirequest:mist_post_file:multipart_form_data:%s = %s", @@ -358,6 +383,7 @@ def mist_post_file( multipart_form_data[key], ) f = open(multipart_form_data[key], "rb") + opened_files.append(f) generated_multipart_form_data[key] = ( os.path.basename(multipart_form_data[key]), f, @@ -392,4 +418,8 @@ def _do_post_file(): ) return resp - return self._request_with_retry("mist_post_file", _do_post_file, url) + try: + return self._request_with_retry("mist_post_file", _do_post_file, url) + finally: + for f in opened_files: + f.close() diff --git a/src/mistapi/__api_response.py b/src/mistapi/__api_response.py index b30d0d5..7547965 100644 --- a/src/mistapi/__api_response.py +++ b/src/mistapi/__api_response.py @@ -51,7 +51,7 @@ def __init__( console.debug(f"Response Status Code: {response.status_code}") try: - self.raw_data = str(response.content) + self.raw_data = response.text self.data = response.json() self._check_next() logger.debug("apiresponse:__init__:HTTP response processed") diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index e84a191..f561a31 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -23,7 +23,7 @@ from dotenv import load_dotenv from requests import Session -from mistapi.__api_request import APIRequest +from mistapi.__api_request import APIRequest, _apitoken_sanitizer from mistapi.__api_response import APIResponse from mistapi.__logger import console as CONSOLE from mistapi.__logger import logger as LOGGER @@ -277,10 +277,10 @@ def _load_keyring(self, keyring_service) -> None: if isinstance(mist_apitoken, str): for token in mist_apitoken.split(","): token = token.strip() + masked = _apitoken_sanitizer(token) LOGGER.info( - "apisession:_load_keyring: Found MIST_APITOKEN=%s...%s", - token[:4], - token[-4:], + "apisession:_load_keyring: Found MIST_APITOKEN=%s", + masked, ) self.set_api_token(mist_apitoken) mist_user = keyring.get_password(keyring_service, "MIST_USER") @@ -526,6 +526,7 @@ def set_api_token(self, apitoken: str, validate: bool = True) -> None: def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: token_privileges = [] token_type = "org" # nosec bandit B105 + masked = _apitoken_sanitizer(apitoken) try: url = f"https://{self._cloud_uri}/api/v1/self" headers = {"Authorization": "Token " + apitoken} @@ -536,9 +537,8 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: ) data_json = data.json() LOGGER.debug( - "apisession:_get_api_token_data:info retrieved for token %s...%s", - apitoken[:4], - apitoken[-4:], + "apisession:_get_api_token_data:info retrieved for token %s", + masked, ) except requests.exceptions.ProxyError as proxy_error: LOGGER.critical("apisession:_get_api_token_data:proxy not valid...") @@ -554,10 +554,8 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: ) from connexion_error except Exception: LOGGER.error( - "apisession:_get_api_token_data:" - "unable to retrieve info for token %s...%s", - apitoken[:4], - apitoken[-4:], + "apisession:_get_api_token_data:unable to retrieve info for token %s", + masked, ) LOGGER.error( "apirequest:_get_api_token_data: Exception occurred", exc_info=True @@ -566,20 +564,17 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: if data.status_code == 401: LOGGER.critical( - "apisession:_get_api_token_data:" - "invalid API Token %s...%s: status code %s", - apitoken[:4], - apitoken[-4:], + "apisession:_get_api_token_data:invalid API Token %s: status code %s", + masked, data.status_code, ) CONSOLE.critical( - "Invalid API Token %s...%s: status code %s\r\n", - apitoken[:4], - apitoken[-4:], + "Invalid API Token %s: status code %s\r\n", + masked, data.status_code, ) raise ValueError( - f"Invalid API Token {apitoken[:4]}...{apitoken[-4:]}: status code {data.status_code}" + f"Invalid API Token {masked}: status code {data.status_code}" ) if data_json.get("email"): @@ -604,11 +599,10 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: LOGGER.error( "apisession:_check_api_tokens:" "unable to process privileges %s for the %s " - "token %s...%s", + "token %s", priv, token_type, - apitoken[:4], - apitoken[-4:], + masked, ) return (token_type, token_privileges) @@ -624,34 +618,34 @@ def _check_api_tokens(self, apitokens) -> list[str]: else: primary_token_privileges: list[str] = [] primary_token_type: str | None = "" - primary_token_value: str = "" + primary_masked: str | None = "" for token in apitokens: - not_sensitive_data = f"{token[:4]}...{token[-4:]}" + masked = _apitoken_sanitizer(token) if token in valid_api_tokens: LOGGER.info( "apisession:_check_api_tokens:API Token %s is already valid", - not_sensitive_data, + masked, ) continue (token_type, token_privileges) = self._get_api_token_data(token) if token_type is None or token_privileges is None: LOGGER.error( "apisession:_check_api_tokens:API Token %s is not valid", - not_sensitive_data, + masked, ) LOGGER.error( "API Token %s is not valid and will not be used", - not_sensitive_data, + masked, ) elif len(primary_token_privileges) == 0 and token_privileges: primary_token_privileges = token_privileges primary_token_type = token_type - primary_token_value = not_sensitive_data + primary_masked = masked valid_api_tokens.append(token) LOGGER.info( "apisession:_check_api_tokens:" "API Token %s set as primary for comparison", - not_sensitive_data, + masked, ) elif primary_token_privileges == token_privileges: valid_api_tokens.append(token) @@ -660,23 +654,19 @@ def _check_api_tokens(self, apitokens) -> list[str]: "%s API Token %s has same privileges as " "the %s API Token %s", token_type, - not_sensitive_data, + masked, primary_token_type, - primary_token_value, + primary_masked, ) else: LOGGER.error( "apisession:_check_api_tokens:" "%s API Token %s has different privileges " - "than the %s API Token %s", + "than the %s API Token %s and will not be used", token_type, - not_sensitive_data, + masked, primary_token_type, - primary_token_value, - ) - LOGGER.error( - "API Token %s has different privileges and will not be used", - not_sensitive_data, + primary_masked, ) return valid_api_tokens diff --git a/src/mistapi/__init__.py b/src/mistapi/__init__.py index d211453..f6fca30 100644 --- a/src/mistapi/__init__.py +++ b/src/mistapi/__init__.py @@ -17,6 +17,8 @@ from mistapi.__version import __author__ as __author__ from mistapi.__version import __version__ as __version__ +import asyncio as _asyncio +from collections.abc import Callable as _Callable from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -41,3 +43,43 @@ def __getattr__(name: str): globals()[name] = module return module raise AttributeError(f"module 'mistapi' has no attribute {name!r}") + + +async def arun(func: _Callable, *args, **kwargs): + """ + Run any sync mistapi function without blocking the event loop. + + Wraps the function call in ``asyncio.to_thread()`` so the blocking + HTTP request runs in a thread pool while the event loop continues. + + EXAMPLE + ----------- + :: + + import asyncio + import mistapi + from mistapi.api.v1.sites import devices + + async def main(): + session = mistapi.APISession(env_file="~/.mist_env") + session.login() + + response = await mistapi.arun( + devices.listSiteDevices, session, site_id + ) + print(response.data) + + asyncio.run(main()) + + PARAMS + ----------- + func : callable + Any sync mistapi API function. + *args, **kwargs + Arguments forwarded to *func*. + + RETURNS + ----------- + The return value of *func* (typically ``APIResponse``). + """ + return await _asyncio.to_thread(func, *args, **kwargs) diff --git a/src/mistapi/__version.py b/src/mistapi/__version.py index 4e16e40..adc4c80 100644 --- a/src/mistapi/__version.py +++ b/src/mistapi/__version.py @@ -1,2 +1,2 @@ -__version__ = "0.61.0" +__version__ = "0.61.1" __author__ = "Thomas Munzer " diff --git a/src/mistapi/api/v1/sites/sle.py b/src/mistapi/api/v1/sites/sle.py index 71c2676..353b59f 100644 --- a/src/mistapi/api/v1/sites/sle.py +++ b/src/mistapi/api/v1/sites/sle.py @@ -19,7 +19,7 @@ @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.61.0", + current_version="0.61.1", details="function replaced with getSiteSleClassifierSummaryTrend", ) def getSiteSleClassifierDetails( @@ -691,7 +691,7 @@ def listSiteSleImpactedWirelessClients( @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.61.0", + current_version="0.61.1", details="function replaced with getSiteSleSummaryTrend", ) def getSiteSleSummary( diff --git a/src/mistapi/device_utils/__init__.py b/src/mistapi/device_utils/__init__.py index dea7134..17d2d42 100644 --- a/src/mistapi/device_utils/__init__.py +++ b/src/mistapi/device_utils/__init__.py @@ -18,7 +18,7 @@ -------------------------------------- Import device-specific modules for a clean, organized API: - from mistapi.utils import ap, ex, srx, ssr + from mistapi.device_utils import ap, ex, srx, ssr # Use device-specific utilities ap.ping(session, site_id, device_id, host) @@ -30,15 +30,6 @@ - ex: Juniper EX Switches - srx: Juniper SRX Firewalls - ssr: Juniper Session Smart Routers - -Function-Based Modules (Legacy) ---------------------------------- -Original organization by function type (still available): - - from mistapi.utils import arp, bgp, dhcp, mac, port, routes, tools - -Available modules: arp, bgp, bpdu, dhcp, dns, dot1x, mac, policy, port, routes, - service_path, tools """ # Device-specific modules (recommended) diff --git a/src/mistapi/device_utils/__tools/__common.py b/src/mistapi/device_utils/__tools/__common.py new file mode 100644 index 0000000..fdee41f --- /dev/null +++ b/src/mistapi/device_utils/__tools/__common.py @@ -0,0 +1,20 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + + +class Node(Enum): + """Node enum for specifying which node to target on dual-node devices.""" + + NODE0 = "node0" + NODE1 = "node1" diff --git a/src/mistapi/device_utils/__tools/__ws_wrapper.py b/src/mistapi/device_utils/__tools/__ws_wrapper.py index a1d8d8a..aac0c91 100644 --- a/src/mistapi/device_utils/__tools/__ws_wrapper.py +++ b/src/mistapi/device_utils/__tools/__ws_wrapper.py @@ -1,12 +1,158 @@ import json +import queue +import re import threading -from collections.abc import Callable +from collections.abc import Callable, Generator from enum import Enum from mistapi import APISession from mistapi.__api_response import APIResponse as _APIResponse from mistapi.__logger import logger as LOGGER +# Matches ANSI CSI sequences, OSC sequences, and character set designations +_ANSI_ESCAPE_RE = re.compile( + r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\][^\x07]*\x07|\x1b[()][A-B0-2]" +) + +# Detects VT100 cursor positioning / clear-screen (triggers screen-buffer mode) +_SCREEN_MODE_RE = re.compile(r"\x1b\[[\d;]*H|\x1b\[2J") + + +class _VT100Screen: + """Minimal VT100 terminal emulator for rendering screen-based output. + + Handles the subset of VT100 sequences used by Junos ``top`` and + ``monitor interface`` commands: cursor positioning, screen/line + clearing, and cursor movement. SGR (colors), scroll regions, and + mode changes are silently ignored. + """ + + def __init__(self, rows: int = 24, cols: int = 80) -> None: + self.rows = rows + self.cols = cols + self.cursor_row = 0 + self.cursor_col = 0 + self.grid: list[list[str]] = [[" "] * cols for _ in range(rows)] + + def feed(self, text: str) -> None: + """Process *text* (may contain VT100 sequences) into the screen buffer.""" + i = 0 + n = len(text) + while i < n: + ch = text[i] + + if ch == "\x1b" and i + 1 < n: + nxt = text[i + 1] + if nxt == "[": + # CSI sequence: \x1b[ + j = i + 2 + params = "" + while j < n and text[j] in "0123456789;": + params += text[j] + j += 1 + if j < n: + self._handle_csi(params, text[j]) + i = j + 1 + else: + i = j + continue + if nxt in "()": + # Character-set designation – skip 3 bytes + i += 3 if i + 2 < n else n + continue + if nxt == "]": + # OSC sequence – skip until BEL + j = i + 2 + while j < n and text[j] != "\x07": + j += 1 + i = j + 1 + continue + # Unknown escape – skip \x1b and the next char + i += 2 + continue + + if ch == "\r": + self.cursor_col = 0 + i += 1 + continue + + if ch == "\n": + self.cursor_row += 1 + self.cursor_col = 0 + if self.cursor_row >= self.rows: + self.grid.pop(0) + self.grid.append([" "] * self.cols) + self.cursor_row = self.rows - 1 + i += 1 + continue + + if ch == "\x00": + i += 1 + continue + + # Printable character + if 0 <= self.cursor_row < self.rows and 0 <= self.cursor_col < self.cols: + self.grid[self.cursor_row][self.cursor_col] = ch + self.cursor_col += 1 + i += 1 + + # ------------------------------------------------------------------ + def _handle_csi(self, params: str, cmd: str) -> None: + nums = [] + for p in params.split(";") if params else []: + try: + nums.append(int(p)) + except ValueError: + nums.append(0) + + if cmd in ("H", "f"): # Cursor position + row = (nums[0] - 1) if nums else 0 + col = (nums[1] - 1) if len(nums) > 1 else 0 + self.cursor_row = max(0, min(row, self.rows - 1)) + self.cursor_col = max(0, min(col, self.cols - 1)) + elif cmd == "A": # Cursor up + self.cursor_row = max(0, self.cursor_row - (nums[0] if nums else 1)) + elif cmd == "B": # Cursor down + self.cursor_row = min( + self.rows - 1, self.cursor_row + (nums[0] if nums else 1) + ) + elif cmd == "C": # Cursor forward + self.cursor_col = min( + self.cols - 1, self.cursor_col + (nums[0] if nums else 1) + ) + elif cmd == "D": # Cursor back + self.cursor_col = max(0, self.cursor_col - (nums[0] if nums else 1)) + elif cmd == "J": # Erase in display + n = nums[0] if nums else 0 + if n == 2: + self.grid = [[" "] * self.cols for _ in range(self.rows)] + self.cursor_row = 0 + self.cursor_col = 0 + elif n == 0: + for c in range(self.cursor_col, self.cols): + self.grid[self.cursor_row][c] = " " + for r in range(self.cursor_row + 1, self.rows): + self.grid[r] = [" "] * self.cols + elif cmd == "K": # Erase in line + n = nums[0] if nums else 0 + if n == 0: + for c in range(self.cursor_col, self.cols): + self.grid[self.cursor_row][c] = " " + elif n == 1: + for c in range(self.cursor_col + 1): + self.grid[self.cursor_row][c] = " " + elif n == 2: + self.grid[self.cursor_row] = [" "] * self.cols + # SGR (m), scroll region (r), mode set/reset (l, h) – ignore + + # ------------------------------------------------------------------ + def render(self) -> str: + """Return screen content as text with trailing whitespace trimmed.""" + lines = ["".join(row).rstrip() for row in self.grid] + while lines and not lines[-1]: + lines.pop() + return "\n".join(lines) + class TimerAction(Enum): """ @@ -30,19 +176,111 @@ class Timer(Enum): class UtilResponse: """ - A simple class to encapsulate the response from utility WebSocket functions. - This class can be extended in the future to include additional metadata or helper methods. + Encapsulates the response from device utility functions. + + Returned immediately by tool functions. When a WebSocket stream is + involved, data is collected in the background. Use ``receive()``, + ``wait()``, or the ``on_message`` callback to consume results. + + USAGE PATTERNS + ----------- + Callback style (on_message passed at call time):: + + response = ex.ping(session, site_id, device_id, host="8.8.8.8", + on_message=lambda msg: print(msg)) + do_other_work() + response.wait() + print(response.ws_data) + + Generator style:: + + response = ex.ping(session, site_id, device_id, host="8.8.8.8") + for msg in response.receive(): + print(msg) + + Context manager:: + + with ex.ping(session, site_id, device_id, host="8.8.8.8") as response: + for msg in response.receive(): + print(msg) + + Async await:: + + response = ex.ping(session, site_id, device_id, host="8.8.8.8") + await response + print(response.ws_data) """ def __init__( self, - api_response: _APIResponse, + api_response: _APIResponse | None = None, ) -> None: self.trigger_api_response = api_response - # This can be set to True if the WebSocket connection was successfully initiated self.ws_required: bool = False self.ws_data: list[str] = [] self.ws_raw_events: list[str] = [] + self._queue: queue.Queue[str | None] = queue.Queue() + self._closed = threading.Event() + self._await_timeout: float | None = None + if api_response is not None: + # api_response provided → done immediately, no WS needed + self._closed.set() + # api_response is None → _closed stays unset (in-progress, waiting for WS) + self._disconnect_fn: Callable[[], None] | None = None + + @property + def done(self) -> bool: + """True if data collection is complete (or no WS was needed).""" + return self._closed.is_set() + + def wait(self, timeout: float | None = None) -> "UtilResponse": + """Block until data collection is complete. Returns self.""" + self._closed.wait(timeout=timeout) + return self + + def receive(self) -> Generator[str, None, None]: + """ + Blocking generator that yields each processed message as it arrives. + + Mirrors ``_MistWebsocket.receive()``. Exits cleanly when the + WebSocket connection closes or ``disconnect()`` is called. + """ + while True: + try: + item = self._queue.get(timeout=0.1) + except queue.Empty: + if self._closed.is_set() and self._queue.empty(): + break + continue + if item is None: + break + yield item + + def disconnect(self) -> None: + """Stop the WebSocket connection early.""" + fn = self._disconnect_fn + if fn is not None: + fn() + + def __enter__(self) -> "UtilResponse": + return self + + def __exit__(self, *args) -> None: + self.disconnect() + + def __await__(self): + """Allow ``result = await response`` in async contexts.""" + import asyncio + + async def _await_impl(): + timed_out = not await asyncio.to_thread( + self._closed.wait, self._await_timeout + ) + if timed_out: + LOGGER.warning("await timed out after %s seconds", self._await_timeout) + return self + + return _await_impl().__await__() class WebSocketWrapper: @@ -54,13 +292,13 @@ class WebSocketWrapper: def __init__( self, - apissession: APISession, + apisession: APISession, util_response: UtilResponse, timeout: int = 10, max_duration: int = 60, on_message: Callable[[dict], None] | None = None, ) -> None: - self.apissession = apissession + self.apisession = apisession self.util_response = util_response self.timers = { Timer.TIMEOUT.value: { @@ -83,8 +321,14 @@ def __init__( self.session_id: str | None = None self.capture_id: str | None = None self._on_message_cb = on_message - self._closed = threading.Event() + self._screen: _VT100Screen | None = None + self._screen_mode: bool = False + self._extract_trigger_ids() + def _extract_trigger_ids(self): + """Extract session_id and capture_id from the trigger API response.""" + if not self.util_response.trigger_api_response: + return LOGGER.debug( "trigger response: %s", self.util_response.trigger_api_response.data ) @@ -107,7 +351,9 @@ def _on_open(self): def _on_close(self, code, msg): LOGGER.info("WebSocket closed: %s - %s", code, msg) - self._closed.set() + self._stop_all_timers() + self.util_response._queue.put(None) # sentinel for receive() + self.util_response._closed.set() # signal completion ########################################################################## ## Helper methods for managing timers @@ -153,11 +399,13 @@ def _handle_message(self, msg): self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.START) elif self._extract_session_id(msg): # Stop the first message timeout timer on receiving the first message - self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.STOP) + if self.timers[Timer.FIRST_MESSAGE_TIMEOUT.value]["thread"]: + self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.STOP) LOGGER.debug("data: %s", msg) raw = self._extract_raw(msg) if raw: self.data.append(raw) + self.util_response._queue.put(raw) # feed receive() generator if self._on_message_cb: self._on_message_cb(raw) self._timeout_handler(Timer.TIMEOUT, TimerAction.RESET) @@ -202,14 +450,15 @@ def _extract_session_id(self, message) -> bool: return True return False - def _extract_raw(self, message): + def _extract_raw(self, message, root: bool = True): """ Extracts the raw message from the given message. This method is designed to handle messages that may have the raw message nested at different levels. Handles both command events (with "raw" field) and pcap events (with "pcap_dict" field). """ - self.raw_events.append(message) + if root: + self.raw_events.append(message) event = message if isinstance(event, str): try: @@ -219,11 +468,22 @@ def _extract_raw(self, message): return None if isinstance(event, dict): if event.get("event") == "data" and event.get("data"): - return self._extract_raw(event["data"]) + return self._extract_raw(event["data"], root=False) if "raw" in event: self.received_messages += 1 - LOGGER.debug("Extracted raw message: %s", event["raw"]) - return event["raw"] + raw_value = event["raw"] + if isinstance(raw_value, str): + # Detect screen-mode (cursor positioning / clear-screen) + if not self._screen_mode and _SCREEN_MODE_RE.search(raw_value): + self._screen_mode = True + self._screen = _VT100Screen() + if self._screen_mode and self._screen is not None: + self._screen.feed(raw_value) + raw_value = self._screen.render() + else: + raw_value = _ANSI_ESCAPE_RE.sub("", raw_value) + LOGGER.debug("Extracted raw message: %s", raw_value) + return raw_value if "pcap_dict" in event: self.received_messages += 1 LOGGER.debug("Extracted pcap data: %s", event["pcap_dict"]) @@ -234,7 +494,11 @@ def _extract_raw(self, message): ## WebSocket connection management def start(self, ws) -> UtilResponse: """ - Start the WS connection, block until closed, return UtilResponse. + Start the WS connection in the background and return immediately. + + The returned ``UtilResponse`` collects data as it streams in. Use + ``response.receive()``, ``response.wait()``, or the ``on_message`` + callback to consume results. PARAMS ----------- @@ -246,9 +510,90 @@ def start(self, ws) -> UtilResponse: ws.on_error(lambda error: LOGGER.error("Error: %s", error)) ws.on_close(self._on_close) ws.on_open(self._on_open) - ws.connect(run_in_background=False) # blocks until _on_close fires - self._stop_all_timers() + + # Wire up UtilResponse before starting WS + # _closed is already unset (in-progress) from UtilResponse.__init__ self.util_response.ws_required = True - self.util_response.ws_data = self.data + self.util_response.ws_data = self.data # live list reference self.util_response.ws_raw_events = self.raw_events + self.util_response._await_timeout = ( + self.timers[Timer.MAX_DURATION.value]["duration"] + 10 + ) + self.util_response._disconnect_fn = ws.disconnect + + ws.connect(run_in_background=True) # non-blocking + return self.util_response + + def start_with_trigger( + self, + trigger_fn: Callable, + ws_factory_fn: Callable | None = None, + ) -> UtilResponse: + """ + Run the trigger API call and optionally start a WebSocket stream. + + If ``ws_factory_fn`` is provided, the trigger and WebSocket setup + run in a background thread (non-blocking). If ``ws_factory_fn`` is + ``None``, the trigger runs synchronously so that + ``trigger_api_response`` is immediately available on the returned + ``UtilResponse``. + + PARAMS + ----------- + trigger_fn : Callable + A zero-argument callable that performs the REST API trigger and + returns an ``APIResponse``. + ws_factory_fn : Callable, optional + A one-argument callable that receives the trigger ``APIResponse`` + and returns a WebSocket channel object (e.g. ``DeviceCmdEvents``). + If ``None``, no WebSocket is started and the ``UtilResponse`` + completes as soon as the trigger finishes. + """ + if ws_factory_fn is None: + return self._trigger_only(trigger_fn) + + def _run(): + try: + trigger = trigger_fn() + self.util_response.trigger_api_response = trigger + if trigger.status_code == 200: + LOGGER.info("Trigger succeeded: %s", trigger.data) + self._extract_trigger_ids() + ws = ws_factory_fn(trigger) + if ws: + self.start(ws) + return # start() / _on_close manages _closed + LOGGER.error("WS factory returned None") + else: + LOGGER.error( + "Failed to trigger command: %s - %s", + trigger.status_code, + trigger.data, + ) + except Exception as e: + LOGGER.error("Error during trigger: %s", e) + # Mark done (failure or WS factory returned None) + self.util_response._queue.put(None) + self.util_response._closed.set() + + threading.Thread(target=_run, daemon=True).start() + return self.util_response + + def _trigger_only(self, trigger_fn: Callable) -> UtilResponse: + """Run a trigger-only command synchronously (no WebSocket needed).""" + try: + trigger = trigger_fn() + self.util_response.trigger_api_response = trigger + if trigger.status_code == 200: + LOGGER.info("Trigger succeeded: %s", trigger.data) + else: + LOGGER.error( + "Failed to trigger command: %s - %s", + trigger.status_code, + trigger.data, + ) + except Exception as e: + LOGGER.error("Error during trigger: %s", e) + self.util_response._queue.put(None) + self.util_response._closed.set() return self.util_response diff --git a/src/mistapi/device_utils/__tools/arp.py b/src/mistapi/device_utils/__tools/arp.py index f9b3d6d..02f839f 100644 --- a/src/mistapi/device_utils/__tools/arp.py +++ b/src/mistapi/device_utils/__tools/arp.py @@ -11,24 +11,17 @@ """ from collections.abc import Callable -from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in ARP commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - def retrieve_ap_arp_table( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -42,7 +35,7 @@ def retrieve_ap_arp_table( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -61,35 +54,30 @@ def retrieve_ap_arp_table( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - # AP is returning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON + LOGGER.debug( + "Initiating ARP table retrieval for device %s with node %s and timeout %s", + device_id, + node, + timeout, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value - trigger = devices.arpFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.arpFromDevice( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response def retrieve_ssr_arp_table( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -103,7 +91,7 @@ def retrieve_ssr_arp_table( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -122,35 +110,30 @@ def retrieve_ssr_arp_table( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - # AP is returning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON + LOGGER.debug( + "Initiating SSR ARP table retrieval for device %s with node %s and timeout %s", + device_id, + node, + timeout, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value - trigger = devices.arpFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.arpFromDevice( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response def retrieve_junos_arp_table( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, ip: str | None = None, @@ -167,7 +150,7 @@ def retrieve_junos_arp_table( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -190,6 +173,15 @@ def retrieve_junos_arp_table( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating Junos ARP table retrieval for device %s with IP filter %s, port filter %s, " + "VRF filter %s, and timeout %s", + device_id, + ip, + port_id, + vrf, + timeout, + ) body: dict[str, str | list | int] = {"duration": 1, "interval": 1} if ip: body["ip"] = ip @@ -197,22 +189,14 @@ def retrieve_junos_arp_table( body["vrf"] = vrf if port_id: body["port_id"] = port_id - trigger = devices.showSiteDeviceArpTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteDeviceArpTable( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/bgp.py b/src/mistapi/device_utils/__tools/bgp.py index f545c57..74f8a5c 100644 --- a/src/mistapi/device_utils/__tools/bgp.py +++ b/src/mistapi/device_utils/__tools/bgp.py @@ -20,7 +20,7 @@ def summary( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, timeout=5, @@ -34,7 +34,7 @@ def summary( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -49,22 +49,20 @@ def summary( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating BGP summary retrieval for device %s with timeout %s", + device_id, + timeout, + ) body: dict[str, str | list | int] = {"protocol": "bgp"} - trigger = devices.showSiteDeviceBgpSummary( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteDeviceBgpSummary( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"BGP summary command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" - ) # Give the BGP summary command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/bpdu.py b/src/mistapi/device_utils/__tools/bpdu.py index 0bdf96b..125e1f7 100644 --- a/src/mistapi/device_utils/__tools/bpdu.py +++ b/src/mistapi/device_utils/__tools/bpdu.py @@ -13,11 +13,11 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -async def clear_error( - apissession: _APISession, +def clear_error( + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -29,6 +29,8 @@ async def clear_error( PARAMS ----------- + apisession: mistapi.APISession + The API session to use for the request. site_id : str UUID of the site where the switch is located. device_id : str @@ -42,20 +44,15 @@ async def clear_error( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - + LOGGER.debug( + "Initiating clear BPDU error command for device %s on ports %s", + device_id, + port_ids, + ) body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearBpduErrorsFromPortsOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper(apisession, util_response).start_with_trigger( + trigger_fn=lambda: devices.clearBpduErrorsFromPortsOnSwitch( + apisession, site_id=site_id, device_id=device_id, body=body + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear BPDU error command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" - ) # Give the clear BPDU error command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/dhcp.py b/src/mistapi/device_utils/__tools/dhcp.py index e0ed5f0..29afd6e 100644 --- a/src/mistapi/device_utils/__tools/dhcp.py +++ b/src/mistapi/device_utils/__tools/dhcp.py @@ -11,24 +11,17 @@ """ from collections.abc import Callable -from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in DHCP commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - def release_dhcp_leases( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, macs: list[str] | None = None, @@ -57,7 +50,7 @@ def release_dhcp_leases( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -82,6 +75,16 @@ def release_dhcp_leases( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating DHCP lease release for device %s with MACs %s, network %s, node %s, port ID %s, " + "and timeout %s", + device_id, + macs, + network, + node, + port_id, + timeout, + ) body: dict[str, str | list | int] = {} if macs: body["macs"] = macs @@ -91,28 +94,21 @@ def release_dhcp_leases( body["node"] = node.value if port_id: body["port_id"] = port_id - trigger = devices.releaseSiteDeviceDhcpLease( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.releaseSiteDeviceDhcpLease( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response def retrieve_dhcp_leases( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, network: str, @@ -127,20 +123,18 @@ def retrieve_dhcp_leases( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. device_id : str UUID of the device to retrieve DHCP leases from. network : str - Network to release DHCP leases for. + Network to retrieve DHCP leases for. node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. + Node information for the DHCP lease retrieval command. timeout : int, optional - Timeout for the release DHCP leases command in seconds. + Timeout for the retrieve DHCP leases command in seconds. on_message : Callable, optional Callback invoked with each extracted raw message as it arrives. @@ -149,24 +143,24 @@ def retrieve_dhcp_leases( UtilResponse A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating DHCP lease retrieval for device %s with network %s, node %s, and timeout %s", + device_id, + network, + node, + timeout, + ) body: dict[str, str | list | int] = {"network": network} if node: body["node"] = node.value - trigger = devices.showSiteDeviceDhcpLeases( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteDeviceDhcpLeases( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/dns.py b/src/mistapi/device_utils/__tools/dns.py index 4f5cca2..94cd0cd 100644 --- a/src/mistapi/device_utils/__tools/dns.py +++ b/src/mistapi/device_utils/__tools/dns.py @@ -1,28 +1,27 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- +# """ +# -------------------------------------------------------------------------------- +# ------------------------- Mist API Python CLI Session -------------------------- - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python +# Written by: Thomas Munzer (tmunzer@juniper.net) +# Github : https://github.com/tmunzer/mistapi_python - This package is licensed under the MIT License. +# This package is licensed under the MIT License. --------------------------------------------------------------------------------- -""" +# -------------------------------------------------------------------------------- +# """ -from enum import Enum +# from collections.abc import Callable +# from mistapi import APISession as _APISession +# from mistapi.api.v1.sites import devices +# from mistapi.device_utils.__tools.__common import Node +# from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +# from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in DNS commands.""" - NODE0 = "node0" - NODE1 = "node1" - - -## NO DATA +# ## NO DATA # def test_resolution( -# apissession: _APISession, +# apisession: _APISession, # site_id: str, # device_id: str, # node: Node | None = None, @@ -37,7 +36,7 @@ class Node(Enum): # PARAMS # ----------- -# apissession : _APISession +# apisession: mistapi.APISession # The API session to use for the request. # site_id : str # UUID of the site where the gateway is located. @@ -63,22 +62,14 @@ class Node(Enum): # body["node"] = node.value # if hostname: # body["hostname"] = hostname -# trigger = devices.testSiteSsrDnsResolution( -# apissession, -# site_id=site_id, -# device_id=device_id, -# body=body, +# util_response = UtilResponse() +# return WebSocketWrapper( +# apisession, util_response, timeout=timeout, on_message=on_message +# ).start_with_trigger( +# trigger_fn=lambda: devices.testSiteSsrDnsResolution( +# apisession, site_id=site_id, device_id=device_id, body=body +# ), +# ws_factory_fn=lambda _trigger: DeviceCmdEvents( +# apisession, site_id=site_id, device_ids=[device_id] +# ), # ) -# util_response = UtilResponse(trigger) -# if trigger.status_code == 200: -# LOGGER.info(trigger.data) -# print(f"SSR DNS resolution command triggered for device {device_id}") -# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) -# util_response = WebSocketWrapper( -# apissession, util_response, timeout=timeout, on_message=on_message -# ).start(ws) -# else: -# LOGGER.error( -# f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" -# ) # Give the SSR DNS resolution command a moment to take effect -# return util_response diff --git a/src/mistapi/device_utils/__tools/dot1x.py b/src/mistapi/device_utils/__tools/dot1x.py index 537e65d..59590b1 100644 --- a/src/mistapi/device_utils/__tools/dot1x.py +++ b/src/mistapi/device_utils/__tools/dot1x.py @@ -13,11 +13,11 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -async def clear_sessions( - apissession: _APISession, +def clear_sessions( + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -29,6 +29,8 @@ async def clear_sessions( PARAMS ----------- + apisession: mistapi.APISession + The API session to use for the request. site_id : str UUID of the site where the switch is located. device_id : str @@ -42,19 +44,15 @@ async def clear_sessions( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating clear dot1x sessions command for device %s on ports %s", + device_id, + port_ids, + ) body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearAllLearnedMacsFromPortOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper(apisession, util_response).start_with_trigger( + trigger_fn=lambda: devices.clearSiteDeviceDot1xSession( + apisession, site_id=site_id, device_id=device_id, body=body + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/mac.py b/src/mistapi/device_utils/__tools/mac.py index d68441a..6616976 100644 --- a/src/mistapi/device_utils/__tools/mac.py +++ b/src/mistapi/device_utils/__tools/mac.py @@ -20,13 +20,12 @@ def clear_mac_table( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, mac_address: str | None = None, port_id: str | None = None, vlan_id: str | None = None, - # timeout=30, ) -> UtilResponse: """ DEVICES: EX @@ -35,7 +34,7 @@ def clear_mac_table( PARAMS ----------- - apissession : _APISession + apisession : mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -54,9 +53,14 @@ def clear_mac_table( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - # AP is returning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON + LOGGER.debug( + "Initiating clear MAC table command for device %s with MAC address filter %s, " + "port filter %s, and VLAN filter %s", + device_id, + mac_address, + port_id, + vlan_id, + ) body: dict[str, str | list | int] = {} if mac_address: body["mac_address"] = mac_address @@ -64,28 +68,16 @@ def clear_mac_table( body["port_id"] = port_id if vlan_id: body["vlan_id"] = vlan_id - trigger = devices.clearSiteDeviceMacTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper(apisession, util_response).start_with_trigger( + trigger_fn=lambda: devices.clearSiteDeviceMacTable( + apisession, site_id=site_id, device_id=device_id, body=body + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear MAC Table command triggered for device {device_id}") - # util_response = WebSocketWrapper( - # apissession, util_response, timeout=timeout, on_message=on_message - # ).start(ws) - else: - LOGGER.error( - f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" - ) # Give the clear MAC Table command a moment to take effect - return util_response def retrieve_mac_table( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, mac_address: str | None = None, @@ -101,20 +93,20 @@ def retrieve_mac_table( PARAMS ----------- - apissession : _APISession + apisession : mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. device_id : str - UUID of the device to retrieve the ARP table from. + UUID of the device to retrieve the MAC table from. mac_address : str, optional - MAC address to filter the ARP table retrieval. + MAC address to filter the MAC table retrieval. port_id : str, optional - Port ID to filter the ARP table retrieval. + Port ID to filter the MAC table retrieval. vlan_id : str, optional - VLAN ID to filter the ARP table retrieval. + VLAN ID to filter the MAC table retrieval. timeout : int, optional - Timeout for the ARP table retrieval command in seconds. + Timeout for the MAC table retrieval command in seconds. on_message : Callable, optional Callback invoked with each extracted raw message as it arrives. @@ -124,9 +116,15 @@ def retrieve_mac_table( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - # AP is returning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON + LOGGER.debug( + "Initiating MAC table retrieval for device %s with MAC address filter %s, port filter %s, " + "VLAN filter %s, and timeout %s", + device_id, + mac_address, + port_id, + vlan_id, + timeout, + ) body: dict[str, str | list | int] = {} if mac_address: body["mac_address"] = mac_address @@ -134,29 +132,21 @@ def retrieve_mac_table( body["port_id"] = port_id if vlan_id: body["vlan_id"] = vlan_id - trigger = devices.showSiteDeviceMacTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteDeviceMacTable( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show MAC Table command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger show MAC Table command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response def clear_learned_mac( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -168,6 +158,8 @@ def clear_learned_mac( PARAMS ----------- + apisession: mistapi.APISession + The API session to use for the request. site_id : str UUID of the site where the device is located. device_id : str @@ -181,19 +173,15 @@ def clear_learned_mac( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating clear learned MACs command for device %s on ports %s", + device_id, + port_ids, + ) body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearSiteDeviceDot1xSession( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper(apisession, util_response).start_with_trigger( + trigger_fn=lambda: devices.clearAllLearnedMacsFromPortOnSwitch( + apisession, site_id=site_id, device_id=device_id, body=body + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/miscellaneous.py b/src/mistapi/device_utils/__tools/miscellaneous.py index ccc1bc4..6bb9611 100644 --- a/src/mistapi/device_utils/__tools/miscellaneous.py +++ b/src/mistapi/device_utils/__tools/miscellaneous.py @@ -4,18 +4,12 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.session import SessionWithUrl from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - class TracerouteProtocol(Enum): """Enum for specifying protocol in traceroute command.""" @@ -24,7 +18,7 @@ class TracerouteProtocol(Enum): def ping( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, host: str, @@ -43,7 +37,7 @@ def ping( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -70,6 +64,17 @@ def ping( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating ping command for device %s to host %s with count %s, node %s, size %s, " + "VRF %s, and timeout %s", + device_id, + host, + count, + node, + size, + vrf, + timeout, + ) body: dict[str, str | list | int] = {} if count: body["count"] = count @@ -81,29 +86,22 @@ def ping( body["size"] = size if vrf: body["vrf"] = vrf - trigger = devices.pingFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.pingFromDevice( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Ping command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" - ) # Give the ping command a moment to take effect - return util_response ## NO DATA # def service_ping( -# apissession: _APISession, +# apisession: _APISession, # site_id: str, # device_id: str, # host: str, @@ -122,7 +120,7 @@ def ping( # PARAMS # ----------- -# apissession : _APISession +# apisession: mistapi.APISession # The API session to use for the request. # site_id : str # UUID of the site where the device is located. @@ -165,7 +163,7 @@ def ping( # if service: # body["service"] = service # trigger = devices.servicePingFromSsr( -# apissession, +# apisession, # site_id=site_id, # device_id=device_id, # body=body, @@ -173,9 +171,9 @@ def ping( # util_response = UtilResponse(trigger) # if trigger.status_code == 200: # LOGGER.info(f"Service Ping command triggered for device {device_id}") -# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# ws = DeviceCmdEvents(apisession, site_id=site_id, device_ids=[device_id]) # util_response = WebSocketWrapper( -# apissession, util_response, timeout, on_message=on_message +# apisession, util_response, timeout, on_message=on_message # ).start(ws) # else: # LOGGER.error( @@ -185,7 +183,7 @@ def ping( def traceroute( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, host: str, @@ -202,7 +200,7 @@ def traceroute( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -225,33 +223,35 @@ def traceroute( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating traceroute command for device %s to host %s with protocol %s, port %s, " + "and timeout %s", + device_id, + host, + protocol, + port, + timeout, + ) body: dict[str, str | list | int] = {"host": host} if protocol: body["protocol"] = protocol.value if port: body["port"] = port - trigger = devices.tracerouteFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.tracerouteFromDevice( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Traceroute command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" - ) # Give the traceroute command a moment to take effect - return util_response def monitor_traffic( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_id: str | None = None, @@ -269,7 +269,7 @@ def monitor_traffic( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -288,77 +288,80 @@ def monitor_traffic( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating monitor traffic command for device %s on port %s with timeout %s", + device_id, + port_id, + timeout, + ) body: dict[str, str | int] = {"duration": 60} if port_id: body["port"] = port_id - trigger = devices.monitorSiteDeviceTraffic( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Monitor traffic command triggered for device {device_id}") - ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: + + def _ws_factory(trigger): + if isinstance(trigger.data, dict) and "url" in trigger.data: + return SessionWithUrl(apisession, url=trigger.data.get("url", "")) LOGGER.error( - f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" - ) # Give the monitor traffic command a moment to take effect - return util_response + "Monitor traffic command did not return a valid URL: %s", trigger.data + ) + return None + + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.monitorSiteDeviceTraffic( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=_ws_factory, + ) -## NO DATA -# def srx_top_command( -# apissession: _APISession, -# site_id: str, -# device_id: str, -# timeout=10, -# on_message: Callable[[dict], None] | None = None, -# ) -> UtilResponse: -# """ -# DEVICE: SRX +# NO DATA +def top_command( + apisession: _APISession, + site_id: str, + device_id: str, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX -# For SRX Only. Initiates a top command on the device and streams the results. + Initiates a top command on the device and streams the results. -# PARAMS -# ----------- -# apissession : _APISession -# The API session to use for the request. -# site_id : str -# UUID of the site where the device is located. -# device_id : str -# UUID of the device to run the top command on. -# timeout : int, optional -# Timeout for the top command in seconds. -# on_message : Callable, optional -# Callback invoked with each extracted raw message as it arrives. + PARAMS + ----------- + apisession: mistapi.APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run the top command on. + timeout : int, optional + Timeout for the top command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. -# RETURNS -# ----------- -# UtilResponse -# A UtilResponse object containing the API response and a list of raw messages received -# from the WebSocket stream. -# """ -# trigger = devices.runSiteSrxTopCommand( -# apissession, -# site_id=site_id, -# device_id=device_id, -# ) -# util_response = UtilResponse(trigger) -# if trigger.status_code == 200: -# LOGGER.info(trigger.data) -# print(f"Top command triggered for device {device_id}") -# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) -# util_response = WebSocketWrapper( -# apissession, util_response, timeout=timeout, on_message=on_message -# ).start(ws) -# else: -# LOGGER.error( -# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" -# ) # Give the top command a moment to take effect -# return util_response + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + def _ws_factory(trigger): + if isinstance(trigger.data, dict) and "url" in trigger.data: + return SessionWithUrl(apisession, url=trigger.data.get("url", "")) + LOGGER.error("Top command did not return a valid URL: %s", trigger.data) + return None + + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.runSiteSrxTopCommand( + apisession, site_id=site_id, device_id=device_id + ), + ws_factory_fn=_ws_factory, + ) diff --git a/src/mistapi/device_utils/__tools/ospf.py b/src/mistapi/device_utils/__tools/ospf.py index 09eda9e..de07592 100644 --- a/src/mistapi/device_utils/__tools/ospf.py +++ b/src/mistapi/device_utils/__tools/ospf.py @@ -11,24 +11,17 @@ """ from collections.abc import Callable -from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in OSPF commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - def show_database( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -45,7 +38,7 @@ def show_database( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -66,6 +59,14 @@ def show_database( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating show OSPF database command for device %s with node %s, self_originate %s, " + "and VRF %s", + device_id, + node, + self_originate, + vrf, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -73,28 +74,21 @@ def show_database( body["self_originate"] = self_originate if vrf: body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfDatabase( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteGatewayOspfDatabase( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF database command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF database command a moment to take effect - return util_response def show_interfaces( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -111,7 +105,7 @@ def show_interfaces( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -132,6 +126,14 @@ def show_interfaces( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating show OSPF interfaces command for device %s with node %s, port_id %s, " + "and VRF %s", + device_id, + node, + port_id, + vrf, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -139,28 +141,21 @@ def show_interfaces( body["port_id"] = port_id if vrf: body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfInterfaces( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteGatewayOspfInterfaces( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF interfaces command a moment to take effect - return util_response def show_neighbors( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, neighbor: str | None = None, @@ -178,7 +173,7 @@ def show_neighbors( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -201,6 +196,15 @@ def show_neighbors( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating show OSPF neighbors command for device %s with neighbor %s, node %s, " + "port_id %s, and VRF %s", + device_id, + neighbor, + node, + port_id, + vrf, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -210,28 +214,21 @@ def show_neighbors( body["vrf"] = vrf if neighbor: body["neighbor"] = neighbor - trigger = devices.showSiteGatewayOspfNeighbors( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteGatewayOspfNeighbors( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF neighbors command a moment to take effect - return util_response def show_summary( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -247,7 +244,7 @@ def show_summary( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -266,26 +263,25 @@ def show_summary( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating show OSPF summary command for device %s with node %s, and VRF %s", + device_id, + node, + vrf, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value if vrf: body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfSummary( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteGatewayOspfSummary( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF summary command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF summary command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/policy.py b/src/mistapi/device_utils/__tools/policy.py index 2d57303..891788d 100644 --- a/src/mistapi/device_utils/__tools/policy.py +++ b/src/mistapi/device_utils/__tools/policy.py @@ -13,11 +13,11 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -async def clear_hit_count( - apissession: _APISession, +def clear_hit_count( + apisession: _APISession, site_id: str, device_id: str, policy_name: str, @@ -29,7 +29,7 @@ async def clear_hit_count( PARAMS ----------- - apissession : _APISession + apisession : _APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -43,20 +43,17 @@ async def clear_hit_count( UtilResponse A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - trigger = devices.clearSiteDevicePolicyHitCount( - apissession, - site_id=site_id, - device_id=device_id, - body={"policy_name": policy_name}, + LOGGER.debug( + "Initiating clear policy hit count command for device %s and policy %s", + device_id, + policy_name, + ) + util_response = UtilResponse() + return WebSocketWrapper(apisession, util_response).start_with_trigger( + trigger_fn=lambda: devices.clearSiteDevicePolicyHitCount( + apisession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") - # util_response = await WebSocketWrapper( - # apissession, util_response, timeout=timeout - # ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" - ) # Give the clear policy hit count command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/port.py b/src/mistapi/device_utils/__tools/port.py index d0c150e..5365edd 100644 --- a/src/mistapi/device_utils/__tools/port.py +++ b/src/mistapi/device_utils/__tools/port.py @@ -20,11 +20,11 @@ def bounce( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], - timeout=60, + timeout=5, on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ @@ -52,33 +52,30 @@ def bounce( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating bounce command for device %s on ports %s with timeout %s", + device_id, + port_ids, + timeout, + ) body: dict[str, str | list | int] = {} if port_ids: body["ports"] = port_ids - trigger = devices.bounceDevicePort( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.bounceDevicePort( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info( - f"Bounce command triggered for ports {port_ids} on device {device_id}" - ) - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" - ) # Give the bounce command a moment to take effect - return util_response def cable_test( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_id: str, @@ -92,7 +89,7 @@ def cable_test( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the switch is located. @@ -100,7 +97,7 @@ def cable_test( UUID of the switch to perform the cable test on. port_id : str Port ID to perform the cable test on. - timeout : int, optional + timeout : int, default 10 Timeout for the cable test command in seconds. on_message : Callable, optional Callback invoked with each extracted raw message as it arrives. @@ -111,23 +108,21 @@ def cable_test( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating cable test for device %s on port %s with timeout %s", + device_id, + port_id, + timeout, + ) body: dict[str, str | list | int] = {"port": port_id} - trigger = devices.cableTestFromSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.cableTestFromSwitch( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Cable test command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" - ) # Give the cable test command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/remote_capture.py b/src/mistapi/device_utils/__tools/remote_capture.py index 90a438d..7e628a7 100644 --- a/src/mistapi/device_utils/__tools/remote_capture.py +++ b/src/mistapi/device_utils/__tools/remote_capture.py @@ -47,13 +47,13 @@ def _build_pcap_body( if tcpdump_expression is not None: port_entry["tcpdump_expression"] = tcpdump_expression body[device_key][mac]["ports"][port_id] = port_entry - if tcpdump_expression: + if tcpdump_expression is not None: body["tcpdump_expression"] = tcpdump_expression return body def ap_remote_pcap_wireless( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, band: str, @@ -73,7 +73,7 @@ def ap_remote_pcap_wireless( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -105,6 +105,12 @@ def ap_remote_pcap_wireless( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating remote pcap for device %s on band %s with timeout %s", + device_id, + band, + timeout, + ) body: dict[str, str | int] = { "band": band, "duration": duration, @@ -119,28 +125,19 @@ def ap_remote_pcap_wireless( body["ap_mac"] = ap_mac if tcpdump_expression: body["tcpdump_expression"] = tcpdump_expression - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: pcaps.startSitePacketCapture( + apisession, site_id=site_id, body=body + ), + ws_factory_fn=lambda _trigger: PcapEvents(apisession, site_id=site_id), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response def ap_remote_pcap_wired( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, tcpdump_expression: str | None = None, @@ -157,7 +154,7 @@ def ap_remote_pcap_wired( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -183,6 +180,11 @@ def ap_remote_pcap_wired( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating remote pcap for device %s with timeout %s", + device_id, + timeout, + ) body: dict[str, str | int] = { "duration": duration, "max_pkt_len": max_pkt_len, @@ -192,28 +194,19 @@ def ap_remote_pcap_wired( } if tcpdump_expression: body["tcpdump_expression"] = tcpdump_expression - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: pcaps.startSitePacketCapture( + apisession, site_id=site_id, body=body + ), + ws_factory_fn=lambda _trigger: PcapEvents(apisession, site_id=site_id), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response def srx_remote_pcap( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -231,7 +224,7 @@ def srx_remote_pcap( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -259,6 +252,12 @@ def srx_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating remote pcap for device %s on ports %s with timeout %s", + device_id, + port_ids, + timeout, + ) body = _build_pcap_body( device_id, port_ids, @@ -269,28 +268,19 @@ def srx_remote_pcap( max_pkt_len, num_packets, ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: pcaps.startSitePacketCapture( + apisession, site_id=site_id, body=body + ), + ws_factory_fn=lambda _trigger: PcapEvents(apisession, site_id=site_id), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response def ssr_remote_pcap( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -308,7 +298,7 @@ def ssr_remote_pcap( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -336,6 +326,12 @@ def ssr_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating remote pcap for device %s on ports %s with timeout %s", + device_id, + port_ids, + timeout, + ) body = _build_pcap_body( device_id, port_ids, @@ -347,28 +343,19 @@ def ssr_remote_pcap( num_packets, raw=False, ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: pcaps.startSitePacketCapture( + apisession, site_id=site_id, body=body + ), + ws_factory_fn=lambda _trigger: PcapEvents(apisession, site_id=site_id), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response def ex_remote_pcap( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, port_ids: list[str], @@ -386,7 +373,7 @@ def ex_remote_pcap( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the device is located. @@ -414,6 +401,12 @@ def ex_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating remote pcap for device %s on ports %s with timeout %s", + device_id, + port_ids, + timeout, + ) body = _build_pcap_body( device_id, port_ids, @@ -424,21 +417,12 @@ def ex_remote_pcap( max_pkt_len, num_packets, ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: pcaps.startSitePacketCapture( + apisession, site_id=site_id, body=body + ), + ws_factory_fn=lambda _trigger: PcapEvents(apisession, site_id=site_id), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/routes.py b/src/mistapi/device_utils/__tools/routes.py index 6022f02..265f548 100644 --- a/src/mistapi/device_utils/__tools/routes.py +++ b/src/mistapi/device_utils/__tools/routes.py @@ -16,16 +16,14 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - NODE0 = "node0" - NODE1 = "node1" - - class RouteProtocol(Enum): + """RouteProtocol Enum for specifying route protocol information in show routes command.""" + ANY = "any" BGP = "bgp" DIRECT = "direct" @@ -35,7 +33,7 @@ class RouteProtocol(Enum): def show( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -53,7 +51,7 @@ def show( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the gateway is located. @@ -80,7 +78,16 @@ def show( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - + LOGGER.debug( + "Initiating show routes command for device %s with node %s, prefix %s, protocol %s, " + "route_type %s, and VRF %s", + device_id, + node, + prefix, + protocol, + route_type, + vrf, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -92,22 +99,14 @@ def show( body["route_type"] = route_type if vrf: body["vrf"] = vrf - trigger = devices.showSiteSsrAndSrxRoutes( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteSsrAndSrxRoutes( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Device Routes command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger Device Routes command: {trigger.status_code} - {trigger.data}" - ) # Give the Device Routes command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/service_path.py b/src/mistapi/device_utils/__tools/service_path.py index 5f53fc0..ea34d6e 100644 --- a/src/mistapi/device_utils/__tools/service_path.py +++ b/src/mistapi/device_utils/__tools/service_path.py @@ -11,24 +11,17 @@ """ from collections.abc import Callable -from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in service path commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - def show_service_path( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -43,7 +36,7 @@ def show_service_path( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the gateway is located. @@ -64,27 +57,27 @@ def show_service_path( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ + LOGGER.debug( + "Initiating show service path command for device %s with node %s, service name %s, " + "and timeout %s", + device_id, + node, + service_name, + timeout, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value if service_name: body["service_name"] = service_name - trigger = devices.showSiteSsrServicePath( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteSsrServicePath( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"SSR service path command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR service path command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/sessions.py b/src/mistapi/device_utils/__tools/sessions.py index c019f10..0d14399 100644 --- a/src/mistapi/device_utils/__tools/sessions.py +++ b/src/mistapi/device_utils/__tools/sessions.py @@ -11,24 +11,17 @@ """ from collections.abc import Callable -from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__common import Node from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper from mistapi.websockets.sites import DeviceCmdEvents -class Node(Enum): - """Node Enum for specifying node information in session commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - def clear( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -45,22 +38,20 @@ def clear( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the gateway is located. device_id : str - UUID of the gateway to perform the show routes command on. + UUID of the gateway to perform clear sessions command on. node : Node, optional - Node information for the show routes command. - prefix : str, optional - Prefix to filter the routes. - protocol : RouteProtocol, optional - Protocol to filter the routes. - route_type : str, optional - Type of the route to filter. + Node information for the clear sessions command. + service_name : str, optional + Name of the service to filter the sessions. + service_ids : list[str], optional + List of service IDs to filter the sessions. vrf : str, optional - VRF to filter the routes. + VRF to filter the sessions. timeout : int, optional Timeout for the command in seconds. on_message : Callable, optional @@ -72,7 +63,16 @@ def clear( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - + LOGGER.debug( + "Initiating clear sessions command for device %s with node %s, service name %s, " + "service IDs %s, VRF %s, and timeout %s", + device_id, + node, + service_name, + service_ids, + vrf, + timeout, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -82,29 +82,21 @@ def clear( body["service_ids"] = service_ids if vrf: body["vrf"] = vrf - trigger = devices.clearSiteDeviceSession( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.clearSiteDeviceSession( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Device Sessions command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" - ) # Give the Device Sessions command a moment to take effect - return util_response def show( - apissession: _APISession, + apisession: _APISession, site_id: str, device_id: str, node: Node | None = None, @@ -120,7 +112,7 @@ def show( PARAMS ----------- - apissession : _APISession + apisession: mistapi.APISession The API session to use for the request. site_id : str UUID of the site where the gateway is located. @@ -143,7 +135,15 @@ def show( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - + LOGGER.debug( + "Initiating show sessions command for device %s with node %s, service name %s, " + "service IDs %s, and timeout %s", + device_id, + node, + service_name, + service_ids, + timeout, + ) body: dict[str, str | list | int] = {} if node: body["node"] = node.value @@ -151,22 +151,14 @@ def show( body["service_name"] = service_name if service_ids: body["service_ids"] = service_ids - trigger = devices.showSiteSsrAndSrxSessions( - apissession, - site_id=site_id, - device_id=device_id, - body=body, + util_response = UtilResponse() + return WebSocketWrapper( + apisession, util_response, timeout=timeout, on_message=on_message + ).start_with_trigger( + trigger_fn=lambda: devices.showSiteSsrAndSrxSessions( + apisession, site_id=site_id, device_id=device_id, body=body + ), + ws_factory_fn=lambda _trigger: DeviceCmdEvents( + apisession, site_id=site_id, device_ids=[device_id] + ), ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Device Sessions command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" - ) # Give the Device Sessions command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/__tools/shell.py b/src/mistapi/device_utils/__tools/shell.py new file mode 100644 index 0000000..f81b783 --- /dev/null +++ b/src/mistapi/device_utils/__tools/shell.py @@ -0,0 +1,360 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +Interactive SSH shell over WebSocket for Juniper EX/SRX devices. + +This module provides: +- ``ShellSession`` — a programmatic bidirectional WebSocket session +- ``create_shell_session()`` — factory that calls the shell API and returns + a connected ``ShellSession`` +- ``interactive_shell()`` — convenience function that takes over the terminal + for human SSH access +""" + +import json +import os +import ssl +import sys +import threading +from typing import TYPE_CHECKING + +import websocket + +from mistapi.__logger import logger as LOGGER + +if TYPE_CHECKING: + from mistapi import APISession + + +class ShellSession: + """ + Bidirectional WebSocket session for SSH-over-WebSocket shell access. + + Connects to the WebSocket URL returned by the Mist shell API endpoint + and provides methods to send/receive raw terminal data. + + USAGE PATTERNS + ----------- + Programmatic:: + + session = create_shell_session(apisession, site_id, device_id) + session.send_text("show version\\r\\n") + while session.connected: + data = session.recv() + if data: + print(data.decode("utf-8", errors="replace"), end="") + session.disconnect() + + Context manager:: + + with create_shell_session(apisession, site_id, device_id) as session: + session.send_text("show interfaces terse\\r\\n") + import time; time.sleep(5) + while True: + data = session.recv() + if data is None: + break + print(data.decode("utf-8", errors="replace"), end="") + + Interactive (human at the keyboard):: + + interactive_shell(apisession, site_id, device_id) + """ + + def __init__( + self, + mist_session: "APISession", + ws_url: str, + rows: int = 24, + cols: int = 80, + ) -> None: + """ + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session (used for auth headers/cookies/SSL). + ws_url : str + WebSocket URL from createSiteDeviceShellSession response. + rows : int + Initial terminal row count. + cols : int + Initial terminal column count. + """ + self._mist_session = mist_session + self._ws_url = ws_url + self._rows = rows + self._cols = cols + self._ws: websocket.WebSocket | None = None + + # ------------------------------------------------------------------ + # Auth / SSL helpers (mirrors _MistWebsocket but avoids coupling) + + def _get_headers(self) -> list[str]: + if self._mist_session._apitoken: + token = self._mist_session._apitoken[self._mist_session._apitoken_index] + return [f"Authorization: Token {token}"] + return [] + + def _get_cookie(self) -> str | None: + cookies = self._mist_session._session.cookies + if cookies: + safe = [] + for c in cookies: + has_crlf = ( + "\r" in c.name + or "\n" in c.name + or (c.value and ("\r" in c.value or "\n" in c.value)) + ) + if has_crlf: + LOGGER.warning( + "Skipping cookie %r: contains CRLF characters", + c.name, + ) + continue + safe.append(f"{c.name}={c.value}") + return "; ".join(safe) if safe else None + return None + + def _build_sslopt(self) -> dict: + sslopt: dict = {} + session = self._mist_session._session + if session.verify is False: + sslopt["cert_reqs"] = ssl.CERT_NONE + elif isinstance(session.verify, str): + sslopt["ca_certs"] = session.verify + if session.cert: + if isinstance(session.cert, str): + sslopt["certfile"] = session.cert + elif isinstance(session.cert, tuple): + sslopt["certfile"] = session.cert[0] + if len(session.cert) > 1: + sslopt["keyfile"] = session.cert[1] + return sslopt + + # ------------------------------------------------------------------ + # Lifecycle + + def connect(self) -> None: + """Open the WebSocket connection.""" + LOGGER.info("Connecting to shell WebSocket: %s", self._ws_url) + self._ws = websocket.create_connection( + self._ws_url, + header=self._get_headers(), + cookie=self._get_cookie(), + sslopt=self._build_sslopt(), + ) + self._ws.settimeout(0.1) + self.resize(self._rows, self._cols) + LOGGER.info("Shell WebSocket connected") + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + ws = self._ws + self._ws = None + if ws: + try: + ws.close() + except Exception: + pass + + @property + def connected(self) -> bool: + """True if the WebSocket is currently connected.""" + ws = self._ws + return ws is not None and ws.connected + + # ------------------------------------------------------------------ + # I/O + + def send(self, data: bytes) -> None: + """Send raw bytes (keystrokes) to the device shell.""" + ws = self._ws + if ws and ws.connected: + ws.send_binary(data) + + def send_text(self, text: str) -> None: + """Send a text string as binary data to the device shell.""" + self.send(f"\x00{text}".encode("utf-8")) + + def recv(self, timeout: float = 0.1) -> bytes | None: + """ + Receive raw bytes from the device shell. + + Returns None if no data is available within the timeout, or if + the connection is closed. + """ + ws = self._ws + if not ws or not ws.connected: + return None + old_timeout = ws.gettimeout() + try: + ws.settimeout(timeout) + data = ws.recv() + if isinstance(data, str): + return data.encode("utf-8") + return data + except websocket.WebSocketTimeoutException: + return None + except ( + websocket.WebSocketConnectionClosedException, + ConnectionError, + ): + return None + finally: + ws.settimeout(old_timeout) + + def resize(self, rows: int, cols: int) -> None: + """Send a terminal resize message to the device.""" + self._rows = rows + self._cols = cols + ws = self._ws + if ws and ws.connected: + ws.send(json.dumps({"resize": {"width": cols, "height": rows}})) + + # ------------------------------------------------------------------ + # Context manager + + def __enter__(self) -> "ShellSession": + return self + + def __exit__(self, *args) -> None: + self.disconnect() + + +def create_shell_session( + apisession: "APISession", + site_id: str, + device_id: str, + rows: int = 24, + cols: int = 80, +) -> ShellSession: + """ + Call the shell API and return a connected ShellSession. + + PARAMS + ----------- + apisession : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to connect to. + rows : int + Initial terminal row count. + cols : int + Initial terminal column count. + + RETURNS + ----------- + ShellSession + A connected ShellSession ready for send/recv. + + RAISES + ----------- + RuntimeError + If the API call fails or no WebSocket URL is returned. + """ + from mistapi.api.v1.sites import devices + + response = devices.createSiteDeviceShellSession( + apisession, site_id=site_id, device_id=device_id, body={} + ) + if response.status_code != 200: + raise RuntimeError( + f"Shell API call failed: {response.status_code} - {response.data}" + ) + if not isinstance(response.data, dict) or "url" not in response.data: + raise RuntimeError( + f"Shell API response did not contain a WebSocket URL: {response.data}" + ) + + ws_url = response.data["url"] + session = ShellSession(apisession, ws_url, rows=rows, cols=cols) + session.connect() + return session + + +def interactive_shell( + apisession: "APISession", + site_id: str, + device_id: str, +) -> None: + """ + Launch an interactive SSH shell session to a device. + + Takes over the terminal: captures keystrokes, sends them to the device, + and displays output. Blocks until the connection closes or the user + presses Ctrl+C. + + PARAMS + ----------- + apisession : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to connect to. + """ + from sshkeyboard import listen_keyboard + + try: + cols, rows = os.get_terminal_size() + except OSError: + rows, cols = 24, 80 + + session = create_shell_session(apisession, site_id, device_id, rows=rows, cols=cols) + + def _reader(): + """Background thread: read from WebSocket, write to stdout.""" + while session.connected: + data = session.recv(timeout=0.1) + if data: + sys.stdout.buffer.write(data) + sys.stdout.buffer.flush() + + def _on_key_press(key: str) -> None: + """Handle a key press event from sshkeyboard.""" + if not session.connected: + return + if key == "enter": + k = "\r\n" + elif key == "space": + k = " " + elif key == "tab": + k = "\t" + elif key == "up": + k = "\x1b[A" + elif key == "right": + k = "\x1b[C" + elif key == "down": + k = "\x1b[B" + elif key == "left": + k = "\x1b[D" + elif key == "backspace": + k = "\x7f" + else: + k = key + session.send(f"\x00{k}".encode("utf-8")) + + reader_thread = threading.Thread(target=_reader, daemon=True) + reader_thread.start() + + try: + listen_keyboard( + on_press=_on_key_press, + delay_second_char=0, + delay_other_chars=0, + lower=False, + ) + except KeyboardInterrupt: + pass + finally: + session.disconnect() + reader_thread.join(timeout=2) diff --git a/src/mistapi/device_utils/ap.py b/src/mistapi/device_utils/ap.py index 73e34df..1d43c9e 100644 --- a/src/mistapi/device_utils/ap.py +++ b/src/mistapi/device_utils/ap.py @@ -26,6 +26,6 @@ __all__ = [ "ping", "traceroute", - "TracerouteProtocol", "retrieveArpTable", + "TracerouteProtocol", ] diff --git a/src/mistapi/device_utils/bgp.py b/src/mistapi/device_utils/bgp.py deleted file mode 100644 index f545c57..0000000 --- a/src/mistapi/device_utils/bgp.py +++ /dev/null @@ -1,70 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -def summary( - apissession: _APISession, - site_id: str, - device_id: str, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: EX, SRX, SSR - - Shows BGP summary on a device (EX/ SRX / SSR) and streams the results. - - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to show BGP summary on. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"protocol": "bgp"} - trigger = devices.showSiteDeviceBgpSummary( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"BGP summary command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" - ) # Give the BGP summary command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/bpdu.py b/src/mistapi/device_utils/bpdu.py deleted file mode 100644 index c565903..0000000 --- a/src/mistapi/device_utils/bpdu.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse - - -async def clearError( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - DEVICES: EX - - Clears BPDU error state on the specified ports of a switch. - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear BPDU errors on. - port_ids : list[str] - List of port IDs to clear BPDU errors on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearBpduErrorsFromPortsOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear BPDU error command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" - ) # Give the clear BPDU error command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/dhcp.py b/src/mistapi/device_utils/dhcp.py deleted file mode 100644 index c967c34..0000000 --- a/src/mistapi/device_utils/dhcp.py +++ /dev/null @@ -1,172 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -class Node(Enum): - """Node Enum for specifying node information in DHCP commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - -def releaseDhcpLeases( - apissession: _APISession, - site_id: str, - device_id: str, - macs: list[str] | None = None, - network: str | None = None, - node: Node | None = None, - port_id: str | None = None, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: EX, SRX, SSR - - Releases DHCP leases on a device (EX/ SRX / SSR) and streams the results. - - valid combinations for EX are: - - network + macs - - network + port_id - - port_id - - valid combinations for SRX / SSR are: - - network - - network + macs - - network + port_id - - port_id - - port_id + macs - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to release DHCP leases on. - macs : list[str], optional - List of MAC addresses to release DHCP leases for. - network : str, optional - Network to release DHCP leases for. - node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. - timeout : int, optional - Timeout for the release DHCP leases command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if macs: - body["macs"] = macs - if network: - body["network"] = network - if node: - body["node"] = node.value - if port_id: - body["port_id"] = port_id - trigger = devices.releaseSiteDeviceDhcpLease( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response - - -def retrieveDhcpLeases( - apissession: _APISession, - site_id: str, - device_id: str, - network: str, - node: Node | None = None, - timeout=15, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SRX, SSR - - Retrieves DHCP leases on a gateway (SRX / SSR) and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to retrieve DHCP leases from. - network : str - Network to release DHCP leases for. - node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. - timeout : int, optional - Timeout for the release DHCP leases command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"network": network} - if node: - body["node"] = node.value - trigger = devices.showSiteDeviceDhcpLeases( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/dot1x.py b/src/mistapi/device_utils/dot1x.py deleted file mode 100644 index af5c322..0000000 --- a/src/mistapi/device_utils/dot1x.py +++ /dev/null @@ -1,60 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse - - -async def clearSessions( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - DEVICES: EX - - Clears dot1x sessions on the specified ports of a switch (EX). - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear dot1x sessions on. - port_ids : list[str] - List of port IDs to clear dot1x sessions on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearAllLearnedMacsFromPortOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/ex.py b/src/mistapi/device_utils/ex.py index dd3680f..816e32a 100644 --- a/src/mistapi/device_utils/ex.py +++ b/src/mistapi/device_utils/ex.py @@ -40,9 +40,15 @@ from mistapi.device_utils.__tools.mac import clear_mac_table as clearMacTable from mistapi.device_utils.__tools.mac import retrieve_mac_table as retrieveMacTable +# Shell (interactive SSH) +from mistapi.device_utils.__tools.shell import ShellSession +from mistapi.device_utils.__tools.shell import create_shell_session as createShellSession +from mistapi.device_utils.__tools.shell import interactive_shell as interactiveShell + # Tools (ping, monitor traffic) from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic from mistapi.device_utils.__tools.miscellaneous import ping +from mistapi.device_utils.__tools.miscellaneous import top_command as topCommand # Policy functions from mistapi.device_utils.__tools.policy import clear_hit_count as clearHitCount @@ -72,7 +78,12 @@ # Port "bouncePort", "cableTest", + # Shell + "ShellSession", + "createShellSession", + "interactiveShell", # Tools "monitorTraffic", "ping", + "topCommand", ] diff --git a/src/mistapi/device_utils/ospf.py b/src/mistapi/device_utils/ospf.py deleted file mode 100644 index 4903a52..0000000 --- a/src/mistapi/device_utils/ospf.py +++ /dev/null @@ -1,291 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -class Node(Enum): - """Node Enum for specifying node information in OSPF commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - -def showDatabase( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - self_originate: bool | None = None, - vrf: str | None = None, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SRX, SSR - - Shows OSPF database on a device (SRX / SSR) and streams the results. - - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to show OSPF database on. - node : Node, optional - Node information for the show OSPF database command. - self_originate : bool, optional - Filter for self-originated routes in the OSPF database. - vrf : str, optional - VRF to filter the OSPF database. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if self_originate is not None: - body["self_originate"] = self_originate - if vrf: - body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfDatabase( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF database command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF database command a moment to take effect - return util_response - - -def showInterfaces( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - port_id: str | None = None, - vrf: str | None = None, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SRX, SSR - - Shows OSPF interfaces on a device (SRX / SSR) and streams the results. - - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to show OSPF interfaces on. - node : Node, optional - Node information for the show OSPF interfaces command. - port_id : str, optional - Port ID to filter the OSPF interfaces. - vrf : str, optional - VRF to filter the OSPF interfaces. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if port_id: - body["port_id"] = port_id - if vrf: - body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfInterfaces( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF interfaces command a moment to take effect - return util_response - - -def showNeighbors( - apissession: _APISession, - site_id: str, - device_id: str, - neighbor: str | None = None, - node: Node | None = None, - port_id: str | None = None, - vrf: str | None = None, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SRX, SSR - - Shows OSPF neighbors on a device (SRX / SSR) and streams the results. - - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to show OSPF neighbors on. - neighbor : str, optional - Neighbor IP address to filter the OSPF neighbors. - node : Node, optional - Node information for the show OSPF neighbors command. - port_id : str, optional - Port ID to filter the OSPF neighbors. - vrf : str, optional - VRF to filter the OSPF neighbors. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if port_id: - body["port_id"] = port_id - if vrf: - body["vrf"] = vrf - if neighbor: - body["neighbor"] = neighbor - trigger = devices.showSiteGatewayOspfNeighbors( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF neighbors command a moment to take effect - return util_response - - -def showSummary( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - vrf: str | None = None, - timeout=5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SRX, SSR - - Shows OSPF summary on a device (SRX / SSR) and streams the results. - - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to show OSPF summary on. - node : Node, optional - Node information for the show OSPF summary command. - vrf : str, optional - VRF to filter the OSPF summary. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if vrf: - body["vrf"] = vrf - trigger = devices.showSiteGatewayOspfSummary( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"OSPF summary command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" - ) # Give the OSPF summary command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/policy.py b/src/mistapi/device_utils/policy.py deleted file mode 100644 index ba8d606..0000000 --- a/src/mistapi/device_utils/policy.py +++ /dev/null @@ -1,62 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse - - -async def clearHitCount( - apissession: _APISession, - site_id: str, - device_id: str, - policy_name: str, -) -> UtilResponse: - """ - DEVICE: EX - - Clears the policy hit count on a device. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear the policy hit count on. - policy_name : str - Name of the policy to clear the hit count for. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - trigger = devices.clearSiteDevicePolicyHitCount( - apissession, - site_id=site_id, - device_id=device_id, - body={"policy_name": policy_name}, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") - # util_response = await WebSocketWrapper( - # apissession, util_response, timeout=timeout - # ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" - ) # Give the clear policy hit count command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/port.py b/src/mistapi/device_utils/port.py deleted file mode 100644 index 5757c0f..0000000 --- a/src/mistapi/device_utils/port.py +++ /dev/null @@ -1,133 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -def bounce( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - timeout=60, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: EX, SRX, SSR - - Initiates a bounce command on the specified ports of a device (EX / SRX / SSR) and streams - the results. - - PARAMS - ----------- - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to perform the bounce command on. - port_ids : list[str] - List of port IDs to bounce. - timeout : int, default 5 - Timeout for the bounce command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if port_ids: - body["ports"] = port_ids - trigger = devices.bounceDevicePort( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info( - f"Bounce command triggered for ports {port_ids} on device {device_id}" - ) - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" - ) # Give the bounce command a moment to take effect - return util_response - - -def cableTest( - apissession: _APISession, - site_id: str, - device_id: str, - port_id: str, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: EX - - Initiates a cable test on a switch port and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to perform the cable test on. - port_id : str - Port ID to perform the cable test on. - timeout : int, optional - Timeout for the cable test command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"port": port_id} - trigger = devices.cableTestFromSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Cable test command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" - ) # Give the cable test command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/service_path.py b/src/mistapi/device_utils/service_path.py deleted file mode 100644 index 2973c23..0000000 --- a/src/mistapi/device_utils/service_path.py +++ /dev/null @@ -1,90 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -class Node(Enum): - """Node Enum for specifying node information in service path commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - -def showServicePath( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - service_name: str | None = None, - timeout: int = 5, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: SSR - - Initiates a show service path command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the show service path command on. - node : Node, optional - Node information for the show service path command. - service_name : str, optional - Name of the service to show the path for. - timeout : int, optional - Timeout for the command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if service_name: - body["service_name"] = service_name - trigger = devices.showSiteSsrServicePath( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"SSR service path command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR service path command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/sessions.py b/src/mistapi/device_utils/sessions.py deleted file mode 100644 index c019f10..0000000 --- a/src/mistapi/device_utils/sessions.py +++ /dev/null @@ -1,172 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from collections.abc import Callable -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.sites import DeviceCmdEvents - - -class Node(Enum): - """Node Enum for specifying node information in session commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - -def clear( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - service_name: str | None = None, - service_ids: list[str] | None = None, - vrf: str | None = None, - timeout=2, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: SSR, SRX - - Initiates a clear sessions command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the show routes command on. - node : Node, optional - Node information for the show routes command. - prefix : str, optional - Prefix to filter the routes. - protocol : RouteProtocol, optional - Protocol to filter the routes. - route_type : str, optional - Type of the route to filter. - vrf : str, optional - VRF to filter the routes. - timeout : int, optional - Timeout for the command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if service_name: - body["service_name"] = service_name - if service_ids: - body["service_ids"] = service_ids - if vrf: - body["vrf"] = vrf - trigger = devices.clearSiteDeviceSession( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Device Sessions command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" - ) # Give the Device Sessions command a moment to take effect - return util_response - - -def show( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - service_name: str | None = None, - service_ids: list[str] | None = None, - timeout=2, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: SSR, SRX - - Initiates a show sessions command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the show sessions command on. - node : Node, optional - Node information for the show sessions command. - service_name : str, optional - Name of the service to filter the sessions. - service_ids : list[str], optional - List of service IDs to filter the sessions. - timeout : int, optional - Timeout for the command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if service_name: - body["service_name"] = service_name - if service_ids: - body["service_ids"] = service_ids - trigger = devices.showSiteSsrAndSrxSessions( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Device Sessions command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" - ) # Give the Device Sessions command a moment to take effect - return util_response diff --git a/src/mistapi/device_utils/srx.py b/src/mistapi/device_utils/srx.py index a93f124..bdd1f5f 100644 --- a/src/mistapi/device_utils/srx.py +++ b/src/mistapi/device_utils/srx.py @@ -16,7 +16,7 @@ """ # Re-export shared classes and types -from mistapi.device_utils.__tools.arp import Node +from mistapi.device_utils.__tools.__common import Node # ARP functions from mistapi.device_utils.__tools.arp import ( @@ -30,14 +30,21 @@ from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases +# Shell (interactive SSH) +from mistapi.device_utils.__tools.shell import ShellSession +from mistapi.device_utils.__tools.shell import create_shell_session as createShellSession +from mistapi.device_utils.__tools.shell import interactive_shell as interactiveShell + # Tools (ping, monitor traffic) from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic from mistapi.device_utils.__tools.miscellaneous import ping +from mistapi.device_utils.__tools.miscellaneous import top_command as topCommand # OSPF functions -from mistapi.device_utils.__tools.ospf import show_database as showDatabase -from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces -from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors +from mistapi.device_utils.__tools.ospf import show_database as retrieveOspfDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as retrieveOspfInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as retrieveOspfNeighbors +from mistapi.device_utils.__tools.ospf import show_summary as retrieveOspfSummary # Port functions from mistapi.device_utils.__tools.port import bounce as bouncePort @@ -45,6 +52,10 @@ # Route functions from mistapi.device_utils.__tools.routes import show as retrieveRoutes +# Sessions functions +from mistapi.device_utils.__tools.sessions import clear as clearSessions +from mistapi.device_utils.__tools.sessions import show as retrieveSessions + __all__ = [ # Classes/Enums "Node", @@ -56,14 +67,23 @@ "releaseDhcpLeases", "retrieveDhcpLeases", # OSPF - "showDatabase", - "showNeighbors", - "showInterfaces", + "retrieveOspfDatabase", + "retrieveOspfNeighbors", + "retrieveOspfInterfaces", + "retrieveOspfSummary", # Port "bouncePort", # Routes "retrieveRoutes", + # Sessions + "retrieveSessions", + "clearSessions", + # Shell + "ShellSession", + "createShellSession", + "interactiveShell", # Tools "monitorTraffic", "ping", + "topCommand", ] diff --git a/src/mistapi/device_utils/ssr.py b/src/mistapi/device_utils/ssr.py index 0af8df1..6c1f64b 100644 --- a/src/mistapi/device_utils/ssr.py +++ b/src/mistapi/device_utils/ssr.py @@ -16,7 +16,7 @@ """ # Re-export shared classes and types -from mistapi.device_utils.__tools.arp import Node +from mistapi.device_utils.__tools.__common import Node # ARP functions from mistapi.device_utils.__tools.arp import ( @@ -36,9 +36,10 @@ # DNS functions # from mistapi.utils.dns import test_resolution as test_dns_resolution # OSPF functions -from mistapi.device_utils.__tools.ospf import show_database as showDatabase -from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces -from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors +from mistapi.device_utils.__tools.ospf import show_database as retrieveOspfDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as retrieveOspfInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as retrieveOspfNeighbors +from mistapi.device_utils.__tools.ospf import show_summary as retrieveOspfSummary # Port functions from mistapi.device_utils.__tools.port import bounce as bouncePort @@ -51,6 +52,10 @@ show_service_path as showServicePath, ) +# Sessions functions +from mistapi.device_utils.__tools.sessions import clear as clearSessions +from mistapi.device_utils.__tools.sessions import show as retrieveSessions + __all__ = [ # Classes/Enums "Node", @@ -64,15 +69,19 @@ # DNS # "test_dns_resolution", # OSPF - "showDatabase", - "showNeighbors", - "showInterfaces", + "retrieveOspfDatabase", + "retrieveOspfNeighbors", + "retrieveOspfInterfaces", + "retrieveOspfSummary", # Port "bouncePort", # Routes "retrieveRoutes", # Service Path "showServicePath", + # Sessions + "retrieveSessions", + "clearSessions", # Tools "ping", ] diff --git a/src/mistapi/device_utils/tools.py b/src/mistapi/device_utils/tools.py deleted file mode 100644 index 8a95822..0000000 --- a/src/mistapi/device_utils/tools.py +++ /dev/null @@ -1,789 +0,0 @@ -from collections.abc import Callable -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices, pcaps -from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper -from mistapi.websockets.session import SessionWithUrl -from mistapi.websockets.sites import DeviceCmdEvents, PcapEvents - - -class Node(Enum): - """Node Enum for specifying node information in commands.""" - - NODE0 = "node0" - NODE1 = "node1" - - -class TracerouteProtocol(Enum): - """Enum for specifying protocol in traceroute command.""" - - ICMP = "icmp" - UDP = "udp" - - -def _build_pcap_body( - device_id: str, - port_ids: list[str], - device_key: str, - device_type: str, - tcpdump_expression: str | None, - duration: int, - max_pkt_len: int, - num_packets: int, - raw: bool | None = None, -) -> dict: - """Build the request body for remote pcap commands (SRX, SSR, EX).""" - mac = device_id.split("-")[-1] - body: dict = { - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - device_key: {mac: {"ports": {}}}, - "type": device_type, - "format": "stream", - } - if raw is not None: - body["raw"] = raw - for port_id in port_ids: - port_entry: dict = {} - if tcpdump_expression is not None: - port_entry["tcpdump_expression"] = tcpdump_expression - body[device_key][mac]["ports"][port_id] = port_entry - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression - return body - - -def ping( - apissession: _APISession, - site_id: str, - device_id: str, - host: str, - count: int | None = None, - node: Node | None = None, - size: int | None = None, - vrf: str | None = None, - timeout: int = 3, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: AP, EX, SRX, SSR - - Initiates a ping command from a device (AP / EX/ SRX / SSR) to a specified host and - streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to initiate the ping from. - host : str - The host to ping. - count : int, optional - Number of ping requests to send. - node : None, optional - Node information for the ping command. - size : int, optional - Size of the ping packet. - vrf : str, optional - VRF to use for the ping command. - timeout : int, optional - Timeout for the ping command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if count: - body["count"] = count - if host: - body["host"] = host - if node: - body["node"] = node.value - if size: - body["size"] = size - if vrf: - body["vrf"] = vrf - trigger = devices.pingFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Ping command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" - ) # Give the ping command a moment to take effect - return util_response - - -## NO DATA -# def service_ping( -# apissession: _APISession, -# site_id: str, -# device_id: str, -# host: str, -# service: str, -# tenant: str, -# count: int | None = None, -# node: None | None = None, -# size: int | None = None, -# timeout: int = 3, -# on_message: Callable[[dict], None] | None = None, -# ) -> UtilResponse: -# """ -# DEVICES: SSR - -# Initiates a service ping command from a SSR to a specified host and streams the results. - -# PARAMS -# ----------- -# apissession : _APISession -# The API session to use for the request. -# site_id : str -# UUID of the site where the device is located. -# device_id : str -# UUID of the device to initiate the ping from. -# host : str -# The host to ping. -# service : str -# The service to ping. -# tenant : str -# Tenant to use for the ping command. -# count : int, optional -# Number of ping requests to send. -# node : None, optional -# Node information for the ping command. -# size : int, optional -# Size of the ping packet. -# timeout : int, optional -# Timeout for the ping command in seconds. -# on_message : Callable, optional -# Callback invoked with each extracted raw message as it arrives. - -# RETURNS -# ----------- -# UtilResponse -# A UtilResponse object containing the API response and a list of raw messages received -# from the WebSocket stream. -# """ -# body: dict[str, str | list | int] = {} -# if count: -# body["count"] = count -# if host: -# body["host"] = host -# if node: -# body["node"] = node.value -# if size: -# body["size"] = size -# if tenant: -# body["tenant"] = tenant -# if service: -# body["service"] = service -# trigger = devices.servicePingFromSsr( -# apissession, -# site_id=site_id, -# device_id=device_id, -# body=body, -# ) -# util_response = UtilResponse(trigger) -# if trigger.status_code == 200: -# LOGGER.info(f"Service Ping command triggered for device {device_id}") -# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) -# util_response = WebSocketWrapper( -# apissession, util_response, timeout, on_message=on_message -# ).start(ws) -# else: -# LOGGER.error( -# f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" -# ) # Give the ping command a moment to take effect -# return util_response - - -def traceroute( - apissession: _APISession, - site_id: str, - device_id: str, - host: str, - protocol: TracerouteProtocol = TracerouteProtocol.ICMP, - port: int | None = None, - timeout: int = 10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICES: AP, EX, SRX, SSR - - Initiates a traceroute command from a device (AP / EX/ SRX / SSR) to a specified host and - streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to initiate the traceroute from. - host : str - The host to traceroute. - protocol : TracerouteProtocol, optional - Protocol to use for the traceroute command (icmp or udp). - port : int, optional - Port to use for UDP traceroute. - timeout : int, optional - Timeout for the traceroute command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"host": host} - if protocol: - body["protocol"] = protocol.value - if port: - body["port"] = port - trigger = devices.tracerouteFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Traceroute command triggered for device {device_id}") - ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) - util_response = WebSocketWrapper( - apissession, util_response, timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" - ) # Give the traceroute command a moment to take effect - return util_response - - -def monitorTraffic( - apissession: _APISession, - site_id: str, - device_id: str, - port_id: str | None = None, - timeout=30, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: EX, SRX - - Initiates a monitor traffic command on the device and streams the results. - - * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular - * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic - on all ports - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to monitor traffic on. - port_id : str, optional - Port ID to filter the traffic. - timeout : int, optional - Timeout for the monitor traffic command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | int] = {"duration": 60} - if port_id: - body["port"] = port_id - trigger = devices.monitorSiteDeviceTraffic( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Monitor traffic command triggered for device {device_id}") - ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" - ) # Give the monitor traffic command a moment to take effect - return util_response - - -def apRemotePcapWireless( - apissession: _APISession, - site_id: str, - device_id: str, - band: str, - tcpdump_expression: str | None = None, - ssid: str | None = None, - ap_mac: str | None = None, - duration: int = 600, - max_pkt_len: int = 512, - num_packets: int = 1024, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: AP - - Initiates a remote pcap command on the device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to run remote pcap on. - band : str - Comma-separated list of radio bands (24, 5, or 6). - tcpdump_expression : str, optional - Tcpdump expression to filter the captured traffic. - e.g. "type mgt or type ctl -vvv -tttt -en" - ssid : str, optional - SSID to filter the wireless traffic. - ap_mac : str, optional - AP MAC address to filter the wireless traffic. - duration : int, optional - Duration of the remote pcap in seconds (default: 600). - max_pkt_len : int, optional - Maximum packet length to capture (default: 512). - num_packets : int, optional - Maximum number of packets to capture (default: 1024). - timeout : int, optional - Timeout for the remote pcap command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | int] = { - "band": band, - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - "type": "radiotap", - "format": "stream", - } - if ssid: - body["ssid"] = ssid - if ap_mac: - body["ap_mac"] = ap_mac - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response - - -def apRemotePcapWired( - apissession: _APISession, - site_id: str, - device_id: str, - tcpdump_expression: str | None = None, - duration: int = 600, - max_pkt_len: int = 512, - num_packets: int = 1024, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: AP - - Initiates a remote pcap command on the device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to run remote pcap on. - tcpdump_expression : str, optional - Tcpdump expression to filter the captured traffic. - e.g. "udp port 67 or udp port 68 -vvv -tttt -en" - duration : int, optional - Duration of the remote pcap in seconds (default: 600). - max_pkt_len : int, optional - Maximum packet length to capture (default: 512). - num_packets : int, optional - Maximum number of packets to capture (default: 1024). - timeout : int, optional - Timeout for the remote pcap command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body: dict[str, str | int] = { - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - "type": "wired", - "format": "stream", - } - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response - - -def srxRemotePcap( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - tcpdump_expression: str | None = None, - duration: int = 600, - max_pkt_len: int = 512, - num_packets: int = 1024, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: SRX - - Initiates a remote pcap command on the device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to run remote pcap on. - port_ids : list[str] - List of port IDs to monitor. - tcpdump_expression : str, optional - Tcpdump expression to filter the captured traffic. - e.g. "udp port 67 or udp port 68 -vvv -tttt -en" - duration : int, optional - Duration of the remote pcap in seconds (default: 600). - max_pkt_len : int, optional - Maximum packet length to capture (default: 512). - num_packets : int, optional - Maximum number of packets to capture (default: 1024). - timeout : int, optional - Timeout for the remote pcap command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body = _build_pcap_body( - device_id, - port_ids, - "gateways", - "gateway", - tcpdump_expression, - duration, - max_pkt_len, - num_packets, - ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response - - -def ssrRemotePcap( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - tcpdump_expression: str | None = None, - duration: int = 600, - max_pkt_len: int = 512, - num_packets: int = 1024, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: SSR - - Initiates a remote pcap command on the device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to run remote pcap on. - port_ids : list[str] - List of port IDs to monitor. - tcpdump_expression : str, optional - Tcpdump expression to filter the captured traffic. - e.g. "udp port 67 or udp port 68 -vvv -tttt -en" - duration : int, optional - Duration of the remote pcap in seconds (default: 600). - max_pkt_len : int, optional - Maximum packet length to capture (default: 512). - num_packets : int, optional - Maximum number of packets to capture (default: 1024). - timeout : int, optional - Timeout for the remote pcap command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body = _build_pcap_body( - device_id, - port_ids, - "gateways", - "gateway", - tcpdump_expression, - duration, - max_pkt_len, - num_packets, - raw=False, - ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response - - -def exRemotePcap( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - tcpdump_expression: str | None = None, - duration: int = 600, - max_pkt_len: int = 512, - num_packets: int = 1024, - timeout=10, - on_message: Callable[[dict], None] | None = None, -) -> UtilResponse: - """ - DEVICE: EX - - Initiates a remote pcap command on the device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to run remote pcap on. - port_ids : list[str] - List of port IDs to monitor. - tcpdump_expression : str, optional - Tcpdump expression to filter the captured traffic. - e.g. "udp port 67 or udp port 68 -vvv -tttt -en" - duration : int, optional - Duration of the remote pcap in seconds (default: 600). - max_pkt_len : int, optional - Maximum packet length to capture (default: 512). - num_packets : int, optional - Maximum number of packets to capture (default: 1024). - timeout : int, optional - Timeout for the remote pcap command in seconds. - on_message : Callable, optional - Callback invoked with each extracted raw message as it arrives. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received - from the WebSocket stream. - """ - body = _build_pcap_body( - device_id, - port_ids, - "switches", - "switch", - tcpdump_expression, - duration, - max_pkt_len, - num_packets, - ) - trigger = pcaps.startSitePacketCapture( - apissession, - site_id=site_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Remote pcap command triggered for device {device_id}") - ws = PcapEvents(apissession, site_id=site_id) - util_response = WebSocketWrapper( - apissession, util_response, timeout=timeout, on_message=on_message - ).start(ws) - else: - LOGGER.error( - f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" - ) # Give the remote pcap command a moment to take effect - return util_response - - -## NO DATA -# def srx_top_command( -# apissession: _APISession, -# site_id: str, -# device_id: str, -# timeout=10, -# on_message: Callable[[dict], None] | None = None, -# ) -> UtilResponse: -# """ -# DEVICE: SRX - -# For SRX Only. Initiates a top command on the device and streams the results. - -# PARAMS -# ----------- -# apissession : _APISession -# The API session to use for the request. -# site_id : str -# UUID of the site where the device is located. -# device_id : str -# UUID of the device to run the top command on. -# timeout : int, optional -# Timeout for the top command in seconds. -# on_message : Callable, optional -# Callback invoked with each extracted raw message as it arrives. - -# RETURNS -# ----------- -# UtilResponse -# A UtilResponse object containing the API response and a list of raw messages received -# from the WebSocket stream. -# """ -# trigger = devices.runSiteSrxTopCommand( -# apissession, -# site_id=site_id, -# device_id=device_id, -# ) -# util_response = UtilResponse(trigger) -# if trigger.status_code == 200: -# LOGGER.info(trigger.data) -# print(f"Top command triggered for device {device_id}") -# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) -# util_response = WebSocketWrapper( -# apissession, util_response, timeout=timeout, on_message=on_message -# ).start(ws) -# else: -# LOGGER.error( -# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" -# ) # Give the top command a moment to take effect -# return util_response diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 237b056..a60d736 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -139,10 +139,12 @@ def _handle_open(self, ws: websocket.WebSocketApp) -> None: if self._on_open_cb: self._on_open_cb() - def _handle_message(self, ws: websocket.WebSocketApp, message: str) -> None: + def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> None: + if isinstance(message, bytes): + message = message.replace(b"\x00", b"").decode("utf-8", errors="replace") try: data = json.loads(message) - except json.JSONDecodeError: + except (json.JSONDecodeError, TypeError): data = {"raw": message} self._queue.put(data) if self._on_message_cb: diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py index 8b87801..4cff41f 100644 --- a/src/mistapi/websockets/session.py +++ b/src/mistapi/websockets/session.py @@ -65,9 +65,13 @@ def __init__( ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + self._url = url super().__init__( mist_session, - channels=[url], + channels=[], ping_interval=ping_interval, ping_timeout=ping_timeout, ) + + def _build_ws_url(self) -> str: + return self._url diff --git a/tests/unit/test_api_response.py b/tests/unit/test_api_response.py index 1c283fa..65d95a5 100644 --- a/tests/unit/test_api_response.py +++ b/tests/unit/test_api_response.py @@ -37,6 +37,7 @@ def _make_mock_response(status_code=200, data=None, headers=None, json_raises=Fa else: payload = data if data is not None else {} mock.content = json.dumps(payload).encode() + mock.text = json.dumps(payload) mock.json.return_value = payload return mock @@ -74,12 +75,12 @@ def test_200_response_with_json(self, api_response_factory): assert resp.url == "https://api.mist.com/api/v1/test" def test_raw_data_is_string_of_content(self): - """raw_data should be str(response.content).""" + """raw_data should be response.text.""" data = {"key": "value"} mock = _make_mock_response(data=data) resp = APIResponse(response=mock, url="https://host/api/v1/x") - assert resp.raw_data == str(json.dumps(data).encode()) + assert resp.raw_data == json.dumps(data) def test_proxy_error_true(self): """proxy_error=True should be stored on the instance.""" diff --git a/tests/unit/test_shell.py b/tests/unit/test_shell.py new file mode 100644 index 0000000..73c9ef6 --- /dev/null +++ b/tests/unit/test_shell.py @@ -0,0 +1,338 @@ +# tests/unit/test_shell.py +""" +Unit tests for ShellSession and create_shell_session. +""" + +import json +from unittest.mock import Mock, patch + +import pytest +import websocket + +from mistapi.api.v1.sites import devices as devices_module +from mistapi.device_utils.__tools import shell as shell_module +from mistapi.device_utils.__tools.shell import ShellSession, create_shell_session + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_apisession(): + session = Mock() + session._apitoken = ["test-token-abc123"] + session._apitoken_index = 0 + session._session = Mock() + session._session.verify = True + session._session.cert = None + session._session.cookies = [] + return session + + +@pytest.fixture +def shell_session(mock_apisession): + return ShellSession(mock_apisession, "wss://example.com/shell") + + +@pytest.fixture +def mock_ws(): + ws = Mock() + ws.connected = True + ws.gettimeout.return_value = 0.1 + return ws + + +# ------------------------------------------------------------------ +# Auth helpers +# ------------------------------------------------------------------ + + +class TestAuthHelpers: + """Tests for auth/SSL helper methods.""" + + def test_get_headers_with_api_token(self, shell_session) -> None: + headers = shell_session._get_headers() + assert headers == ["Authorization: Token test-token-abc123"] + + def test_get_headers_without_api_token(self, mock_apisession) -> None: + mock_apisession._apitoken = [] + session = ShellSession(mock_apisession, "wss://example.com/shell") + assert session._get_headers() == [] + + def test_get_cookie_with_cookies(self, mock_apisession) -> None: + cookie1 = Mock() + cookie1.name = "session_id" + cookie1.value = "abc123" + cookie2 = Mock() + cookie2.name = "csrf" + cookie2.value = "xyz789" + mock_apisession._session.cookies = [cookie1, cookie2] + session = ShellSession(mock_apisession, "wss://example.com/shell") + assert session._get_cookie() == "session_id=abc123; csrf=xyz789" + + def test_get_cookie_without_cookies(self, shell_session) -> None: + assert shell_session._get_cookie() is None + + def test_get_cookie_skips_crlf(self, mock_apisession) -> None: + bad_cookie = Mock() + bad_cookie.name = "bad\rcookie" + bad_cookie.value = "val" + mock_apisession._session.cookies = [bad_cookie] + session = ShellSession(mock_apisession, "wss://example.com/shell") + assert session._get_cookie() is None + + def test_build_sslopt_verify_false(self, mock_apisession) -> None: + import ssl + + mock_apisession._session.verify = False + session = ShellSession(mock_apisession, "wss://example.com/shell") + assert session._build_sslopt()["cert_reqs"] == ssl.CERT_NONE + + def test_build_sslopt_custom_ca(self, mock_apisession) -> None: + mock_apisession._session.verify = "/path/to/ca.pem" + session = ShellSession(mock_apisession, "wss://example.com/shell") + assert session._build_sslopt()["ca_certs"] == "/path/to/ca.pem" + + def test_build_sslopt_client_cert_tuple(self, mock_apisession) -> None: + mock_apisession._session.cert = ("/path/cert.pem", "/path/key.pem") + session = ShellSession(mock_apisession, "wss://example.com/shell") + sslopt = session._build_sslopt() + assert sslopt["certfile"] == "/path/cert.pem" + assert sslopt["keyfile"] == "/path/key.pem" + + +# ------------------------------------------------------------------ +# Lifecycle +# ------------------------------------------------------------------ + + +class TestLifecycle: + """Tests for connect/disconnect/connected.""" + + def test_connect_calls_create_connection( + self, shell_session, mock_ws + ) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ) as mock_create: + shell_session.connect() + + mock_create.assert_called_once_with( + "wss://example.com/shell", + header=["Authorization: Token test-token-abc123"], + cookie=None, + sslopt={}, + ) + + def test_connect_sends_resize(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + + mock_ws.send.assert_called_once_with( + json.dumps({"resize": {"width": 80, "height": 24}}) + ) + + def test_disconnect_closes_ws(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + shell_session.disconnect() + + mock_ws.close.assert_called_once() + assert shell_session._ws is None + + def test_disconnect_without_connect_is_safe(self, shell_session) -> None: + shell_session.disconnect() # Should not raise + + def test_connected_false_before_connect(self, shell_session) -> None: + assert shell_session.connected is False + + def test_connected_true_after_connect(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + assert shell_session.connected is True + + +# ------------------------------------------------------------------ +# I/O +# ------------------------------------------------------------------ + + +class TestIO: + """Tests for send/recv/resize.""" + + def test_send_binary(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + shell_session.send(b"\x00hello") + mock_ws.send_binary.assert_called_once_with(b"\x00hello") + + def test_send_text_prefixes_null(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + shell_session.send_text("ls\r\n") + called_data = mock_ws.send_binary.call_args[0][0] + assert called_data == b"\x00ls\r\n" + + def test_recv_returns_bytes(self, shell_session, mock_ws) -> None: + mock_ws.recv.return_value = b"output data" + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + result = shell_session.recv() + assert result == b"output data" + + def test_recv_converts_str_to_bytes(self, shell_session, mock_ws) -> None: + mock_ws.recv.return_value = "text output" + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + result = shell_session.recv() + assert result == b"text output" + + def test_recv_returns_none_on_timeout(self, shell_session, mock_ws) -> None: + mock_ws.recv.side_effect = websocket.WebSocketTimeoutException() + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + result = shell_session.recv() + assert result is None + + def test_recv_returns_none_on_closed(self, shell_session, mock_ws) -> None: + mock_ws.recv.side_effect = websocket.WebSocketConnectionClosedException() + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + result = shell_session.recv() + assert result is None + + def test_recv_returns_none_when_not_connected(self, shell_session) -> None: + assert shell_session.recv() is None + + def test_resize_sends_json(self, shell_session, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + shell_session.connect() + mock_ws.send.reset_mock() # clear initial resize from connect() + + shell_session.resize(40, 120) + mock_ws.send.assert_called_once_with( + json.dumps({"resize": {"width": 120, "height": 40}}) + ) + + +# ------------------------------------------------------------------ +# Context manager +# ------------------------------------------------------------------ + + +class TestContextManager: + """Tests for context manager support.""" + + def test_exit_calls_disconnect(self, mock_apisession, mock_ws) -> None: + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ): + session = ShellSession(mock_apisession, "wss://example.com/shell") + session.connect() + + with session: + pass + + mock_ws.close.assert_called_once() + + +# ------------------------------------------------------------------ +# create_shell_session +# ------------------------------------------------------------------ + + +class TestCreateShellSession: + """Tests for the create_shell_session factory.""" + + def test_happy_path(self, mock_apisession, mock_ws) -> None: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.data = {"url": "wss://example.com/shell/abc"} + + with patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ), patch.object( + devices_module, + "createSiteDeviceShellSession", + return_value=mock_response, + ) as mock_shell_api: + session = create_shell_session(mock_apisession, "site-1", "device-1") + + assert isinstance(session, ShellSession) + mock_shell_api.assert_called_once_with( + mock_apisession, site_id="site-1", device_id="device-1", body={} + ) + + def test_api_failure_raises(self, mock_apisession) -> None: + mock_response = Mock() + mock_response.status_code = 403 + mock_response.data = {"error": "forbidden"} + + with patch.object( + devices_module, + "createSiteDeviceShellSession", + return_value=mock_response, + ): + with pytest.raises(RuntimeError, match="Shell API call failed"): + create_shell_session(mock_apisession, "site-1", "device-1") + + def test_missing_url_raises(self, mock_apisession) -> None: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.data = {"session": "abc"} # no "url" key + + with patch.object( + devices_module, + "createSiteDeviceShellSession", + return_value=mock_response, + ): + with pytest.raises(RuntimeError, match="did not contain a WebSocket URL"): + create_shell_session(mock_apisession, "site-1", "device-1") diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index 5069567..79bb829 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -372,6 +372,26 @@ def test_calls_on_message_callback_with_raw_fallback(self, ws_client) -> None: def test_no_error_without_on_message_callback(self, ws_client) -> None: ws_client._handle_message(Mock(), '{"ok": true}') # Should not raise + def test_decodes_binary_frame_to_str(self, ws_client) -> None: + ws_client._handle_message(Mock(), b"hello binary") + item = ws_client._queue.get_nowait() + assert item == {"raw": "hello binary"} + + def test_strips_null_bytes_from_binary(self, ws_client) -> None: + ws_client._handle_message(Mock(), b"\x00hello\x00world") + item = ws_client._queue.get_nowait() + assert item == {"raw": "helloworld"} + + def test_binary_valid_json_is_parsed(self, ws_client) -> None: + ws_client._handle_message(Mock(), b'{"event": "data", "key": "value"}') + item = ws_client._queue.get_nowait() + assert item == {"event": "data", "key": "value"} + + def test_binary_with_invalid_utf8_uses_replacement(self, ws_client) -> None: + ws_client._handle_message(Mock(), b"hello\xff\xfeworld") + item = ws_client._queue.get_nowait() + assert item["raw"] == "hello\ufffd\ufffdworld" + class TestHandleError: """Tests for _handle_error().""" @@ -773,7 +793,8 @@ class TestSessionChannel: def test_session_with_url_channels(self, mock_session) -> None: ws = SessionWithUrl(mock_session, url="wss://example.com/custom") - assert ws._channels == ["wss://example.com/custom"] + assert ws._channels == [] + assert ws._build_ws_url() == "wss://example.com/custom" def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = SessionWithUrl(mock_session, url="wss://example.com/custom") diff --git a/tests/unit/test_ws_wrapper.py b/tests/unit/test_ws_wrapper.py new file mode 100644 index 0000000..fce5e3b --- /dev/null +++ b/tests/unit/test_ws_wrapper.py @@ -0,0 +1,185 @@ +# tests/unit/test_ws_wrapper.py +""" +Unit tests for WebSocketWrapper._extract_raw() — ANSI stripping (stream mode) +and VT100 screen-buffer rendering (screen mode). +""" + +from unittest.mock import Mock + +from mistapi.device_utils.__tools.__ws_wrapper import ( + WebSocketWrapper, + UtilResponse, + _VT100Screen, +) + + +def _make_wrapper(): + """Create a minimal WebSocketWrapper for testing _extract_raw.""" + session = Mock() + util_response = UtilResponse() + return WebSocketWrapper(session, util_response) + + +# ------------------------------------------------------------------ +# Stream-mode tests (no cursor positioning → ANSI strip) +# ------------------------------------------------------------------ +class TestExtractRawStreamMode: + """Messages without cursor positioning use simple ANSI stripping.""" + + def test_preserves_plain_text(self) -> None: + wrapper = _make_wrapper() + msg = {"raw": "64 bytes from 8.8.8.8: icmp_seq=1 ttl=118 time=12.3 ms"} + result = wrapper._extract_raw(msg) + assert result == "64 bytes from 8.8.8.8: icmp_seq=1 ttl=118 time=12.3 ms" + + def test_strips_sgr_color_codes(self) -> None: + wrapper = _make_wrapper() + msg = {"raw": "text\x1b[0m here\x1b[1;32m green"} + result = wrapper._extract_raw(msg) + assert result == "text here green" + + def test_strips_character_set_designations(self) -> None: + wrapper = _make_wrapper() + msg = {"raw": "hello\x1b(Bworld"} + result = wrapper._extract_raw(msg) + assert result == "helloworld" + + def test_nested_data_event_stripped(self) -> None: + wrapper = _make_wrapper() + msg = {"event": "data", "data": {"raw": "text\x1b[0m here"}} + result = wrapper._extract_raw(msg) + assert result == "text here" + + def test_pcap_dict_unaffected(self) -> None: + wrapper = _make_wrapper() + msg = {"pcap_dict": {"packet": "data"}} + result = wrapper._extract_raw(msg) + assert result == {"packet": "data"} + + +# ------------------------------------------------------------------ +# Screen-mode tests (cursor positioning detected → VT100 screen buffer) +# ------------------------------------------------------------------ +class TestExtractRawScreenMode: + """Messages with cursor positioning / clear-screen use the VT100 buffer.""" + + def test_activates_on_cursor_home(self) -> None: + wrapper = _make_wrapper() + msg = {"raw": "\x1b[H\x1b[2JHello World"} + wrapper._extract_raw(msg) + assert wrapper._screen_mode is True + assert wrapper._screen is not None + + def test_renders_clear_screen_then_text(self) -> None: + wrapper = _make_wrapper() + msg = {"raw": "\x1b[H\x1b[2JLine1\r\nLine2"} + result = wrapper._extract_raw(msg) + assert "Line1" in result + assert "Line2" in result + + def test_cursor_positioning_places_text(self) -> None: + wrapper = _make_wrapper() + # Row 1 col 1: "A", then row 2 col 5: "B" + msg = {"raw": "\x1b[1;1HA\x1b[2;5HB"} + result = wrapper._extract_raw(msg) + lines = result.split("\n") + assert lines[0] == "A" + assert lines[1] == " B" + + def test_in_place_update_overwrites(self) -> None: + """Subsequent messages update the screen buffer in place.""" + wrapper = _make_wrapper() + # First message: draw initial screen + wrapper._extract_raw({"raw": "\x1b[H\x1b[2J\x1b[1;1HOLD VALUE"}) + # Second message: overwrite at same position + result = wrapper._extract_raw({"raw": "\x1b[1;1HNEW VALUE"}) + lines = result.split("\n") + assert lines[0] == "NEW VALUE" + assert "OLD" not in lines[0] + + def test_clear_line_erases_content(self) -> None: + wrapper = _make_wrapper() + wrapper._extract_raw({"raw": "\x1b[1;1HFull line of text"}) + result = wrapper._extract_raw({"raw": "\x1b[1;1H\x1b[2K"}) + lines = result.split("\n") + # Line 1 should be empty after clear + assert lines[0] == "" if lines else True + + def test_stream_mode_not_activated_by_plain_text(self) -> None: + wrapper = _make_wrapper() + wrapper._extract_raw({"raw": "no escape codes here"}) + assert wrapper._screen_mode is False + assert wrapper._screen is None + + +# ------------------------------------------------------------------ +# _VT100Screen unit tests +# ------------------------------------------------------------------ +class TestVT100Screen: + """Direct tests for the _VT100Screen class.""" + + def test_simple_text(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("Hello") + assert s.render() == "Hello" + + def test_newline(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("A\nB\nC") + assert s.render() == "A\nB\nC" + + def test_carriage_return_newline(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("A\r\nB") + assert s.render() == "A\nB" + + def test_cursor_position(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("\x1b[3;10Hx") + lines = s.render().split("\n") + assert len(lines) >= 3 + assert lines[2][9] == "x" + + def test_cursor_home(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("AAAA\x1b[HB") + lines = s.render().split("\n") + assert lines[0].startswith("BAAA") + + def test_clear_screen(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("old text\x1b[2Jnew") + rendered = s.render() + assert "old" not in rendered + assert "new" in rendered + + def test_cursor_movement(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("\x1b[1;1HABCDE") + # Move back 3 + s.feed("\x1b[3DX") + lines = s.render().split("\n") + assert lines[0] == "ABXDE" + + def test_erase_to_end_of_line(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("Hello World") + s.feed("\x1b[1;6H\x1b[K") # Position at col 6, erase to EOL + assert s.render() == "Hello" + + def test_sgr_ignored(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("\x1b[1;32mGreen\x1b[0m") + assert s.render() == "Green" + + def test_null_bytes_ignored(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("\x00Hello\x00") + assert s.render() == "Hello" + + def test_render_trims_trailing_spaces_and_empty_lines(self) -> None: + s = _VT100Screen(rows=5, cols=20) + s.feed("A") + rendered = s.render() + assert rendered == "A" + assert not rendered.endswith(" ") diff --git a/uv.lock b/uv.lock index 11c66d0..4c81864 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "mistapi" -version = "0.61.0" +version = "0.61.1" source = { editable = "." } dependencies = [ { name = "deprecation" }, @@ -545,6 +545,7 @@ dependencies = [ { name = "keyring" }, { name = "python-dotenv" }, { name = "requests" }, + { name = "sshkeyboard" }, { name = "tabulate" }, { name = "websocket-client" }, ] @@ -569,6 +570,7 @@ requires-dist = [ { name = "keyring", specifier = ">=24.3.0" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "sshkeyboard", specifier = ">=2.3.1" }, { name = "tabulate", specifier = ">=0.9.0" }, { name = "websocket-client", specifier = ">=1.8.0" }, ] @@ -903,6 +905,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/46/f5af3402b579fd5e11573ce652019a67074317e18c1935cc0b4ba9b35552/secretstorage-3.5.0-py3-none-any.whl", hash = "sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137", size = 15554, upload-time = "2025-11-23T19:02:51.545Z" }, ] +[[package]] +name = "sshkeyboard" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/7b/d78e6ade4bb4680d0a610ed02047c4de04db62de8864193bf842c59c47cb/sshkeyboard-2.3.1.tar.gz", hash = "sha256:3273be5b2fde7f8d2ea075d40e1981104ac0928d7b77a848756f08d0e66a3d9f", size = 20296, upload-time = "2021-10-28T11:37:08.07Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/b6/d24c6184348a91386e5ca5fe3e46668ef978dd694ff98574ef0dd5904522/sshkeyboard-2.3.1-py3-none-any.whl", hash = "sha256:05ec2cb116bd9c4a7c17d7add4a5af74e12eb2add24c110608cb90f6812692f9", size = 12870, upload-time = "2021-10-28T11:37:05.566Z" }, +] + [[package]] name = "tabulate" version = "0.9.0"