Files
ai/backend/src/trigger/context.py

62 lines
1.9 KiB
Python

"""
Execution context tracking using Python's contextvars.
Each execution gets a unique seq number that propagates through all async calls,
allowing us to track which execution made which changes for conflict detection.
"""
import logging
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger(__name__)
# Context variables - automatically propagate through async call chains
_execution_context: ContextVar[Optional["ExecutionContext"]] = ContextVar(
"execution_context", default=None
)
@dataclass
class ExecutionContext:
"""
Execution context for a single trigger execution.
Automatically propagates through async calls via contextvars.
Tracks the seq number and which store snapshots were read.
"""
seq: int
"""Sequential execution number - determines commit order"""
trigger_name: str
"""Name/type of trigger being executed"""
snapshot_seqs: dict[str, int] = field(default_factory=dict)
"""Store name -> seq number of snapshot that was read"""
def record_snapshot(self, store_name: str, snapshot_seq: int) -> None:
"""Record that we read a snapshot from a store at a specific seq"""
self.snapshot_seqs[store_name] = snapshot_seq
logger.debug(f"Seq {self.seq}: Read {store_name} at seq {snapshot_seq}")
def __str__(self) -> str:
return f"ExecutionContext(seq={self.seq}, trigger={self.trigger_name})"
def get_execution_context() -> Optional[ExecutionContext]:
"""Get the current execution context, or None if not in an execution"""
return _execution_context.get()
def set_execution_context(ctx: ExecutionContext) -> None:
"""Set the execution context for the current async task"""
_execution_context.set(ctx)
logger.debug(f"Set execution context: {ctx}")
def clear_execution_context() -> None:
"""Clear the execution context"""
_execution_context.set(None)