108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
import asyncio
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class Message(BaseModel):
|
|
"""A single message in a conversation."""
|
|
role: str # "user" or "assistant"
|
|
content: str
|
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
|
channel_id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class UserSession:
|
|
"""Manages a user's conversation session with the agent.
|
|
|
|
A session tracks:
|
|
- Active communication channels
|
|
- Conversation history
|
|
- In-flight agent tasks (for interruption)
|
|
- User metadata
|
|
"""
|
|
|
|
def __init__(self, session_id: str, user_id: str):
|
|
self.session_id = session_id
|
|
self.user_id = user_id
|
|
self.active_channels: List[str] = []
|
|
self.conversation_history: List[Message] = []
|
|
self.current_task: Optional[asyncio.Task] = None
|
|
self.metadata: Dict[str, Any] = {}
|
|
self.created_at = datetime.utcnow()
|
|
self.last_activity = datetime.utcnow()
|
|
|
|
def add_channel(self, channel_id: str) -> None:
|
|
"""Attach a channel to this session."""
|
|
if channel_id not in self.active_channels:
|
|
self.active_channels.append(channel_id)
|
|
self.last_activity = datetime.utcnow()
|
|
|
|
def remove_channel(self, channel_id: str) -> None:
|
|
"""Detach a channel from this session."""
|
|
if channel_id in self.active_channels:
|
|
self.active_channels.remove(channel_id)
|
|
self.last_activity = datetime.utcnow()
|
|
|
|
def add_message(self, role: str, content: str, channel_id: Optional[str] = None, **kwargs) -> None:
|
|
"""Add a message to conversation history."""
|
|
message = Message(
|
|
role=role,
|
|
content=content,
|
|
channel_id=channel_id,
|
|
metadata=kwargs
|
|
)
|
|
self.conversation_history.append(message)
|
|
self.last_activity = datetime.utcnow()
|
|
|
|
def get_history(self, limit: Optional[int] = None) -> List[Message]:
|
|
"""Get conversation history.
|
|
|
|
Args:
|
|
limit: Maximum number of recent messages to return (None = all)
|
|
|
|
Returns:
|
|
List of Message objects
|
|
"""
|
|
if limit:
|
|
return self.conversation_history[-limit:]
|
|
return self.conversation_history
|
|
|
|
async def interrupt(self) -> bool:
|
|
"""Interrupt the current agent task if one is running.
|
|
|
|
Returns:
|
|
True if a task was interrupted, False otherwise
|
|
"""
|
|
if self.current_task and not self.current_task.done():
|
|
self.current_task.cancel()
|
|
try:
|
|
await self.current_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self.current_task = None
|
|
return True
|
|
return False
|
|
|
|
def set_task(self, task: asyncio.Task) -> None:
|
|
"""Set the current agent task."""
|
|
self.current_task = task
|
|
|
|
def is_busy(self) -> bool:
|
|
"""Check if the agent is currently processing a request."""
|
|
return self.current_task is not None and not self.current_task.done()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Serialize session to dict."""
|
|
return {
|
|
"session_id": self.session_id,
|
|
"user_id": self.user_id,
|
|
"active_channels": self.active_channels,
|
|
"message_count": len(self.conversation_history),
|
|
"is_busy": self.is_busy(),
|
|
"created_at": self.created_at.isoformat(),
|
|
"last_activity": self.last_activity.isoformat(),
|
|
"metadata": self.metadata
|
|
}
|