Source code for axiom.memory

"""
llm_engine/vector_memory.py

Local vector-database memory for Axiom AI narrative chunks.

Every piece of narrative embedded here carries a `turn_id` metadata tag.
This enables the surgical rollback required by the Checkpoint system:
when the player rewinds to turn N, all chunks with turn_id > N are
permanently deleted so they cannot bleed into the rebuilt timeline.

Backend: ChromaDB (persistent, local)
Embedding model: sentence-transformers all-MiniLM-L6-v2 (fully offline)

Collection layout
-----------------
Collection name : "narrative_memory"
Document        : the text chunk
Metadata fields : save_id (str), turn_id (int), chunk_type (str)
ID              : UUID string, generated per chunk
"""

import uuid
from typing import Any

# Lazy imports for heavy libraries
try:
    from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
except ImportError:
    SentenceTransformerEmbeddingFunction = None


_COLLECTION_NAME: str = "narrative_memory"
_EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"

_runtime_warned = False


def _warn_runtime_unavailable_once(exc: BaseException) -> None:
    """Log a single, actionable warning when the embedding runtime won't load.

    Typical cause on Windows: torch's native libraries fail to load (WinError
    126) because the Microsoft Visual C++ Redistributable is not installed.
    """
    global _runtime_warned
    if _runtime_warned:
        return
    _runtime_warned = True
    try:
        from axiom.logger import logger

        logger.warning(
            "Semantic memory disabled: the embedding runtime could not load (%s). "
            "Gameplay continues without long-term narrative recall. On Windows this "
            "usually means the Microsoft Visual C++ Redistributable (x64) is missing.",
            exc,
        )
    except Exception:
        pass


