Files
ws-sanctum-chronicler/app/memory/repository.py

340 lines
11 KiB
Python

"""Data access layer for database operations."""
import json
import logging
import uuid
from datetime import datetime
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.memory.models import (
StreamSession,
StreamDashboard,
ChatMessage,
AgentAction,
ClipCandidate,
BlogSeed,
AgentActionType,
)
logger = logging.getLogger(__name__)
class Repository:
"""Repository for all database operations."""
def __init__(self, session: AsyncSession):
self.session = session
# Stream Session operations
async def create_session(self, channel_name: str) -> str:
"""Create a new stream session."""
session_id = str(uuid.uuid4())
session = StreamSession(
id=session_id,
channel_name=channel_name,
started_at=datetime.utcnow(),
)
self.session.add(session)
await self.session.commit()
logger.info(f"Created session {session_id} for {channel_name}")
return session_id
async def end_session(self, session_id: str) -> None:
"""End a stream session."""
stmt = (
update(StreamSession)
.where(StreamSession.id == session_id)
.values(ended_at=datetime.utcnow(), is_active=False)
)
await self.session.execute(stmt)
await self.session.commit()
logger.info(f"Ended session {session_id}")
async def get_session(self, session_id: str) -> StreamSession | None:
"""Retrieve a session by ID."""
stmt = select(StreamSession).where(StreamSession.id == session_id)
result = await self.session.execute(stmt)
return result.scalars().first()
async def get_active_sessions(self) -> list[StreamSession]:
"""Retrieve sessions that are still marked active."""
stmt = (
select(StreamSession)
.where(StreamSession.is_active.is_(True))
.order_by(StreamSession.started_at.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Stream Dashboard operations
async def upsert_dashboard(
self,
session_id: str,
raw_markdown: str,
stream_title: str | None = None,
game: str | None = None,
mood: str | None = None,
go_live_notification: str | None = None,
social_post: str | None = None,
session_goals: list[str] | None = None,
content_angle: str | None = None,
) -> StreamDashboard:
"""Create or update a stream dashboard for a session."""
dashboard = await self.get_dashboard(session_id)
now = datetime.utcnow()
goals_json = json.dumps(session_goals or [])
if dashboard is None:
dashboard = StreamDashboard(
session_id=session_id,
raw_markdown=raw_markdown,
stream_title=stream_title,
game=game,
mood=mood,
go_live_notification=go_live_notification,
social_post=social_post,
session_goals=goals_json,
content_angle=content_angle,
created_at=now,
updated_at=now,
)
self.session.add(dashboard)
else:
dashboard.raw_markdown = raw_markdown
dashboard.stream_title = stream_title
dashboard.game = game
dashboard.mood = mood
dashboard.go_live_notification = go_live_notification
dashboard.social_post = social_post
dashboard.session_goals = goals_json
dashboard.content_angle = content_angle
dashboard.updated_at = now
await self.session.commit()
logger.info(f"Saved dashboard for session {session_id}")
return dashboard
async def get_dashboard(self, session_id: str) -> StreamDashboard | None:
"""Retrieve the dashboard for a session."""
stmt = select(StreamDashboard).where(StreamDashboard.session_id == session_id)
result = await self.session.execute(stmt)
return result.scalars().first()
@staticmethod
def serialize_dashboard(dashboard: StreamDashboard | None) -> dict | None:
"""Serialize a dashboard model into a plain dict."""
if dashboard is None:
return None
session_goals = []
if dashboard.session_goals:
try:
session_goals = json.loads(dashboard.session_goals)
except json.JSONDecodeError:
session_goals = []
return {
"session_id": dashboard.session_id,
"raw_markdown": dashboard.raw_markdown,
"stream_title": dashboard.stream_title,
"game": dashboard.game,
"mood": dashboard.mood,
"go_live_notification": dashboard.go_live_notification,
"social_post": dashboard.social_post,
"session_goals": session_goals,
"content_angle": dashboard.content_angle,
"created_at": dashboard.created_at.isoformat(),
"updated_at": dashboard.updated_at.isoformat(),
}
# Chat Message operations
async def add_chat_message(
self,
session_id: str,
username: str,
content: str,
is_bot: bool = False,
is_moderator: bool = False,
) -> str:
"""Add a chat message to the database."""
message_id = str(uuid.uuid4())
message = ChatMessage(
id=message_id,
session_id=session_id,
username=username,
content=content,
is_bot=is_bot,
is_moderator=is_moderator,
timestamp=datetime.utcnow(),
)
self.session.add(message)
await self.session.commit()
logger.debug(f"Stored chat message from {username}")
return message_id
async def get_recent_messages(
self, session_id: str, limit: int = 50
) -> list[ChatMessage]:
"""Get recent chat messages from a session."""
stmt = (
select(ChatMessage)
.where(ChatMessage.session_id == session_id)
.order_by(ChatMessage.timestamp.desc())
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_recent_human_messages(
self, session_id: str, limit: int = 50
) -> list[ChatMessage]:
"""Get recent non-bot chat messages from a session."""
stmt = (
select(ChatMessage)
.where(
ChatMessage.session_id == session_id,
ChatMessage.is_bot.is_(False),
)
.order_by(ChatMessage.timestamp.desc())
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_messages(self, session_id: str) -> int:
"""Count chat messages stored for a session."""
stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.session_id == session_id
)
result = await self.session.execute(stmt)
return result.scalar_one()
async def get_messages_since(
self, session_id: str, since: datetime
) -> list[ChatMessage]:
"""Get messages recorded since a specific timestamp."""
stmt = (
select(ChatMessage)
.where(
ChatMessage.session_id == session_id,
ChatMessage.timestamp >= since,
)
.order_by(ChatMessage.timestamp.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_human_messages_since(
self, session_id: str, since: datetime
) -> list[ChatMessage]:
"""Get non-bot messages recorded since a specific timestamp."""
stmt = (
select(ChatMessage)
.where(
ChatMessage.session_id == session_id,
ChatMessage.is_bot.is_(False),
ChatMessage.timestamp >= since,
)
.order_by(ChatMessage.timestamp.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Agent Action operations
async def record_action(
self,
session_id: str,
action_type: AgentActionType,
mode: str,
description: str,
triggered_by_message_id: str | None = None,
) -> str:
"""Record an agent action."""
action_id = str(uuid.uuid4())
action = AgentAction(
id=action_id,
session_id=session_id,
action_type=action_type,
mode=mode,
triggered_by_message_id=triggered_by_message_id,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(action)
await self.session.commit()
logger.info(f"Recorded agent action: {action_type} via {mode}")
return action_id
async def get_session_actions(self, session_id: str) -> list[AgentAction]:
"""Get all actions from a session."""
stmt = (
select(AgentAction)
.where(AgentAction.session_id == session_id)
.order_by(AgentAction.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Clip Candidate operations
async def add_clip_candidate(
self, session_id: str, message_id: str, reason: str
) -> str:
"""Add a clip candidate."""
candidate_id = str(uuid.uuid4())
candidate = ClipCandidate(
id=candidate_id,
session_id=session_id,
message_id=message_id,
reason=reason,
timestamp=datetime.utcnow(),
)
self.session.add(candidate)
await self.session.commit()
logger.info(f"Added clip candidate: {reason}")
return candidate_id
async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]:
"""Get all clip candidates from a session."""
stmt = (
select(ClipCandidate)
.where(ClipCandidate.session_id == session_id)
.order_by(ClipCandidate.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Blog Seed operations
async def add_blog_seed(
self, session_id: str, topic: str, description: str
) -> str:
"""Add a blog post seed."""
seed_id = str(uuid.uuid4())
seed = BlogSeed(
id=seed_id,
session_id=session_id,
topic=topic,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(seed)
await self.session.commit()
logger.info(f"Added blog seed: {topic}")
return seed_id
async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]:
"""Get all blog seeds from a session."""
stmt = (
select(BlogSeed)
.where(BlogSeed.session_id == session_id)
.order_by(BlogSeed.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())