256 lines
8.1 KiB
Python
256 lines
8.1 KiB
Python
"""Data access layer for database operations."""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.memory.models import (
|
|
StreamSession,
|
|
ChatMessage,
|
|
AgentAction,
|
|
ClipCandidate,
|
|
BlogSeed,
|
|
AgentActionType,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Repository:
|
|
"""Repository for all database operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
# Stream Session operations
|
|
|
|
async def create_session(self, channel_name: str) -> str:
|
|
"""Create a new stream session."""
|
|
session_id = str(uuid.uuid4())
|
|
session = StreamSession(
|
|
id=session_id,
|
|
channel_name=channel_name,
|
|
started_at=datetime.utcnow(),
|
|
)
|
|
self.session.add(session)
|
|
await self.session.commit()
|
|
logger.info(f"Created session {session_id} for {channel_name}")
|
|
return session_id
|
|
|
|
async def end_session(self, session_id: str) -> None:
|
|
"""End a stream session."""
|
|
stmt = (
|
|
update(StreamSession)
|
|
.where(StreamSession.id == session_id)
|
|
.values(ended_at=datetime.utcnow(), is_active=False)
|
|
)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
logger.info(f"Ended session {session_id}")
|
|
|
|
async def get_session(self, session_id: str) -> StreamSession | None:
|
|
"""Retrieve a session by ID."""
|
|
stmt = select(StreamSession).where(StreamSession.id == session_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalars().first()
|
|
|
|
async def get_active_sessions(self) -> list[StreamSession]:
|
|
"""Retrieve sessions that are still marked active."""
|
|
stmt = (
|
|
select(StreamSession)
|
|
.where(StreamSession.is_active.is_(True))
|
|
.order_by(StreamSession.started_at.asc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# Chat Message operations
|
|
|
|
async def add_chat_message(
|
|
self,
|
|
session_id: str,
|
|
username: str,
|
|
content: str,
|
|
is_bot: bool = False,
|
|
is_moderator: bool = False,
|
|
) -> str:
|
|
"""Add a chat message to the database."""
|
|
message_id = str(uuid.uuid4())
|
|
message = ChatMessage(
|
|
id=message_id,
|
|
session_id=session_id,
|
|
username=username,
|
|
content=content,
|
|
is_bot=is_bot,
|
|
is_moderator=is_moderator,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
self.session.add(message)
|
|
await self.session.commit()
|
|
logger.debug(f"Stored chat message from {username}")
|
|
return message_id
|
|
|
|
async def get_recent_messages(
|
|
self, session_id: str, limit: int = 50
|
|
) -> list[ChatMessage]:
|
|
"""Get recent chat messages from a session."""
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(ChatMessage.session_id == session_id)
|
|
.order_by(ChatMessage.timestamp.desc())
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_recent_human_messages(
|
|
self, session_id: str, limit: int = 50
|
|
) -> list[ChatMessage]:
|
|
"""Get recent non-bot chat messages from a session."""
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.session_id == session_id,
|
|
ChatMessage.is_bot.is_(False),
|
|
)
|
|
.order_by(ChatMessage.timestamp.desc())
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_messages(self, session_id: str) -> int:
|
|
"""Count chat messages stored for a session."""
|
|
stmt = select(func.count()).select_from(ChatMessage).where(
|
|
ChatMessage.session_id == session_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one()
|
|
|
|
async def get_messages_since(
|
|
self, session_id: str, since: datetime
|
|
) -> list[ChatMessage]:
|
|
"""Get messages recorded since a specific timestamp."""
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.session_id == session_id,
|
|
ChatMessage.timestamp >= since,
|
|
)
|
|
.order_by(ChatMessage.timestamp.desc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_human_messages_since(
|
|
self, session_id: str, since: datetime
|
|
) -> list[ChatMessage]:
|
|
"""Get non-bot messages recorded since a specific timestamp."""
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.session_id == session_id,
|
|
ChatMessage.is_bot.is_(False),
|
|
ChatMessage.timestamp >= since,
|
|
)
|
|
.order_by(ChatMessage.timestamp.desc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# Agent Action operations
|
|
|
|
async def record_action(
|
|
self,
|
|
session_id: str,
|
|
action_type: AgentActionType,
|
|
mode: str,
|
|
description: str,
|
|
triggered_by_message_id: str | None = None,
|
|
) -> str:
|
|
"""Record an agent action."""
|
|
action_id = str(uuid.uuid4())
|
|
action = AgentAction(
|
|
id=action_id,
|
|
session_id=session_id,
|
|
action_type=action_type,
|
|
mode=mode,
|
|
triggered_by_message_id=triggered_by_message_id,
|
|
description=description,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
self.session.add(action)
|
|
await self.session.commit()
|
|
logger.info(f"Recorded agent action: {action_type} via {mode}")
|
|
return action_id
|
|
|
|
async def get_session_actions(self, session_id: str) -> list[AgentAction]:
|
|
"""Get all actions from a session."""
|
|
stmt = (
|
|
select(AgentAction)
|
|
.where(AgentAction.session_id == session_id)
|
|
.order_by(AgentAction.timestamp.asc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# Clip Candidate operations
|
|
|
|
async def add_clip_candidate(
|
|
self, session_id: str, message_id: str, reason: str
|
|
) -> str:
|
|
"""Add a clip candidate."""
|
|
candidate_id = str(uuid.uuid4())
|
|
candidate = ClipCandidate(
|
|
id=candidate_id,
|
|
session_id=session_id,
|
|
message_id=message_id,
|
|
reason=reason,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
self.session.add(candidate)
|
|
await self.session.commit()
|
|
logger.info(f"Added clip candidate: {reason}")
|
|
return candidate_id
|
|
|
|
async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]:
|
|
"""Get all clip candidates from a session."""
|
|
stmt = (
|
|
select(ClipCandidate)
|
|
.where(ClipCandidate.session_id == session_id)
|
|
.order_by(ClipCandidate.timestamp.asc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# Blog Seed operations
|
|
|
|
async def add_blog_seed(
|
|
self, session_id: str, topic: str, description: str
|
|
) -> str:
|
|
"""Add a blog post seed."""
|
|
seed_id = str(uuid.uuid4())
|
|
seed = BlogSeed(
|
|
id=seed_id,
|
|
session_id=session_id,
|
|
topic=topic,
|
|
description=description,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
self.session.add(seed)
|
|
await self.session.commit()
|
|
logger.info(f"Added blog seed: {topic}")
|
|
return seed_id
|
|
|
|
async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]:
|
|
"""Get all blog seeds from a session."""
|
|
stmt = (
|
|
select(BlogSeed)
|
|
.where(BlogSeed.session_id == session_id)
|
|
.order_by(BlogSeed.timestamp.asc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|