Source code for axiom.backends.universal

"""
llm_engine/universal_client.py

Universal OpenAI-compatible API client for Axiom AI.
Supports any local/remote backend that implements the OpenAI /v1/chat/completions API
(e.g., LM Studio, KoboldCPP, Ollama, standard OpenAI, etc.).
"""

from __future__ import annotations

import json
from typing import Callable, Iterator

import httpx

from axiom.backends.base import LLMBackend, LLMConnectionError, LLMMessage, LLMResponse
from axiom.backends.transport import IPv4FirstTransport
from axiom.logger import logger

_DEFAULT_TIMEOUT: float = 600.0
# Per-address connect timeout: without it the scalar 600s also applied to the
# TCP connect phase, so a broken IPv6 route stalled for minutes (see
# axiom/backends/transport.py for the long story).
_CONNECT_TIMEOUT: float = 5.0

# Statuses that mean "this key is dead or out of budget" — worth retrying the
# same request with the next key of the pool (TICKET-062: shared beta keys).
# 401 revoked, 402 payment required, 403 suspended, 429 rate/quota exhausted.
_KEY_ROTATION_STATUSES: frozenset[int] = frozenset({401, 402, 403, 429})

# Reasoning models (gpt-oss, OpenAI o-series, DeepSeek R/v4, QwQ, *-thinking)
# spend tokens on a hidden chain-of-thought (`reasoning_content`) BEFORE the
# answer lands in `content`. `max_tokens` caps the total, so a budget sized
# for plain models (150-600 in the arbitrator) gets eaten by the reasoning
# and `content` is never produced (TICKET-066). The cap is billed on actual
# usage, so flooring it higher is free when the model stops early — it only
# prevents that truncation. deepseek-v4-flash confirmed reasoning by live
# probe on Fireworks, 2026-06-12.
_REASONING_MODEL_HINTS: tuple[str, ...] = (
    "gpt-oss", "deepseek-r", "deepseek-v4", "qwq", "-thinking", "-reasoning",
)
_REASONING_TOKEN_FLOOR: int = 2048


def _is_reasoning_model(model_name: str) -> bool:
    """Heuristic match on the model id (after the last '/')."""
    tail = model_name.lower().rsplit("/", 1)[-1]
    # OpenAI o-series ids start with the family name (o1-mini, o3, o4-mini…).
    if tail.startswith(("o1", "o3", "o4")):
        return True
    return any(hint in tail for hint in _REASONING_MODEL_HINTS)


