"""Data access layer for database operations.""" import logging import uuid from datetime import datetime from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.memory.models import ( StreamSession, ChatMessage, AgentAction, ClipCandidate, BlogSeed, AgentActionType, ) logger = logging.getLogger(__name__) class Repository: """Repository for all database operations.""" def __init__(self, session: AsyncSession): self.session = session # Stream Session operations async def create_session(self, channel_name: str) -> str: """Create a new stream session.""" session_id = str(uuid.uuid4()) session = StreamSession( id=session_id, channel_name=channel_name, started_at=datetime.utcnow(), ) self.session.add(session) await self.session.commit() logger.info(f"Created session {session_id} for {channel_name}") return session_id async def end_session(self, session_id: str) -> None: """End a stream session.""" stmt = ( update(StreamSession) .where(StreamSession.id == session_id) .values(ended_at=datetime.utcnow(), is_active=False) ) await self.session.execute(stmt) await self.session.commit() logger.info(f"Ended session {session_id}") async def get_session(self, session_id: str) -> StreamSession | None: """Retrieve a session by ID.""" stmt = select(StreamSession).where(StreamSession.id == session_id) result = await self.session.execute(stmt) return result.scalars().first() async def get_active_sessions(self) -> list[StreamSession]: """Retrieve sessions that are still marked active.""" stmt = ( select(StreamSession) .where(StreamSession.is_active.is_(True)) .order_by(StreamSession.started_at.asc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) # Chat Message operations async def add_chat_message( self, session_id: str, username: str, content: str, is_bot: bool = False, is_moderator: bool = False, ) -> str: """Add a chat message to the database.""" message_id = str(uuid.uuid4()) message = ChatMessage( id=message_id, session_id=session_id, username=username, content=content, is_bot=is_bot, is_moderator=is_moderator, timestamp=datetime.utcnow(), ) self.session.add(message) await self.session.commit() logger.debug(f"Stored chat message from {username}") return message_id async def get_recent_messages( self, session_id: str, limit: int = 50 ) -> list[ChatMessage]: """Get recent chat messages from a session.""" stmt = ( select(ChatMessage) .where(ChatMessage.session_id == session_id) .order_by(ChatMessage.timestamp.desc()) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_recent_human_messages( self, session_id: str, limit: int = 50 ) -> list[ChatMessage]: """Get recent non-bot chat messages from a session.""" stmt = ( select(ChatMessage) .where( ChatMessage.session_id == session_id, ChatMessage.is_bot.is_(False), ) .order_by(ChatMessage.timestamp.desc()) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def count_messages(self, session_id: str) -> int: """Count chat messages stored for a session.""" stmt = select(func.count()).select_from(ChatMessage).where( ChatMessage.session_id == session_id ) result = await self.session.execute(stmt) return result.scalar_one() async def get_messages_since( self, session_id: str, since: datetime ) -> list[ChatMessage]: """Get messages recorded since a specific timestamp.""" stmt = ( select(ChatMessage) .where( ChatMessage.session_id == session_id, ChatMessage.timestamp >= since, ) .order_by(ChatMessage.timestamp.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_human_messages_since( self, session_id: str, since: datetime ) -> list[ChatMessage]: """Get non-bot messages recorded since a specific timestamp.""" stmt = ( select(ChatMessage) .where( ChatMessage.session_id == session_id, ChatMessage.is_bot.is_(False), ChatMessage.timestamp >= since, ) .order_by(ChatMessage.timestamp.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) # Agent Action operations async def record_action( self, session_id: str, action_type: AgentActionType, mode: str, description: str, triggered_by_message_id: str | None = None, ) -> str: """Record an agent action.""" action_id = str(uuid.uuid4()) action = AgentAction( id=action_id, session_id=session_id, action_type=action_type, mode=mode, triggered_by_message_id=triggered_by_message_id, description=description, timestamp=datetime.utcnow(), ) self.session.add(action) await self.session.commit() logger.info(f"Recorded agent action: {action_type} via {mode}") return action_id async def get_session_actions(self, session_id: str) -> list[AgentAction]: """Get all actions from a session.""" stmt = ( select(AgentAction) .where(AgentAction.session_id == session_id) .order_by(AgentAction.timestamp.asc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) # Clip Candidate operations async def add_clip_candidate( self, session_id: str, message_id: str, reason: str ) -> str: """Add a clip candidate.""" candidate_id = str(uuid.uuid4()) candidate = ClipCandidate( id=candidate_id, session_id=session_id, message_id=message_id, reason=reason, timestamp=datetime.utcnow(), ) self.session.add(candidate) await self.session.commit() logger.info(f"Added clip candidate: {reason}") return candidate_id async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]: """Get all clip candidates from a session.""" stmt = ( select(ClipCandidate) .where(ClipCandidate.session_id == session_id) .order_by(ClipCandidate.timestamp.asc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) # Blog Seed operations async def add_blog_seed( self, session_id: str, topic: str, description: str ) -> str: """Add a blog post seed.""" seed_id = str(uuid.uuid4()) seed = BlogSeed( id=seed_id, session_id=session_id, topic=topic, description=description, timestamp=datetime.utcnow(), ) self.session.add(seed) await self.session.commit() logger.info(f"Added blog seed: {topic}") return seed_id async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]: """Get all blog seeds from a session.""" stmt = ( select(BlogSeed) .where(BlogSeed.session_id == session_id) .order_by(BlogSeed.timestamp.asc()) ) result = await self.session.execute(stmt) return list(result.scalars().all())