Files
ws-sanctum-chronicler/app/memory/repository.py

206 lines
6.3 KiB
Python

"""Data access layer for database operations."""
import logging
import uuid
from datetime import datetime
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.memory.models import (
StreamSession,
ChatMessage,
AgentAction,
ClipCandidate,
BlogSeed,
AgentActionType,
)
logger = logging.getLogger(__name__)
class Repository:
"""Repository for all database operations."""
def __init__(self, session: AsyncSession):
self.session = session
# Stream Session operations
async def create_session(self, channel_name: str) -> str:
"""Create a new stream session."""
session_id = str(uuid.uuid4())
session = StreamSession(
id=session_id,
channel_name=channel_name,
started_at=datetime.utcnow(),
)
self.session.add(session)
await self.session.commit()
logger.info(f"Created session {session_id} for {channel_name}")
return session_id
async def end_session(self, session_id: str) -> None:
"""End a stream session."""
stmt = (
update(StreamSession)
.where(StreamSession.id == session_id)
.values(ended_at=datetime.utcnow(), is_active=False)
)
await self.session.execute(stmt)
await self.session.commit()
logger.info(f"Ended session {session_id}")
async def get_session(self, session_id: str) -> StreamSession | None:
"""Retrieve a session by ID."""
stmt = select(StreamSession).where(StreamSession.id == session_id)
result = await self.session.execute(stmt)
return result.scalars().first()
# Chat Message operations
async def add_chat_message(
self,
session_id: str,
username: str,
content: str,
is_bot: bool = False,
is_moderator: bool = False,
) -> str:
"""Add a chat message to the database."""
message_id = str(uuid.uuid4())
message = ChatMessage(
id=message_id,
session_id=session_id,
username=username,
content=content,
is_bot=is_bot,
is_moderator=is_moderator,
timestamp=datetime.utcnow(),
)
self.session.add(message)
await self.session.commit()
logger.debug(f"Stored chat message from {username}")
return message_id
async def get_recent_messages(
self, session_id: str, limit: int = 50
) -> list[ChatMessage]:
"""Get recent chat messages from a session."""
stmt = (
select(ChatMessage)
.where(ChatMessage.session_id == session_id)
.order_by(ChatMessage.timestamp.desc())
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_messages_since(
self, session_id: str, since: datetime
) -> list[ChatMessage]:
"""Get messages recorded since a specific timestamp."""
stmt = (
select(ChatMessage)
.where(
ChatMessage.session_id == session_id,
ChatMessage.timestamp >= since,
)
.order_by(ChatMessage.timestamp.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Agent Action operations
async def record_action(
self,
session_id: str,
action_type: AgentActionType,
mode: str,
description: str,
triggered_by_message_id: str | None = None,
) -> str:
"""Record an agent action."""
action_id = str(uuid.uuid4())
action = AgentAction(
id=action_id,
session_id=session_id,
action_type=action_type,
mode=mode,
triggered_by_message_id=triggered_by_message_id,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(action)
await self.session.commit()
logger.info(f"Recorded agent action: {action_type} via {mode}")
return action_id
async def get_session_actions(self, session_id: str) -> list[AgentAction]:
"""Get all actions from a session."""
stmt = (
select(AgentAction)
.where(AgentAction.session_id == session_id)
.order_by(AgentAction.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Clip Candidate operations
async def add_clip_candidate(
self, session_id: str, message_id: str, reason: str
) -> str:
"""Add a clip candidate."""
candidate_id = str(uuid.uuid4())
candidate = ClipCandidate(
id=candidate_id,
session_id=session_id,
message_id=message_id,
reason=reason,
timestamp=datetime.utcnow(),
)
self.session.add(candidate)
await self.session.commit()
logger.info(f"Added clip candidate: {reason}")
return candidate_id
async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]:
"""Get all clip candidates from a session."""
stmt = (
select(ClipCandidate)
.where(ClipCandidate.session_id == session_id)
.order_by(ClipCandidate.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Blog Seed operations
async def add_blog_seed(
self, session_id: str, topic: str, description: str
) -> str:
"""Add a blog post seed."""
seed_id = str(uuid.uuid4())
seed = BlogSeed(
id=seed_id,
session_id=session_id,
topic=topic,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(seed)
await self.session.commit()
logger.info(f"Added blog seed: {topic}")
return seed_id
async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]:
"""Get all blog seeds from a session."""
stmt = (
select(BlogSeed)
.where(BlogSeed.session_id == session_id)
.order_by(BlogSeed.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())