[docs] def preload_embedding_runtime() -> bool: """Force torch's native runtime to load on the *calling* (main) thread. The sentence-transformers embedding model is loaded and used on worker threads (VectorInitWorker / NarrativeWorker). The first encode lazily pulls in ``torch._dynamo`` → ``triton``, which ``dlopen()``s ``libtriton.so``. Doing that ``dlopen`` from a secondary thread while Qt is running segfaults (native crash, no Python traceback). Importing it once here, on the main thread at startup, makes the later cross-thread use safe. Call this from the GUI/CLI entry point *before* any worker thread touches VectorMemory. Idempotent, never raises. Returns True if the runtime was pre-loaded, False if torch is unavailable (e.g. headless test stubs). """ try: import torch # noqa: F401 — heavy import, front-loaded on purpose import torch._dynamo # noqa: F401 — triggers the libtriton.so dlopen here return True except Exception: # torch absent or its internal layout changed: the worst case is the # pre-fix behaviour, so we degrade silently rather than block startup. return False
class _EmbeddingSingleton: """Ensures we only load the heavy transformer model once per session.""" _instance = None @classmethod def get(cls): if cls._instance is None: from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction try: # Model already cached → load it WITHOUT a network round-trip. # sentence-transformers otherwise sends a HEAD request to the HF # Hub on every load to check for updates; on hosts with broken # IPv6 routing to huggingface.co that request stalls ~90s (the # same root cause as the Gemini IPv4FirstTransport fix), and it # runs on the *first turn of every session*, so the narrative # never seems to arrive. local_files_only skips the check. cls._instance = SentenceTransformerEmbeddingFunction( model_name=_EMBEDDING_MODEL, local_files_only=True ) except Exception: # First-ever launch (or a cleared cache): the model is not on # disk yet, so allow the one-time online download. Every later # session then takes the offline fast path above. cls._instance = SentenceTransformerEmbeddingFunction( model_name=_EMBEDDING_MODEL ) return cls._instance
[docs] class VectorMemory: """Local semantic memory store backed by ChromaDB. Args: persist_dir: Filesystem path where ChromaDB will store its data. Created automatically if it does not exist. """ def __init__(self, persist_dir: str) -> None: self._persist_dir = persist_dir self._chroma_client = None self._collection = None # Set True if the embedding runtime can't load (e.g. torch native libs # missing on Windows). Semantic memory then degrades to a no-op so the # game stays playable instead of crashing every turn. self._disabled = False def _ensure_connected(self) -> None: """Lazy-init ChromaDB only when first used. If the embedding runtime is unavailable, the store is marked disabled rather than raising: callers then get empty results / no-ops. """ if self._collection is not None or self._disabled: return try: import chromadb self._chroma_client = chromadb.PersistentClient(path=self._persist_dir) self._collection = self._chroma_client.get_or_create_collection( name=_COLLECTION_NAME, embedding_function=_EmbeddingSingleton.get(), ) except Exception as exc: self._disabled = True _warn_runtime_unavailable_once(exc) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def embed_chunk( self, save_id: str, turn_id: int, text: str, chunk_type: str = "narrative", ) -> str: """Embed a text chunk and store it with turn_id metadata.""" if not text or not text.strip(): raise ValueError("Cannot embed empty or whitespace-only text.") self._ensure_connected() if self._disabled: return "" doc_id = str(uuid.uuid4()) self._collection.add( documents=[text], metadatas=[{ "save_id": save_id, "turn_id": turn_id, "chunk_type": chunk_type, }], ids=[doc_id], ) return doc_id
[docs] def query( self, save_id: str, query_text: str, k: int = 5, current_turn_id: int | None = None, max_turn_id: int | None = None, ) -> list[dict[str, Any]]: """Retrieve the top-k most relevant chunks using Time-Weighted search.""" if not query_text or not query_text.strip(): raise ValueError("Query text must not be empty.") self._ensure_connected() if self._disabled: return [] # Fetch more candidates than k to allow for re-ranking candidate_count = max(k * 3, 20) # Build filter condition where_cond: dict[str, Any] = {"save_id": save_id} if max_turn_id is not None: where_cond = { "$and": [ {"save_id": {"$eq": save_id}}, {"turn_id": {"$lte": max_turn_id}} ] } # How many docs exist for this save. Each save has its OWN collection # (persist_dir is per save_id), so count() is exactly this save's chunk # count — a cheap metadata read, instead of get() which used to # materialise every chunk's document + metadata just to size the query. available = self._collection.count() if available == 0: return [] fetch_k = min(candidate_count, available) results = self._collection.query( query_texts=[query_text], n_results=fetch_k, where=where_cond, ) candidates: list[dict[str, Any]] = [] documents = results.get("documents", [[]])[0] metadatas = results.get("metadatas", [[]])[0] distances = results.get("distances", [[]])[0] for doc, meta, dist in zip(documents, metadatas, distances): turn_id = int(meta.get("turn_id", 0)) chunk_type = str(meta.get("chunk_type", "narrative")) # 1. Semantic Score (0.0 to 1.0, 1.0 is perfect match) # ChromaDB cosine distance: 0.0 is perfect, 2.0 is opposite semantic_score = max(0.0, 1.0 - (float(dist) / 2.0)) # 2. Recency Weight (0.1 to 1.0) if current_turn_id is None or chunk_type == "lore" or turn_id == 0: time_weight = 1.0 else: # Linear decay: Lose 1% weight per turn of age, cap at 10% age = max(0, current_turn_id - turn_id) time_weight = max(0.1, 1.0 - (age * 0.01)) final_score = semantic_score * time_weight candidates.append({ "text": doc, "turn_id": turn_id, "chunk_type": chunk_type, "distance": float(dist), "score": final_score }) # Sort by final score descending and take top k candidates.sort(key=lambda x: x["score"], reverse=True) return candidates[:k]
[docs] def rollback(self, save_id: str, target_turn_id: int) -> int: """Delete all chunks for a save with turn_id strictly greater than target.""" self._ensure_connected() if self._disabled: return 0 # ChromaDB's $gt operator requires a numeric type result = self._collection.get( where={ "$and": [ {"save_id": {"$eq": save_id}}, {"turn_id": {"$gt": target_turn_id}}, ] } ) ids_to_delete: list[str] = result["ids"] if ids_to_delete: self._collection.delete(ids=ids_to_delete) return len(ids_to_delete)
[docs] def update_turn_narrative( self, save_id: str, turn_id: int, new_text: str, chunk_type: str = "narrative", ) -> None: """Delete existing chunks for this turn and embed the new text. """ self._ensure_connected() if self._disabled: return result = self._collection.get( where={ "$and": [ {"save_id": {"$eq": save_id}}, {"turn_id": {"$eq": turn_id}}, {"chunk_type": {"$eq": chunk_type}}, ] } ) ids_to_delete: list[str] = result["ids"] if ids_to_delete: self._collection.delete(ids=ids_to_delete) if new_text and new_text.strip(): self.embed_chunk(save_id, turn_id, new_text, chunk_type)