AI generated first iteration

This commit is contained in:
2026-05-11 15:01:55 -05:00
parent af3e282fda
commit 412d7caec3
28 changed files with 2094 additions and 157 deletions

190
app/memory/repository.py Normal file
View File

@@ -0,0 +1,190 @@
"""Data access layer for database operations."""
import logging
import uuid
from datetime import datetime
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.memory.models import (
StreamSession,
ChatMessage,
AgentAction,
ClipCandidate,
BlogSeed,
AgentActionType,
)
logger = logging.getLogger(__name__)
class Repository:
"""Repository for all database operations."""
def __init__(self, session: AsyncSession):
self.session = session
# Stream Session operations
async def create_session(self, channel_name: str) -> str:
"""Create a new stream session."""
session_id = str(uuid.uuid4())
session = StreamSession(
id=session_id,
channel_name=channel_name,
started_at=datetime.utcnow(),
)
self.session.add(session)
await self.session.commit()
logger.info(f"Created session {session_id} for {channel_name}")
return session_id
async def end_session(self, session_id: str) -> None:
"""End a stream session."""
stmt = (
update(StreamSession)
.where(StreamSession.id == session_id)
.values(ended_at=datetime.utcnow(), is_active=False)
)
await self.session.execute(stmt)
await self.session.commit()
logger.info(f"Ended session {session_id}")
async def get_session(self, session_id: str) -> StreamSession | None:
"""Retrieve a session by ID."""
stmt = select(StreamSession).where(StreamSession.id == session_id)
result = await self.session.execute(stmt)
return result.scalars().first()
# Chat Message operations
async def add_chat_message(
self,
session_id: str,
username: str,
content: str,
is_bot: bool = False,
is_moderator: bool = False,
) -> str:
"""Add a chat message to the database."""
message_id = str(uuid.uuid4())
message = ChatMessage(
id=message_id,
session_id=session_id,
username=username,
content=content,
is_bot=is_bot,
is_moderator=is_moderator,
timestamp=datetime.utcnow(),
)
self.session.add(message)
await self.session.commit()
logger.debug(f"Stored chat message from {username}")
return message_id
async def get_recent_messages(
self, session_id: str, limit: int = 50
) -> list[ChatMessage]:
"""Get recent chat messages from a session."""
stmt = (
select(ChatMessage)
.where(ChatMessage.session_id == session_id)
.order_by(ChatMessage.timestamp.desc())
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Agent Action operations
async def record_action(
self,
session_id: str,
action_type: AgentActionType,
mode: str,
description: str,
triggered_by_message_id: str | None = None,
) -> str:
"""Record an agent action."""
action_id = str(uuid.uuid4())
action = AgentAction(
id=action_id,
session_id=session_id,
action_type=action_type,
mode=mode,
triggered_by_message_id=triggered_by_message_id,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(action)
await self.session.commit()
logger.info(f"Recorded agent action: {action_type} via {mode}")
return action_id
async def get_session_actions(self, session_id: str) -> list[AgentAction]:
"""Get all actions from a session."""
stmt = (
select(AgentAction)
.where(AgentAction.session_id == session_id)
.order_by(AgentAction.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Clip Candidate operations
async def add_clip_candidate(
self, session_id: str, message_id: str, reason: str
) -> str:
"""Add a clip candidate."""
candidate_id = str(uuid.uuid4())
candidate = ClipCandidate(
id=candidate_id,
session_id=session_id,
message_id=message_id,
reason=reason,
timestamp=datetime.utcnow(),
)
self.session.add(candidate)
await self.session.commit()
logger.info(f"Added clip candidate: {reason}")
return candidate_id
async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]:
"""Get all clip candidates from a session."""
stmt = (
select(ClipCandidate)
.where(ClipCandidate.session_id == session_id)
.order_by(ClipCandidate.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
# Blog Seed operations
async def add_blog_seed(
self, session_id: str, topic: str, description: str
) -> str:
"""Add a blog post seed."""
seed_id = str(uuid.uuid4())
seed = BlogSeed(
id=seed_id,
session_id=session_id,
topic=topic,
description=description,
timestamp=datetime.utcnow(),
)
self.session.add(seed)
await self.session.commit()
logger.info(f"Added blog seed: {topic}")
return seed_id
async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]:
"""Get all blog seeds from a session."""
stmt = (
select(BlogSeed)
.where(BlogSeed.session_id == session_id)
.order_by(BlogSeed.timestamp.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())