212 lines
7.3 KiB
Python
212 lines
7.3 KiB
Python
"""Agent Orchestrator - Routes messages and manages agent modes."""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.agent.policies import (
|
|
ChatActivityPolicy,
|
|
ResponseSuppression,
|
|
SuspiciousContentPolicy,
|
|
)
|
|
from app.agent.modes.hearthkeeper import HearthkeeperMode
|
|
from app.agent.modes.steward import StewardMode
|
|
from app.agent.modes.warden import WardenMode
|
|
from app.agent.modes.librarian import LibrarianMode
|
|
from app.agent.modes.scribe import ScribeMode
|
|
from app.llm.client import LLMClient
|
|
from app.memory.database import async_session_factory
|
|
from app.memory.models import AgentActionType
|
|
from app.memory.repository import Repository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentOrchestrator:
|
|
"""
|
|
Main orchestrator for agent behavior.
|
|
|
|
Routes chat messages to appropriate modes and manages responses.
|
|
Implements policies for when to speak, when to stay silent,
|
|
and how to flag suspicious content.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the orchestrator and all modes."""
|
|
self.llm_client = LLMClient()
|
|
|
|
# Initialize modes
|
|
self.hearthkeeper = HearthkeeperMode(self.llm_client)
|
|
self.steward = StewardMode(self.llm_client)
|
|
self.warden = WardenMode(self.llm_client)
|
|
self.librarian = LibrarianMode(self.llm_client)
|
|
self.scribe = ScribeMode(self.llm_client)
|
|
|
|
# Initialize policies
|
|
self.chat_activity = ChatActivityPolicy(inactivity_threshold_minutes=15)
|
|
self.response_suppression = ResponseSuppression()
|
|
self.suspicious_content = SuspiciousContentPolicy()
|
|
|
|
# Track active sessions
|
|
self.active_sessions: dict[str, dict] = {}
|
|
|
|
logger.info("AgentOrchestrator initialized with all modes and policies")
|
|
|
|
async def start_session(self, channel_name: str) -> str:
|
|
"""
|
|
Start a new stream session.
|
|
|
|
Args:
|
|
channel_name: Twitch channel name
|
|
|
|
Returns:
|
|
Session ID
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
await repo.create_session(channel_name)
|
|
|
|
self.active_sessions[session_id] = {
|
|
"channel_name": channel_name,
|
|
"started_at": datetime.utcnow(),
|
|
"message_count": 0,
|
|
"theme": None,
|
|
}
|
|
|
|
logger.info(f"Started session {session_id} for {channel_name}")
|
|
return session_id
|
|
|
|
async def end_session(self, session_id: str) -> None:
|
|
"""
|
|
End a stream session and trigger ledger generation.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return
|
|
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
await repo.end_session(session_id)
|
|
|
|
del self.active_sessions[session_id]
|
|
logger.info(f"Ended session {session_id}")
|
|
|
|
async def handle_chat_message(
|
|
self,
|
|
session_id: str,
|
|
username: str,
|
|
message: str,
|
|
) -> dict:
|
|
"""
|
|
Process a chat message and determine agent response.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
username: Username of message sender
|
|
message: Message content
|
|
|
|
Returns:
|
|
Response dict with agent_response, actions_taken, etc.
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return {"agent_response": None, "actions_taken": []}
|
|
|
|
session_info = self.active_sessions[session_id]
|
|
actions = []
|
|
agent_response = None
|
|
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
|
|
# Store the message
|
|
message_id = await repo.add_chat_message(
|
|
session_id=session_id,
|
|
username=username,
|
|
content=message,
|
|
is_bot=False,
|
|
)
|
|
|
|
# Record activity
|
|
self.chat_activity.record_activity(session_id)
|
|
session_info["message_count"] += 1
|
|
|
|
# 1. Warden always analyzes (passive mode)
|
|
warden_result = await self.warden.analyze_message(message)
|
|
if warden_result["is_suspicious"]:
|
|
actions.append(f"WARDEN_FLAG: {warden_result['severity']}")
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.FLAG_SUSPICIOUS,
|
|
mode="warden",
|
|
description=f"Detected: {warden_result['patterns_detected']}",
|
|
triggered_by_message_id=message_id,
|
|
)
|
|
|
|
# 2. Check if we should suppress responses due to active chat
|
|
recent_messages = []
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
recent_messages = await repo.get_recent_messages(session_id, limit=10)
|
|
|
|
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
|
logger.debug("Response suppressed due to active chat")
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
# 3. Hearthkeeper: Generate prompt if chat inactive
|
|
if self.chat_activity.should_hearthkeeper_prompt(session_id):
|
|
try:
|
|
agent_response = await self.hearthkeeper.generate_prompt(
|
|
theme=session_info.get("theme")
|
|
)
|
|
actions.append("HEARTHKEEPER_PROMPT")
|
|
async with async_session_factory() as db_session:
|
|
repo = Repository(db_session)
|
|
await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.RESPONSE,
|
|
mode="hearthkeeper",
|
|
description=agent_response,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in Hearthkeeper: {e}")
|
|
|
|
# 4. Librarian: Archive important messages (passive)
|
|
if len(message) > 50: # Archive longer messages
|
|
await self.librarian.archive_message(message_id, message, username)
|
|
|
|
logger.info(
|
|
f"Message processed. Session: {session_id}, Actions: {actions}"
|
|
)
|
|
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
async def get_session_status(self, session_id: str) -> dict:
|
|
"""Get status of a session."""
|
|
if session_id not in self.active_sessions:
|
|
return {}
|
|
|
|
session = self.active_sessions[session_id]
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"channel_name": session["channel_name"],
|
|
"message_count": session["message_count"],
|
|
"uptime_seconds": (datetime.utcnow() - session["started_at"]).total_seconds(),
|
|
"theme": session.get("theme"),
|
|
}
|