Implement runtime agent loop and container hygiene
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
"""Agent Orchestrator - Routes messages and manages agent modes."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.agent.policies import (
|
||||
ChatActivityPolicy,
|
||||
@@ -16,7 +14,7 @@ from app.agent.modes.warden import WardenMode
|
||||
from app.agent.modes.librarian import LibrarianMode
|
||||
from app.agent.modes.scribe import ScribeMode
|
||||
from app.llm.client import LLMClient
|
||||
from app.memory.database import async_session_factory
|
||||
from app.memory.database import get_session
|
||||
from app.memory.models import AgentActionType
|
||||
from app.memory.repository import Repository
|
||||
|
||||
@@ -32,9 +30,10 @@ class AgentOrchestrator:
|
||||
and how to flag suspicious content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, loop_interval_seconds: float = 60.0):
|
||||
"""Initialize the orchestrator and all modes."""
|
||||
self.llm_client = LLMClient()
|
||||
self.loop_interval_seconds = loop_interval_seconds
|
||||
|
||||
# Initialize modes
|
||||
self.hearthkeeper = HearthkeeperMode(self.llm_client)
|
||||
@@ -63,18 +62,22 @@ class AgentOrchestrator:
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
session_id: str | None = None
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
await repo.create_session(channel_name)
|
||||
session_id = await repo.create_session(channel_name)
|
||||
|
||||
if session_id is None:
|
||||
raise RuntimeError("Failed to create stream session")
|
||||
|
||||
self.active_sessions[session_id] = {
|
||||
"channel_name": channel_name,
|
||||
"started_at": datetime.utcnow(),
|
||||
"message_count": 0,
|
||||
"theme": None,
|
||||
"last_hearthkeeper_prompt_at": None,
|
||||
}
|
||||
self.chat_activity.record_activity(session_id)
|
||||
|
||||
logger.info(f"Started session {session_id} for {channel_name}")
|
||||
return session_id
|
||||
@@ -90,7 +93,7 @@ class AgentOrchestrator:
|
||||
logger.warning(f"Session {session_id} not found")
|
||||
return
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
await repo.end_session(session_id)
|
||||
|
||||
@@ -122,7 +125,7 @@ class AgentOrchestrator:
|
||||
actions = []
|
||||
agent_response = None
|
||||
|
||||
async with async_session_factory() as db_session:
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
|
||||
# Store the message
|
||||
@@ -136,12 +139,13 @@ class AgentOrchestrator:
|
||||
# Record activity
|
||||
self.chat_activity.record_activity(session_id)
|
||||
session_info["message_count"] += 1
|
||||
session_info["last_hearthkeeper_prompt_at"] = None
|
||||
|
||||
# 1. Warden always analyzes (passive mode)
|
||||
warden_result = await self.warden.analyze_message(message)
|
||||
if warden_result["is_suspicious"]:
|
||||
actions.append(f"WARDEN_FLAG: {warden_result['severity']}")
|
||||
async with async_session_factory() as db_session:
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
await repo.record_action(
|
||||
session_id=session_id,
|
||||
@@ -153,9 +157,12 @@ class AgentOrchestrator:
|
||||
|
||||
# 2. Check if we should suppress responses due to active chat
|
||||
recent_messages = []
|
||||
async with async_session_factory() as db_session:
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
recent_messages = await repo.get_recent_messages(session_id, limit=10)
|
||||
recent_messages = await repo.get_messages_since(
|
||||
session_id=session_id,
|
||||
since=datetime.utcnow() - timedelta(minutes=1),
|
||||
)
|
||||
|
||||
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
||||
logger.debug("Response suppressed due to active chat")
|
||||
@@ -164,25 +171,7 @@ class AgentOrchestrator:
|
||||
"actions_taken": actions,
|
||||
}
|
||||
|
||||
# 3. Hearthkeeper: Generate prompt if chat inactive
|
||||
if self.chat_activity.should_hearthkeeper_prompt(session_id):
|
||||
try:
|
||||
agent_response = await self.hearthkeeper.generate_prompt(
|
||||
theme=session_info.get("theme")
|
||||
)
|
||||
actions.append("HEARTHKEEPER_PROMPT")
|
||||
async with async_session_factory() as db_session:
|
||||
repo = Repository(db_session)
|
||||
await repo.record_action(
|
||||
session_id=session_id,
|
||||
action_type=AgentActionType.RESPONSE,
|
||||
mode="hearthkeeper",
|
||||
description=agent_response,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Hearthkeeper: {e}")
|
||||
|
||||
# 4. Librarian: Archive important messages (passive)
|
||||
# 3. Librarian: Archive important messages (passive)
|
||||
if len(message) > 50: # Archive longer messages
|
||||
await self.librarian.archive_message(message_id, message, username)
|
||||
|
||||
@@ -195,6 +184,88 @@ class AgentOrchestrator:
|
||||
"actions_taken": actions,
|
||||
}
|
||||
|
||||
def set_loop_interval(self, interval_seconds: float) -> None:
|
||||
"""Update how frequently the background agent loop runs."""
|
||||
if interval_seconds < 1:
|
||||
raise ValueError("Loop interval must be at least 1 second")
|
||||
self.loop_interval_seconds = interval_seconds
|
||||
|
||||
def get_loop_status(self) -> dict:
|
||||
"""Get background loop configuration and current session count."""
|
||||
return {
|
||||
"interval_seconds": self.loop_interval_seconds,
|
||||
"active_session_count": len(self.active_sessions),
|
||||
}
|
||||
|
||||
async def tick(self) -> list[dict]:
|
||||
"""Evaluate active sessions for time-based agent behavior."""
|
||||
results = []
|
||||
for session_id in list(self.active_sessions.keys()):
|
||||
result = await self._tick_session(session_id)
|
||||
if result:
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def _tick_session(self, session_id: str) -> dict | None:
|
||||
"""Evaluate a single active session during the background loop."""
|
||||
session_info = self.active_sessions.get(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
recent_messages = []
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
recent_messages = await repo.get_messages_since(
|
||||
session_id=session_id,
|
||||
since=datetime.utcnow() - timedelta(minutes=1),
|
||||
)
|
||||
|
||||
if self.response_suppression.should_suppress_response(len(recent_messages)):
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"actions_taken": [],
|
||||
"agent_response": None,
|
||||
"reason": "active_chat",
|
||||
}
|
||||
|
||||
if not self.chat_activity.should_hearthkeeper_prompt(session_id):
|
||||
return None
|
||||
|
||||
last_activity_at = self.chat_activity.last_activity_at(session_id)
|
||||
last_prompt_at = session_info.get("last_hearthkeeper_prompt_at")
|
||||
if last_prompt_at and last_activity_at and last_prompt_at >= last_activity_at:
|
||||
return None
|
||||
|
||||
try:
|
||||
agent_response = await self.hearthkeeper.generate_prompt(
|
||||
theme=session_info.get("theme")
|
||||
)
|
||||
session_info["last_hearthkeeper_prompt_at"] = datetime.utcnow()
|
||||
|
||||
async for db_session in get_session():
|
||||
repo = Repository(db_session)
|
||||
await repo.record_action(
|
||||
session_id=session_id,
|
||||
action_type=AgentActionType.RESPONSE,
|
||||
mode="hearthkeeper",
|
||||
description=agent_response,
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"actions_taken": ["HEARTHKEEPER_PROMPT"],
|
||||
"agent_response": agent_response,
|
||||
"reason": "inactive_chat",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Hearthkeeper loop: {e}")
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"actions_taken": [],
|
||||
"agent_response": None,
|
||||
"reason": "hearthkeeper_error",
|
||||
}
|
||||
|
||||
async def get_session_status(self, session_id: str) -> dict:
|
||||
"""Get status of a session."""
|
||||
if session_id not in self.active_sessions:
|
||||
|
||||
@@ -19,9 +19,13 @@ class ChatActivityPolicy:
|
||||
self.inactivity_threshold = timedelta(minutes=inactivity_threshold_minutes)
|
||||
self.last_message_time: dict[str, datetime] = {}
|
||||
|
||||
def record_activity(self, session_id: str) -> None:
|
||||
def record_activity(self, session_id: str, occurred_at: datetime | None = None) -> None:
|
||||
"""Record that chat activity occurred."""
|
||||
self.last_message_time[session_id] = datetime.utcnow()
|
||||
self.last_message_time[session_id] = occurred_at or datetime.utcnow()
|
||||
|
||||
def last_activity_at(self, session_id: str) -> datetime | None:
|
||||
"""Get the most recent chat activity time for a session."""
|
||||
return self.last_message_time.get(session_id)
|
||||
|
||||
def minutes_since_activity(self, session_id: str) -> int:
|
||||
"""Get minutes since last chat message."""
|
||||
|
||||
Reference in New Issue
Block a user