409 lines
15 KiB
Python
409 lines
15 KiB
Python
"""Agent Orchestrator - Routes messages and manages agent modes."""
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
|
|
from app.agent.policies import (
|
|
ChatActivityPolicy,
|
|
ResponseSuppression,
|
|
SuspiciousContentPolicy,
|
|
)
|
|
from app.agent.modes.hearthkeeper import HearthkeeperMode
|
|
from app.agent.modes.steward import StewardMode
|
|
from app.agent.modes.warden import WardenMode
|
|
from app.agent.modes.librarian import LibrarianMode
|
|
from app.agent.modes.scribe import ScribeMode
|
|
from app.config import settings
|
|
from app.llm.client import LLMClient
|
|
from app.memory.database import get_session
|
|
from app.memory.models import AgentActionType
|
|
from app.memory.repository import Repository
|
|
from app.twitch.chat import send_chat_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentOrchestrator:
|
|
"""
|
|
Main orchestrator for agent behavior.
|
|
|
|
Routes chat messages to appropriate modes and manages responses.
|
|
Implements policies for when to speak, when to stay silent,
|
|
and how to flag suspicious content.
|
|
"""
|
|
|
|
def __init__(self, loop_interval_seconds: float = 60.0):
|
|
"""Initialize the orchestrator and all modes."""
|
|
self.llm_client = LLMClient()
|
|
self.loop_interval_seconds = loop_interval_seconds
|
|
|
|
# Initialize modes
|
|
self.hearthkeeper = HearthkeeperMode(self.llm_client)
|
|
self.steward = StewardMode(self.llm_client)
|
|
self.warden = WardenMode(self.llm_client)
|
|
self.librarian = LibrarianMode(self.llm_client)
|
|
self.scribe = ScribeMode(self.llm_client)
|
|
|
|
# Initialize policies
|
|
self.chat_activity = ChatActivityPolicy(inactivity_threshold_minutes=15)
|
|
self.response_suppression = ResponseSuppression()
|
|
self.suspicious_content = SuspiciousContentPolicy()
|
|
|
|
# Track active sessions
|
|
self.active_sessions: dict[str, dict] = {}
|
|
|
|
logger.info("AgentOrchestrator initialized with all modes and policies")
|
|
|
|
async def start_session(self, channel_name: str) -> str:
|
|
"""
|
|
Start a new stream session.
|
|
|
|
Args:
|
|
channel_name: Twitch channel name
|
|
|
|
Returns:
|
|
Session ID
|
|
"""
|
|
session_id: str | None = None
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
session_id = await repo.create_session(channel_name)
|
|
|
|
if session_id is None:
|
|
raise RuntimeError("Failed to create stream session")
|
|
|
|
self.active_sessions[session_id] = {
|
|
"channel_name": channel_name,
|
|
"started_at": datetime.utcnow(),
|
|
"message_count": 0,
|
|
"theme": None,
|
|
"last_hearthkeeper_prompt_at": None,
|
|
}
|
|
self.chat_activity.record_activity(session_id)
|
|
|
|
logger.info(f"Started session {session_id} for {channel_name}")
|
|
return session_id
|
|
|
|
async def restore_active_sessions(self) -> int:
|
|
"""Restore active sessions from the database after app startup."""
|
|
restored_count = 0
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
sessions = await repo.get_active_sessions()
|
|
|
|
for session in sessions:
|
|
recent_messages = await repo.get_recent_messages(session.id, limit=1)
|
|
message_count = await repo.count_messages(session.id)
|
|
last_activity_at = (
|
|
recent_messages[0].timestamp if recent_messages else session.started_at
|
|
)
|
|
|
|
self.active_sessions[session.id] = {
|
|
"channel_name": session.channel_name,
|
|
"started_at": session.started_at,
|
|
"message_count": message_count,
|
|
"theme": session.theme,
|
|
"last_hearthkeeper_prompt_at": None,
|
|
}
|
|
self.chat_activity.record_activity(session.id, occurred_at=last_activity_at)
|
|
restored_count += 1
|
|
|
|
logger.info(f"Restored {restored_count} active sessions")
|
|
return restored_count
|
|
|
|
async def end_session(self, session_id: str) -> None:
|
|
"""
|
|
End a stream session and trigger ledger generation.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
await repo.end_session(session_id)
|
|
|
|
del self.active_sessions[session_id]
|
|
logger.info(f"Ended session {session_id}")
|
|
|
|
async def handle_chat_message(
|
|
self,
|
|
session_id: str,
|
|
username: str,
|
|
message: str,
|
|
) -> dict:
|
|
"""
|
|
Process a chat message and determine agent response.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
username: Username of message sender
|
|
message: Message content
|
|
|
|
Returns:
|
|
Response dict with agent_response, actions_taken, etc.
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return {"agent_response": None, "actions_taken": []}
|
|
|
|
session_info = self.active_sessions[session_id]
|
|
actions = []
|
|
agent_response = None
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
|
|
# Store the message
|
|
message_id = await repo.add_chat_message(
|
|
session_id=session_id,
|
|
username=username,
|
|
content=message,
|
|
is_bot=False,
|
|
)
|
|
|
|
# Record activity
|
|
self.chat_activity.record_activity(session_id)
|
|
session_info["message_count"] += 1
|
|
session_info["last_hearthkeeper_prompt_at"] = None
|
|
|
|
# 1. Warden always analyzes (passive mode)
|
|
warden_result = await self.warden.analyze_message(message)
|
|
if warden_result["is_suspicious"]:
|
|
actions.append(f"WARDEN_FLAG: {warden_result['severity']}")
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.FLAG_SUSPICIOUS,
|
|
mode="warden",
|
|
description=f"Detected: {warden_result['patterns_detected']}",
|
|
triggered_by_message_id=message_id,
|
|
)
|
|
|
|
# 2. Check if we should suppress responses due to active chat
|
|
recent_messages = []
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
recent_messages = await repo.get_messages_since(
|
|
session_id=session_id,
|
|
since=datetime.utcnow() - timedelta(minutes=1),
|
|
)
|
|
|
|
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
|
logger.debug("Response suppressed due to active chat")
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
# 3. Librarian: Archive important messages (passive)
|
|
if len(message) > 50: # Archive longer messages
|
|
await self.librarian.archive_message(message_id, message, username)
|
|
|
|
logger.info(
|
|
f"Message processed. Session: {session_id}, Actions: {actions}"
|
|
)
|
|
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
async def emit_agent_response(
|
|
self,
|
|
session_id: str,
|
|
message: str,
|
|
mode: str,
|
|
triggered_by_message_id: str | None = None,
|
|
) -> dict:
|
|
"""Send an agent response through the outbound chat boundary."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return {"sent": False, "reason": "session_not_found"}
|
|
|
|
channel_name = session_info["channel_name"]
|
|
sent = await send_chat_message(channel_name=channel_name, message=message)
|
|
bot_username = settings.TWITCH_BOT_USERNAME or "sanctum_chronicler"
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
bot_message_id = await repo.add_chat_message(
|
|
session_id=session_id,
|
|
username=bot_username,
|
|
content=message,
|
|
is_bot=True,
|
|
)
|
|
action_id = await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.RESPONSE,
|
|
mode=mode,
|
|
description=message,
|
|
triggered_by_message_id=triggered_by_message_id,
|
|
)
|
|
|
|
session_info["message_count"] += 1
|
|
logger.info(
|
|
f"Agent response emitted. Session: {session_id}, Mode: {mode}, Sent: {sent}"
|
|
)
|
|
return {
|
|
"sent": sent,
|
|
"channel": channel_name,
|
|
"message_id": bot_message_id,
|
|
"action_id": action_id,
|
|
}
|
|
|
|
async def run_hearthkeeper_loop_test(
|
|
self,
|
|
session_id: str,
|
|
inactive_minutes: int = 16,
|
|
) -> dict:
|
|
"""Exercise the quiet-chat loop and verify it prompts exactly once."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
return {"passed": False, "reason": "session_not_found"}
|
|
|
|
simulated_activity_at = datetime.utcnow() - timedelta(minutes=inactive_minutes)
|
|
self.chat_activity.record_activity(
|
|
session_id=session_id,
|
|
occurred_at=simulated_activity_at,
|
|
)
|
|
session_info["last_hearthkeeper_prompt_at"] = None
|
|
|
|
before_count = await self._count_response_actions(
|
|
session_id=session_id,
|
|
mode="hearthkeeper",
|
|
)
|
|
first_tick = await self._tick_session(session_id)
|
|
second_tick = await self._tick_session(session_id)
|
|
after_count = await self._count_response_actions(
|
|
session_id=session_id,
|
|
mode="hearthkeeper",
|
|
)
|
|
prompts_created = after_count - before_count
|
|
|
|
return {
|
|
"passed": (
|
|
prompts_created == 1
|
|
and first_tick is not None
|
|
and second_tick is None
|
|
),
|
|
"session_id": session_id,
|
|
"inactive_minutes": inactive_minutes,
|
|
"prompts_created": prompts_created,
|
|
"first_tick": first_tick,
|
|
"second_tick": second_tick,
|
|
}
|
|
|
|
async def _count_response_actions(self, session_id: str, mode: str) -> int:
|
|
"""Count response actions for a mode in a session."""
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
actions = await repo.get_session_actions(session_id)
|
|
return sum(
|
|
1
|
|
for action in actions
|
|
if action.action_type == AgentActionType.RESPONSE
|
|
and action.mode == mode
|
|
)
|
|
return 0
|
|
|
|
def set_loop_interval(self, interval_seconds: float) -> None:
|
|
"""Update how frequently the background agent loop runs."""
|
|
if interval_seconds < 1:
|
|
raise ValueError("Loop interval must be at least 1 second")
|
|
self.loop_interval_seconds = interval_seconds
|
|
|
|
def get_loop_status(self) -> dict:
|
|
"""Get background loop configuration and current session count."""
|
|
return {
|
|
"interval_seconds": self.loop_interval_seconds,
|
|
"active_session_count": len(self.active_sessions),
|
|
}
|
|
|
|
async def tick(self) -> list[dict]:
|
|
"""Evaluate active sessions for time-based agent behavior."""
|
|
results = []
|
|
for session_id in list(self.active_sessions.keys()):
|
|
result = await self._tick_session(session_id)
|
|
if result:
|
|
results.append(result)
|
|
return results
|
|
|
|
async def _tick_session(self, session_id: str) -> dict | None:
|
|
"""Evaluate a single active session during the background loop."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
return None
|
|
|
|
recent_messages = []
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
recent_messages = await repo.get_messages_since(
|
|
session_id=session_id,
|
|
since=datetime.utcnow() - timedelta(minutes=1),
|
|
)
|
|
|
|
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": [],
|
|
"agent_response": None,
|
|
"reason": "active_chat",
|
|
}
|
|
|
|
if not self.chat_activity.should_hearthkeeper_prompt(session_id):
|
|
return None
|
|
|
|
last_activity_at = self.chat_activity.last_activity_at(session_id)
|
|
last_prompt_at = session_info.get("last_hearthkeeper_prompt_at")
|
|
if last_prompt_at and last_activity_at and last_prompt_at >= last_activity_at:
|
|
return None
|
|
|
|
try:
|
|
agent_response = await self.hearthkeeper.generate_prompt(
|
|
theme=session_info.get("theme")
|
|
)
|
|
session_info["last_hearthkeeper_prompt_at"] = datetime.utcnow()
|
|
delivery = await self.emit_agent_response(
|
|
session_id=session_id,
|
|
message=agent_response,
|
|
mode="hearthkeeper",
|
|
)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": ["HEARTHKEEPER_PROMPT"],
|
|
"agent_response": agent_response,
|
|
"delivery": delivery,
|
|
"reason": "inactive_chat",
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in Hearthkeeper loop: {e}")
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": [],
|
|
"agent_response": None,
|
|
"reason": "hearthkeeper_error",
|
|
}
|
|
|
|
async def get_session_status(self, session_id: str) -> dict:
|
|
"""Get status of a session."""
|
|
if session_id not in self.active_sessions:
|
|
return {}
|
|
|
|
session = self.active_sessions[session_id]
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"channel_name": session["channel_name"],
|
|
"message_count": session["message_count"],
|
|
"uptime_seconds": (datetime.utcnow() - session["started_at"]).total_seconds(),
|
|
"theme": session.get("theme"),
|
|
}
|