585 lines
22 KiB
Python
585 lines
22 KiB
Python
"""Agent Orchestrator - Routes messages and manages agent modes."""
|
|
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import datetime, timedelta
|
|
|
|
from app.agent.policies import (
|
|
ChatActivityPolicy,
|
|
ResponseSuppression,
|
|
SuspiciousContentPolicy,
|
|
)
|
|
from app.agent.modes.hearthkeeper import HearthkeeperMode
|
|
from app.agent.modes.steward import StewardMode
|
|
from app.agent.modes.warden import WardenMode
|
|
from app.agent.modes.librarian import LibrarianMode
|
|
from app.agent.modes.scribe import ScribeMode
|
|
from app.config import settings
|
|
from app.llm.client import LLMClient
|
|
from app.memory.database import get_session
|
|
from app.memory.models import AgentActionType
|
|
from app.memory.repository import Repository
|
|
from app.twitch.chat import send_chat_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentOrchestrator:
|
|
"""
|
|
Main orchestrator for agent behavior.
|
|
|
|
Routes chat messages to appropriate modes and manages responses.
|
|
Implements policies for when to speak, when to stay silent,
|
|
and how to flag suspicious content.
|
|
"""
|
|
|
|
def __init__(self, loop_interval_seconds: float = 60.0):
|
|
"""Initialize the orchestrator and all modes."""
|
|
self.llm_client = LLMClient()
|
|
self.loop_interval_seconds = loop_interval_seconds
|
|
self.hearthkeeper_prompt_interval = timedelta(
|
|
minutes=settings.HEARTHKEEPER_PROMPT_INTERVAL_MINUTES
|
|
)
|
|
|
|
# Initialize modes
|
|
self.hearthkeeper = HearthkeeperMode(self.llm_client)
|
|
self.steward = StewardMode(self.llm_client)
|
|
self.warden = WardenMode(self.llm_client)
|
|
self.librarian = LibrarianMode(self.llm_client)
|
|
self.scribe = ScribeMode(self.llm_client)
|
|
|
|
# Initialize policies
|
|
self.chat_activity = ChatActivityPolicy(inactivity_threshold_minutes=15)
|
|
self.response_suppression = ResponseSuppression()
|
|
self.suspicious_content = SuspiciousContentPolicy()
|
|
|
|
# Track active sessions
|
|
self.active_sessions: dict[str, dict] = {}
|
|
self.chat_interaction_gate: Callable[[str], Awaitable[bool]] | None = None
|
|
|
|
logger.info("AgentOrchestrator initialized with all modes and policies")
|
|
|
|
def set_chat_interaction_gate(
|
|
self,
|
|
gate: Callable[[str], Awaitable[bool]] | None,
|
|
) -> None:
|
|
"""Set an async gate that must pass before the agent can post to chat."""
|
|
self.chat_interaction_gate = gate
|
|
|
|
async def can_interact_with_chat(self, channel_name: str) -> bool:
|
|
"""Return whether outbound chat interaction is currently allowed."""
|
|
if not self.chat_interaction_gate:
|
|
return True
|
|
try:
|
|
return await self.chat_interaction_gate(channel_name)
|
|
except Exception as e:
|
|
logger.warning("Chat interaction gate failed closed: %s", e)
|
|
return False
|
|
|
|
async def start_session(self, channel_name: str) -> str:
|
|
"""
|
|
Start a new stream session.
|
|
|
|
Args:
|
|
channel_name: Twitch channel name
|
|
|
|
Returns:
|
|
Session ID
|
|
"""
|
|
session_id: str | None = None
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
session_id = await repo.create_session(channel_name)
|
|
|
|
if session_id is None:
|
|
raise RuntimeError("Failed to create stream session")
|
|
|
|
self.active_sessions[session_id] = {
|
|
"channel_name": channel_name,
|
|
"started_at": datetime.utcnow(),
|
|
"message_count": 0,
|
|
"theme": None,
|
|
"dashboard": None,
|
|
"last_hearthkeeper_prompt_at": None,
|
|
}
|
|
self.chat_activity.record_activity(session_id)
|
|
|
|
logger.info(f"Started session {session_id} for {channel_name}")
|
|
return session_id
|
|
|
|
async def ensure_single_active_session_for_channel(self, channel_name: str) -> str:
|
|
"""Return one active session for a channel and end duplicate active sessions."""
|
|
normalized_channel = channel_name.lower()
|
|
matching_sessions = [
|
|
(session_id, session)
|
|
for session_id, session in self.active_sessions.items()
|
|
if session.get("channel_name", "").lower() == normalized_channel
|
|
]
|
|
|
|
if not matching_sessions:
|
|
return await self.start_session(channel_name)
|
|
|
|
keep_session_id, _ = max(
|
|
matching_sessions,
|
|
key=lambda item: item[1].get("started_at", datetime.min),
|
|
)
|
|
duplicate_session_ids = [
|
|
session_id
|
|
for session_id, _ in matching_sessions
|
|
if session_id != keep_session_id
|
|
]
|
|
|
|
if duplicate_session_ids:
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
for session_id in duplicate_session_ids:
|
|
await repo.end_session(session_id)
|
|
|
|
for session_id in duplicate_session_ids:
|
|
self.active_sessions.pop(session_id, None)
|
|
self.chat_activity.clear_activity(session_id)
|
|
|
|
logger.warning(
|
|
"Ended %s duplicate active session(s) for channel %s; keeping %s",
|
|
len(duplicate_session_ids),
|
|
channel_name,
|
|
keep_session_id,
|
|
)
|
|
|
|
return keep_session_id
|
|
|
|
async def restore_active_sessions(self) -> int:
|
|
"""Restore active sessions from the database after app startup."""
|
|
restored_count = 0
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
sessions = await repo.get_active_sessions()
|
|
|
|
for session in sessions:
|
|
recent_messages = await repo.get_recent_human_messages(
|
|
session.id,
|
|
limit=1,
|
|
)
|
|
message_count = await repo.count_messages(session.id)
|
|
dashboard = await repo.get_dashboard(session.id)
|
|
last_hearthkeeper_action = await repo.get_latest_action(
|
|
session_id=session.id,
|
|
action_type=AgentActionType.RESPONSE,
|
|
mode="hearthkeeper",
|
|
)
|
|
last_activity_at = (
|
|
recent_messages[0].timestamp if recent_messages else session.started_at
|
|
)
|
|
|
|
self.active_sessions[session.id] = {
|
|
"channel_name": session.channel_name,
|
|
"started_at": session.started_at,
|
|
"message_count": message_count,
|
|
"theme": session.theme,
|
|
"dashboard": Repository.serialize_dashboard(dashboard),
|
|
"last_hearthkeeper_prompt_at": (
|
|
last_hearthkeeper_action.timestamp
|
|
if last_hearthkeeper_action
|
|
else None
|
|
),
|
|
}
|
|
self.chat_activity.record_activity(session.id, occurred_at=last_activity_at)
|
|
restored_count += 1
|
|
|
|
logger.info(f"Restored {restored_count} active sessions")
|
|
return restored_count
|
|
|
|
async def end_session(self, session_id: str) -> None:
|
|
"""
|
|
End a stream session and trigger ledger generation.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
await repo.end_session(session_id)
|
|
|
|
del self.active_sessions[session_id]
|
|
logger.info(f"Ended session {session_id}")
|
|
|
|
async def handle_chat_message(
|
|
self,
|
|
session_id: str,
|
|
username: str,
|
|
message: str,
|
|
) -> dict:
|
|
"""
|
|
Process a chat message and determine agent response.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
username: Username of message sender
|
|
message: Message content
|
|
|
|
Returns:
|
|
Response dict with agent_response, actions_taken, etc.
|
|
"""
|
|
if session_id not in self.active_sessions:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return {"agent_response": None, "actions_taken": []}
|
|
|
|
session_info = self.active_sessions[session_id]
|
|
actions = []
|
|
agent_response = None
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
|
|
# Store the message
|
|
message_id = await repo.add_chat_message(
|
|
session_id=session_id,
|
|
username=username,
|
|
content=message,
|
|
is_bot=False,
|
|
)
|
|
|
|
# Record activity
|
|
self.chat_activity.record_activity(session_id)
|
|
session_info["message_count"] += 1
|
|
session_info["last_hearthkeeper_prompt_at"] = None
|
|
|
|
# 1. Warden always analyzes (passive mode)
|
|
warden_result = await self.warden.analyze_message(message)
|
|
if warden_result["is_suspicious"]:
|
|
actions.append(f"WARDEN_FLAG: {warden_result['severity']}")
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.FLAG_SUSPICIOUS,
|
|
mode="warden",
|
|
description=f"Detected: {warden_result['patterns_detected']}",
|
|
triggered_by_message_id=message_id,
|
|
)
|
|
|
|
# 2. Check if we should suppress responses due to active chat
|
|
recent_messages = []
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
recent_messages = await repo.get_human_messages_since(
|
|
session_id=session_id,
|
|
since=datetime.utcnow() - timedelta(minutes=1),
|
|
)
|
|
|
|
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
|
logger.debug("Response suppressed due to active chat")
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
# 3. Librarian: Archive important messages (passive)
|
|
if len(message) > 50: # Archive longer messages
|
|
await self.librarian.archive_message(message_id, message, username)
|
|
|
|
logger.info(
|
|
f"Message processed. Session: {session_id}, Actions: {actions}"
|
|
)
|
|
|
|
return {
|
|
"agent_response": agent_response,
|
|
"actions_taken": actions,
|
|
}
|
|
|
|
async def emit_agent_response(
|
|
self,
|
|
session_id: str,
|
|
message: str,
|
|
mode: str,
|
|
triggered_by_message_id: str | None = None,
|
|
) -> dict:
|
|
"""Send an agent response through the outbound chat boundary."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
logger.warning(f"Session {session_id} not found")
|
|
return {"sent": False, "reason": "session_not_found"}
|
|
|
|
channel_name = session_info["channel_name"]
|
|
if not await self.can_interact_with_chat(channel_name):
|
|
logger.info(
|
|
"Agent response suppressed because stream is not live. Session: %s",
|
|
session_id,
|
|
)
|
|
return {"sent": False, "reason": "stream_not_live"}
|
|
|
|
sent = await send_chat_message(channel_name=channel_name, message=message)
|
|
bot_username = settings.TWITCH_BOT_USERNAME or "sanctum_chronicler"
|
|
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
bot_message_id = await repo.add_chat_message(
|
|
session_id=session_id,
|
|
username=bot_username,
|
|
content=message,
|
|
is_bot=True,
|
|
)
|
|
action_id = await repo.record_action(
|
|
session_id=session_id,
|
|
action_type=AgentActionType.RESPONSE,
|
|
mode=mode,
|
|
description=message,
|
|
triggered_by_message_id=triggered_by_message_id,
|
|
)
|
|
|
|
session_info["message_count"] += 1
|
|
logger.info(
|
|
f"Agent response emitted. Session: {session_id}, Mode: {mode}, Sent: {sent}"
|
|
)
|
|
return {
|
|
"sent": sent,
|
|
"channel": channel_name,
|
|
"message_id": bot_message_id,
|
|
"action_id": action_id,
|
|
}
|
|
|
|
async def run_hearthkeeper_loop_test(
|
|
self,
|
|
session_id: str,
|
|
inactive_minutes: int = 16,
|
|
) -> dict:
|
|
"""Exercise the quiet-chat loop and verify it prompts exactly once."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
return {"passed": False, "reason": "session_not_found"}
|
|
|
|
simulated_activity_at = datetime.utcnow() - timedelta(minutes=inactive_minutes)
|
|
self.chat_activity.record_activity(
|
|
session_id=session_id,
|
|
occurred_at=simulated_activity_at,
|
|
)
|
|
session_info["last_hearthkeeper_prompt_at"] = None
|
|
|
|
before_count = await self._count_response_actions(
|
|
session_id=session_id,
|
|
mode="hearthkeeper",
|
|
)
|
|
first_tick = await self._tick_session(session_id)
|
|
second_tick = await self._tick_session(session_id)
|
|
after_count = await self._count_response_actions(
|
|
session_id=session_id,
|
|
mode="hearthkeeper",
|
|
)
|
|
session_info["last_hearthkeeper_prompt_at"] = (
|
|
datetime.utcnow() - self.hearthkeeper_prompt_interval
|
|
)
|
|
third_tick = await self._tick_session(session_id)
|
|
final_count = await self._count_response_actions(
|
|
session_id=session_id,
|
|
mode="hearthkeeper",
|
|
)
|
|
prompts_created = after_count - before_count
|
|
prompts_created_after_interval = final_count - before_count
|
|
|
|
return {
|
|
"passed": (
|
|
prompts_created == 1
|
|
and first_tick is not None
|
|
and second_tick is None
|
|
and third_tick is not None
|
|
and prompts_created_after_interval == 2
|
|
),
|
|
"session_id": session_id,
|
|
"inactive_minutes": inactive_minutes,
|
|
"prompts_created": prompts_created,
|
|
"prompts_created_after_interval": prompts_created_after_interval,
|
|
"first_tick": first_tick,
|
|
"second_tick": second_tick,
|
|
"third_tick_after_interval": third_tick,
|
|
}
|
|
|
|
async def _count_response_actions(self, session_id: str, mode: str) -> int:
|
|
"""Count response actions for a mode in a session."""
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
actions = await repo.get_session_actions(session_id)
|
|
return sum(
|
|
1
|
|
for action in actions
|
|
if action.action_type == AgentActionType.RESPONSE
|
|
and action.mode == mode
|
|
)
|
|
return 0
|
|
|
|
def set_loop_interval(self, interval_seconds: float) -> None:
|
|
"""Update how frequently the background agent loop runs."""
|
|
if interval_seconds < 1:
|
|
raise ValueError("Loop interval must be at least 1 second")
|
|
self.loop_interval_seconds = interval_seconds
|
|
|
|
def get_loop_status(self) -> dict:
|
|
"""Get background loop configuration and current session count."""
|
|
return {
|
|
"interval_seconds": self.loop_interval_seconds,
|
|
"hearthkeeper_prompt_interval_minutes": int(
|
|
self.hearthkeeper_prompt_interval.total_seconds() / 60
|
|
),
|
|
"active_session_count": len(self.active_sessions),
|
|
}
|
|
|
|
def get_hearthkeeper_runtime_status(self, session_id: str) -> dict:
|
|
"""Get Hearthkeeper timing status for an active session."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
return {}
|
|
|
|
now = datetime.utcnow()
|
|
last_activity_at = self.chat_activity.last_activity_at(session_id)
|
|
last_prompt_at = session_info.get("last_hearthkeeper_prompt_at")
|
|
|
|
next_from_activity = None
|
|
if last_activity_at:
|
|
next_from_activity = last_activity_at + self.chat_activity.inactivity_threshold
|
|
|
|
next_from_prompt = None
|
|
if last_prompt_at:
|
|
next_from_prompt = last_prompt_at + self.hearthkeeper_prompt_interval
|
|
|
|
next_eligible_at = max(
|
|
[candidate for candidate in (next_from_activity, next_from_prompt) if candidate],
|
|
default=None,
|
|
)
|
|
seconds_until_next_prompt = None
|
|
if next_eligible_at:
|
|
seconds_until_next_prompt = max(
|
|
0,
|
|
int((next_eligible_at - now).total_seconds()),
|
|
)
|
|
quiet_enough = bool(
|
|
last_activity_at
|
|
and now - last_activity_at >= self.chat_activity.inactivity_threshold
|
|
)
|
|
prompt_interval_elapsed = bool(
|
|
not last_prompt_at
|
|
or now - last_prompt_at >= self.hearthkeeper_prompt_interval
|
|
)
|
|
|
|
return {
|
|
"last_activity_at": last_activity_at.isoformat() if last_activity_at else None,
|
|
"last_hearthkeeper_prompt_at": (
|
|
last_prompt_at.isoformat() if last_prompt_at else None
|
|
),
|
|
"next_eligible_prompt_at": (
|
|
next_eligible_at.isoformat() if next_eligible_at else None
|
|
),
|
|
"seconds_until_next_prompt": seconds_until_next_prompt,
|
|
"inactivity_threshold_minutes": int(
|
|
self.chat_activity.inactivity_threshold.total_seconds() / 60
|
|
),
|
|
"prompt_interval_minutes": int(
|
|
self.hearthkeeper_prompt_interval.total_seconds() / 60
|
|
),
|
|
"can_prompt_now": quiet_enough and prompt_interval_elapsed,
|
|
}
|
|
|
|
async def tick(self) -> list[dict]:
|
|
"""Evaluate active sessions for time-based agent behavior."""
|
|
results = []
|
|
for session_id in list(self.active_sessions.keys()):
|
|
result = await self._tick_session(session_id)
|
|
if result:
|
|
results.append(result)
|
|
return results
|
|
|
|
async def _tick_session(self, session_id: str) -> dict | None:
|
|
"""Evaluate a single active session during the background loop."""
|
|
session_info = self.active_sessions.get(session_id)
|
|
if not session_info:
|
|
return None
|
|
|
|
active_chat_messages = []
|
|
recent_discussion_messages = []
|
|
async for db_session in get_session():
|
|
repo = Repository(db_session)
|
|
active_chat_messages = await repo.get_human_messages_since(
|
|
session_id=session_id,
|
|
since=datetime.utcnow() - timedelta(minutes=1),
|
|
)
|
|
recent_discussion_messages = await repo.get_recent_human_messages(
|
|
session_id=session_id,
|
|
limit=5,
|
|
)
|
|
|
|
if self.response_suppression.should_suppress_response(len(active_chat_messages)):
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": [],
|
|
"agent_response": None,
|
|
"reason": "active_chat",
|
|
}
|
|
|
|
if not self.chat_activity.should_hearthkeeper_prompt(session_id):
|
|
return None
|
|
|
|
if not await self.can_interact_with_chat(session_info["channel_name"]):
|
|
return None
|
|
|
|
last_activity_at = self.chat_activity.last_activity_at(session_id)
|
|
last_prompt_at = session_info.get("last_hearthkeeper_prompt_at")
|
|
if last_activity_at and last_prompt_at and last_activity_at > last_prompt_at:
|
|
session_info["last_hearthkeeper_prompt_at"] = None
|
|
last_prompt_at = None
|
|
if (
|
|
last_prompt_at
|
|
and datetime.utcnow() - last_prompt_at < self.hearthkeeper_prompt_interval
|
|
):
|
|
return None
|
|
|
|
try:
|
|
recent_discussion = [
|
|
message.content for message in recent_discussion_messages[:5]
|
|
]
|
|
agent_response = await self.hearthkeeper.generate_prompt(
|
|
theme=session_info.get("theme"),
|
|
dashboard=session_info.get("dashboard"),
|
|
recent_discussion=recent_discussion,
|
|
)
|
|
session_info["last_hearthkeeper_prompt_at"] = datetime.utcnow()
|
|
delivery = await self.emit_agent_response(
|
|
session_id=session_id,
|
|
message=agent_response,
|
|
mode="hearthkeeper",
|
|
)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": ["HEARTHKEEPER_PROMPT"],
|
|
"agent_response": agent_response,
|
|
"delivery": delivery,
|
|
"reason": "inactive_chat",
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in Hearthkeeper loop: {e}")
|
|
return {
|
|
"session_id": session_id,
|
|
"actions_taken": [],
|
|
"agent_response": None,
|
|
"reason": "hearthkeeper_error",
|
|
}
|
|
|
|
async def get_session_status(self, session_id: str) -> dict:
|
|
"""Get status of a session."""
|
|
if session_id not in self.active_sessions:
|
|
return {}
|
|
|
|
session = self.active_sessions[session_id]
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"channel_name": session["channel_name"],
|
|
"message_count": session["message_count"],
|
|
"uptime_seconds": (datetime.utcnow() - session["started_at"]).total_seconds(),
|
|
"theme": session.get("theme"),
|
|
"dashboard": session.get("dashboard"),
|
|
}
|