[docs] class UniversalClient(LLMBackend): """OpenAI-compatible LLM client using httpx. Args: base_url: The base URL (e.g., http://localhost:1234/v1). api_key: Optional API key for authorization. model_name: The model identifier to request. extra_headers: Optional headers merged into every request. Lets a provider use a non-Bearer auth scheme (e.g. Anthropic's x-api-key + anthropic-version) — pass api_key="" then. max_stop_sequences: Optional cap on the number of stop sequences sent (OpenAI rejects more than 4; most providers have no limit). fallback_api_keys: Optional pool of spare Bearer keys. When a request fails with an auth/quota status (401/402/403/429), the client switches to the next key — stickily — and retries, until the pool is exhausted (TICKET-062: shared beta keys). Only meaningful with Bearer auth (`api_key`). """ def __init__( self, base_url: str, api_key: str, model_name: str, extra_headers: dict[str, str] | None = None, max_stop_sequences: int | None = None, fallback_api_keys: list[str] | None = None, ) -> None: self.base_url = base_url.rstrip("/") self.api_key = api_key self.model_name = model_name self.extra_headers = dict(extra_headers) if extra_headers else {} self.max_stop_sequences = max_stop_sequences # Ordered, deduplicated key pool; index 0 is the active key. pool = [api_key] + [k for k in (fallback_api_keys or []) if k] self._api_keys = list(dict.fromkeys(k for k in pool if k)) or [api_key] self._key_index = 0 self._client = httpx.Client( base_url=self.base_url, headers=self._get_headers(), timeout=httpx.Timeout(_DEFAULT_TIMEOUT, connect=_CONNECT_TIMEOUT), transport=IPv4FirstTransport(), ) def _rotate_key(self) -> bool: """Switch to the next key of the pool (sticky). False when exhausted.""" if self._key_index + 1 >= len(self._api_keys): return False self._key_index += 1 self.api_key = self._api_keys[self._key_index] self._client.headers["Authorization"] = f"Bearer {self.api_key}" logger.warning( "LLM API key rejected/exhausted — switching to spare key %d/%d.", self._key_index + 1, len(self._api_keys), ) return True def _send_with_rotation(self, send: "Callable[[], httpx.Response]") -> httpx.Response: """Run `send()`, retrying on the next key after an auth/quota status.""" while True: response = send() if response.status_code in _KEY_ROTATION_STATUSES and self._rotate_key(): continue return response def _get_headers(self) -> dict[str, str]: headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" headers.update(self.extra_headers) return headers @staticmethod def _format_status_error(exc: httpx.HTTPStatusError) -> str: """Build an error message that includes the provider's response body. A bare "400 Bad Request" hides the actionable cause (unknown model, rejected parameter…) — cloud providers put it in the JSON body. """ try: exc.response.read() # streamed responses: body not read yet except Exception: pass try: body = " ".join(exc.response.text.split())[:300] except Exception: body = "" message = ( f"LLM API error {exc.response.status_code} from " f"{exc.request.url}: {body or exc}" ) if exc.response.status_code == 404 and "chat/completions" in str(exc.request.url): # Cloud providers answer 404 on generation when the model id is # unknown/retired (the server itself is fine). message += ( " — a 404 here usually means the configured model does not " "exist on this provider. Check the Model field in Settings." ) return message def _get_payload(self, messages: list[LLMMessage], stream: bool, temperature: float = 0.7, top_p: float = 1.0, response_format: str | None = None, stop_sequences: list[str] | None = None, max_tokens: int | None = None) -> dict: budget = max_tokens if max_tokens else 1024 if _is_reasoning_model(self.model_name): budget = max(budget, _REASONING_TOKEN_FLOOR) payload = { "model": self.model_name, "messages": messages, "stream": stream, "temperature": temperature, "top_p": top_p, "max_tokens": budget, } if "gpt-oss" in self.model_name.lower(): # Shorter chain-of-thought = lower latency for the same answers. # Fireworks accepts the parameter (probed live 2026-06-12). payload["reasoning_effort"] = "low" # Merge stop sequences stops = ["</s>", "<|im_end|>", "\n===", "\n###", "\nUser:", "\nPlayer:", "\n[User]"] if stop_sequences: stops.extend(stop_sequences) stops = list(dict.fromkeys(stops)) if self.max_stop_sequences is not None: stops = stops[: self.max_stop_sequences] payload["stop"] = stops if response_format == "json": payload["response_format"] = {"type": "json_object"} return payload
[docs] def complete( self, messages: list[LLMMessage], stream: bool = False, temperature: float = 0.7, top_p: float = 1.0, response_format: str | None = None, stop_sequences: list[str] | None = None, max_tokens: int | None = None, ) -> LLMResponse: """Send a list of messages and return a fully assembled LLMResponse.""" payload = self._get_payload(messages, stream=False, temperature=temperature, top_p=top_p, response_format=response_format, stop_sequences=stop_sequences, max_tokens=max_tokens) try: response = self._send_with_rotation( lambda: self._client.post("/chat/completions", json=payload) ) response.raise_for_status() data = response.json() choice = data["choices"][0] # Reasoning models (gpt-oss, o-series…) put their chain-of-thought # in `reasoning_content` and the answer in `content`. When the token # budget is spent on reasoning, `content` is null or absent — that # is an empty generation, NOT a malformed response, so tolerate it # instead of raising KeyError (TICKET-066). raw_text = choice.get("message", {}).get("content") or "" reason = choice.get("finish_reason", "stop") narrative, tool_call = self.parse_tool_call(raw_text) return LLMResponse( narrative_text=narrative, tool_call=tool_call, finish_reason=reason, ) except httpx.HTTPStatusError as exc: raise LLMConnectionError(self._format_status_error(exc)) from exc except httpx.HTTPError as exc: raise LLMConnectionError(f"Universal API unreachable: {exc}") from exc except (KeyError, json.JSONDecodeError) as exc: raise LLMConnectionError(f"Unexpected response format: {exc}") from exc
[docs] def stream_tokens( self, messages: list[LLMMessage], temperature: float = 0.7, top_p: float = 1.0, response_format: str | None = None, stop_sequences: list[str] | None = None, max_tokens: int | None = None, ) -> Iterator[str]: """Yield individual tokens as they arrive via SSE.""" payload = self._get_payload(messages, stream=True, temperature=temperature, top_p=top_p, response_format=response_format, stop_sequences=stop_sequences, max_tokens=max_tokens) try: while True: with self._client.stream("POST", "/chat/completions", json=payload) as response: # The status arrives before any token: rotating the key # here can never drop already-yielded content. if (response.status_code in _KEY_ROTATION_STATUSES and self._rotate_key()): continue response.raise_for_status() for line in response.iter_lines(): if not line: continue if line.startswith("data: "): data_str = line[len("data: "):] if data_str == "[DONE]": break try: chunk = json.loads(data_str) delta = chunk["choices"][0].get("delta", {}) content = delta.get("content", "") if content: yield content except (json.JSONDecodeError, KeyError, IndexError): continue return except httpx.HTTPStatusError as exc: raise LLMConnectionError(self._format_status_error(exc)) from exc except httpx.HTTPError as exc: raise LLMConnectionError(f"Universal API streaming error: {exc}") from exc
[docs] def is_available(self) -> bool: """Check if the backend is reachable.""" try: response = self._send_with_rotation( lambda: self._client.get("/models", timeout=5.0) ) return response.status_code == 200 except Exception: return False
[docs] def list_models(self) -> list[str]: """Return the model ids exposed by the server's /models endpoint. Empty list on any error — callers use this as a best-effort check (e.g. the settings dialog verifying the configured model exists). """ try: response = self._send_with_rotation( lambda: self._client.get("/models", timeout=10.0) ) response.raise_for_status() data = response.json().get("data", []) return [m["id"] for m in data if isinstance(m, dict) and "id" in m] except Exception: return []