AI generated first iteration
This commit is contained in:
211
app/agent/orchestrator.py
Normal file
211
app/agent/orchestrator.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Agent Orchestrator - Routes messages and manages agent modes."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agent.policies import (
|
||||
ChatActivityPolicy,
|
||||
ResponseSuppression,
|
||||
SuspiciousContentPolicy,
|
||||
)
|
||||
from app.agent.modes.hearthkeeper import HearthkeeperMode
|
||||
from app.agent.modes.steward import StewardMode
|
||||
from app.agent.modes.warden import WardenMode
|
||||
from app.agent.modes.librarian import LibrarianMode
|
||||
from app.agent.modes.scribe import ScribeMode
|
||||
from app.llm.client import LLMClient
|
||||
from app.memory.database import async_session_factory
|
||||
from app.memory.models import AgentActionType
|
||||
from app.memory.repository import Repository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOrchestrator:
|
||||
"""
|
||||
Main orchestrator for agent behavior.
|
||||
|
||||
Routes chat messages to appropriate modes and manages responses.
|
||||
Implements policies for when to speak, when to stay silent,
|
||||
and how to flag suspicious content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the orchestrator and all modes."""
|
||||
self.llm_client = LLMClient()
|
||||
|
||||
# Initialize modes
|
||||
self.hearthkeeper = HearthkeeperMode(self.llm_client)
|
||||
self.steward = StewardMode(self.llm_client)
|
||||
self.warden = WardenMode(self.llm_client)
|
||||
self.librarian = LibrarianMode(self.llm_client)
|
||||
self.scribe = ScribeMode(self.llm_client)
|
||||
|
||||
# Initialize policies
|
||||
self.chat_activity = ChatActivityPolicy(inactivity_threshold_minutes=15)
|
||||
self.response_suppression = ResponseSuppression()
|
||||
self.suspicious_content = SuspiciousContentPolicy()
|
||||
|
||||
# Track active sessions
|
||||
self.active_sessions: dict[str, dict] = {}
|
||||
|
||||
logger.info("AgentOrchestrator initialized with all modes and policies")
|
||||
|
||||
async def start_session(self, channel_name: str) -> str:
|
||||
"""
|
||||
Start a new stream session.
|
||||
|
||||
Args:
|
||||
channel_name: Twitch channel name
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
await repo.create_session(channel_name)
|
||||
|
||||
self.active_sessions[session_id] = {
|
||||
"channel_name": channel_name,
|
||||
"started_at": datetime.utcnow(),
|
||||
"message_count": 0,
|
||||
"theme": None,
|
||||
}
|
||||
|
||||
logger.info(f"Started session {session_id} for {channel_name}")
|
||||
return session_id
|
||||
|
||||
async def end_session(self, session_id: str) -> None:
|
||||
"""
|
||||
End a stream session and trigger ledger generation.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
"""
|
||||
if session_id not in self.active_sessions:
|
||||
logger.warning(f"Session {session_id} not found")
|
||||
return
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
await repo.end_session(session_id)
|
||||
|
||||
del self.active_sessions[session_id]
|
||||
logger.info(f"Ended session {session_id}")
|
||||
|
||||
async def handle_chat_message(
|
||||
self,
|
||||
session_id: str,
|
||||
username: str,
|
||||
message: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Process a chat message and determine agent response.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
username: Username of message sender
|
||||
message: Message content
|
||||
|
||||
Returns:
|
||||
Response dict with agent_response, actions_taken, etc.
|
||||
"""
|
||||
if session_id not in self.active_sessions:
|
||||
logger.warning(f"Session {session_id} not found")
|
||||
return {"agent_response": None, "actions_taken": []}
|
||||
|
||||
session_info = self.active_sessions[session_id]
|
||||
actions = []
|
||||
agent_response = None
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
|
||||
# Store the message
|
||||
message_id = await repo.add_chat_message(
|
||||
session_id=session_id,
|
||||
username=username,
|
||||
content=message,
|
||||
is_bot=False,
|
||||
)
|
||||
|
||||
# Record activity
|
||||
self.chat_activity.record_activity(session_id)
|
||||
session_info["message_count"] += 1
|
||||
|
||||
# 1. Warden always analyzes (passive mode)
|
||||
warden_result = await self.warden.analyze_message(message)
|
||||
if warden_result["is_suspicious"]:
|
||||
actions.append(f"WARDEN_FLAG: {warden_result['severity']}")
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
await repo.record_action(
|
||||
session_id=session_id,
|
||||
action_type=AgentActionType.FLAG_SUSPICIOUS,
|
||||
mode="warden",
|
||||
description=f"Detected: {warden_result['patterns_detected']}",
|
||||
triggered_by_message_id=message_id,
|
||||
)
|
||||
|
||||
# 2. Check if we should suppress responses due to active chat
|
||||
recent_messages = []
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
recent_messages = await repo.get_recent_messages(session_id, limit=10)
|
||||
|
||||
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
||||
logger.debug("Response suppressed due to active chat")
|
||||
return {
|
||||
"agent_response": agent_response,
|
||||
"actions_taken": actions,
|
||||
}
|
||||
|
||||
# 3. Hearthkeeper: Generate prompt if chat inactive
|
||||
if self.chat_activity.should_hearthkeeper_prompt(session_id):
|
||||
try:
|
||||
agent_response = await self.hearthkeeper.generate_prompt(
|
||||
theme=session_info.get("theme")
|
||||
)
|
||||
actions.append("HEARTHKEEPER_PROMPT")
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
await repo.record_action(
|
||||
session_id=session_id,
|
||||
action_type=AgentActionType.RESPONSE,
|
||||
mode="hearthkeeper",
|
||||
description=agent_response,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Hearthkeeper: {e}")
|
||||
|
||||
# 4. Librarian: Archive important messages (passive)
|
||||
if len(message) > 50: # Archive longer messages
|
||||
await self.librarian.archive_message(message_id, message, username)
|
||||
|
||||
logger.info(
|
||||
f"Message processed. Session: {session_id}, Actions: {actions}"
|
||||
)
|
||||
|
||||
return {
|
||||
"agent_response": agent_response,
|
||||
"actions_taken": actions,
|
||||
}
|
||||
|
||||
async def get_session_status(self, session_id: str) -> dict:
|
||||
"""Get status of a session."""
|
||||
if session_id not in self.active_sessions:
|
||||
return {}
|
||||
|
||||
session = self.active_sessions[session_id]
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"channel_name": session["channel_name"],
|
||||
"message_count": session["message_count"],
|
||||
"uptime_seconds": (datetime.utcnow() - session["started_at"]).total_seconds(),
|
||||
"theme": session.get("theme"),
|
||||
}
|
||||
Reference in New Issue
Block a user