Restore active sessions on startup

This commit is contained in:
2026-05-12 08:04:59 -05:00
parent a09197e85a
commit bce93b39e0
3 changed files with 64 additions and 1 deletions

View File

@@ -82,6 +82,34 @@ class AgentOrchestrator:
logger.info(f"Started session {session_id} for {channel_name}") logger.info(f"Started session {session_id} for {channel_name}")
return session_id return session_id
async def restore_active_sessions(self) -> int:
"""Restore active sessions from the database after app startup."""
restored_count = 0
async for db_session in get_session():
repo = Repository(db_session)
sessions = await repo.get_active_sessions()
for session in sessions:
recent_messages = await repo.get_recent_messages(session.id, limit=1)
message_count = await repo.count_messages(session.id)
last_activity_at = (
recent_messages[0].timestamp if recent_messages else session.started_at
)
self.active_sessions[session.id] = {
"channel_name": session.channel_name,
"started_at": session.started_at,
"message_count": message_count,
"theme": session.theme,
"last_hearthkeeper_prompt_at": None,
}
self.chat_activity.record_activity(session.id, occurred_at=last_activity_at)
restored_count += 1
logger.info(f"Restored {restored_count} active sessions")
return restored_count
async def end_session(self, session_id: str) -> None: async def end_session(self, session_id: str) -> None:
""" """
End a stream session and trigger ledger generation. End a stream session and trigger ledger generation.

View File

@@ -51,6 +51,7 @@ async def startup_event():
orchestrator = AgentOrchestrator( orchestrator = AgentOrchestrator(
loop_interval_seconds=settings.AGENT_LOOP_INTERVAL_SECONDS loop_interval_seconds=settings.AGENT_LOOP_INTERVAL_SECONDS
) )
await orchestrator.restore_active_sessions()
agent_loop_task = asyncio.create_task(agent_loop()) agent_loop_task = asyncio.create_task(agent_loop())
logger.info("Application started successfully") logger.info("Application started successfully")
except Exception as e: except Exception as e:
@@ -143,6 +144,22 @@ async def get_ledger(session_id: str) -> dict:
} }
@app.get("/admin/session/status")
async def get_session_status(session_id: str) -> dict:
"""Get status for an active stream session."""
if not orchestrator:
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
status = await orchestrator.get_session_status(session_id)
if not status:
raise HTTPException(status_code=404, detail="Active session not found")
return {
**status,
"timestamp": datetime.utcnow().isoformat(),
}
@app.get("/admin/loop/status") @app.get("/admin/loop/status")
async def get_loop_status() -> dict: async def get_loop_status() -> dict:
"""Get the background agent loop runtime configuration.""" """Get the background agent loop runtime configuration."""

View File

@@ -3,7 +3,7 @@
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import select, update from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.memory.models import ( from app.memory.models import (
@@ -56,6 +56,16 @@ class Repository:
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalars().first() return result.scalars().first()
async def get_active_sessions(self) -> list[StreamSession]:
"""Retrieve sessions that are still marked active."""
stmt = (
select(StreamSession)
.where(StreamSession.is_active.is_(True))
.order_by(StreamSession.started_at.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Chat Message operations # Chat Message operations
async def add_chat_message( async def add_chat_message(
@@ -95,6 +105,14 @@ class Repository:
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return list(result.scalars().all()) return list(result.scalars().all())
async def count_messages(self, session_id: str) -> int:
"""Count chat messages stored for a session."""
stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.session_id == session_id
)
result = await self.session.execute(stmt)
return result.scalar_one()
async def get_messages_since( async def get_messages_since(
self, session_id: str, since: datetime self, session_id: str, since: datetime
) -> list[ChatMessage]: ) -> list[ChatMessage]: