AI generated first iteration
This commit is contained in:
1
app/memory/__init__.py
Normal file
1
app/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Memory module exports."""
|
||||
58
app/memory/database.py
Normal file
58
app/memory/database.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Async database configuration and initialization."""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.config import settings
|
||||
from app.memory.models import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global engine and session factory
|
||||
engine = None
|
||||
async_session_factory = None
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize the database engine and create all tables."""
|
||||
global engine, async_session_factory
|
||||
|
||||
try:
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
async_session_factory = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_session() -> AsyncSession:
|
||||
"""Get an async database session."""
|
||||
if async_session_factory is None:
|
||||
raise RuntimeError("Database not initialized. Call init_db() first.")
|
||||
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close the database connection."""
|
||||
global engine
|
||||
if engine:
|
||||
await engine.dispose()
|
||||
logger.info("Database connection closed")
|
||||
84
app/memory/models.py
Normal file
84
app/memory/models.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Memory and database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, String, DateTime, Text, Integer, Boolean, Enum as SQLEnum
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class StreamSession(Base):
|
||||
"""Represents a single stream session."""
|
||||
|
||||
__tablename__ = "stream_sessions"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
channel_name = Column(String, nullable=False, index=True)
|
||||
started_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
ended_at = Column(DateTime, nullable=True)
|
||||
theme = Column(Text, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
"""Represents a chat message from the stream."""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
session_id = Column(String, nullable=False, index=True)
|
||||
username = Column(String, nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
is_bot = Column(Boolean, default=False)
|
||||
is_moderator = Column(Boolean, default=False)
|
||||
|
||||
|
||||
class AgentActionType(str, Enum):
|
||||
"""Types of actions the agent can take."""
|
||||
|
||||
RESPONSE = "response"
|
||||
FLAG_SUSPICIOUS = "flag_suspicious"
|
||||
ARCHIVE_CLIP = "archive_clip"
|
||||
RECORD_SEED = "record_seed"
|
||||
UPDATE_THEME = "update_theme"
|
||||
|
||||
|
||||
class AgentAction(Base):
|
||||
"""Records of agent actions taken during a session."""
|
||||
|
||||
__tablename__ = "agent_actions"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
session_id = Column(String, nullable=False, index=True)
|
||||
action_type = Column(SQLEnum(AgentActionType), nullable=False)
|
||||
mode = Column(String, nullable=False) # hearthkeeper, steward, warden, etc.
|
||||
triggered_by_message_id = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class ClipCandidate(Base):
|
||||
"""Stores potential clip candidates from stream chat."""
|
||||
|
||||
__tablename__ = "clip_candidates"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
session_id = Column(String, nullable=False, index=True)
|
||||
message_id = Column(String, nullable=False)
|
||||
reason = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class BlogSeed(Base):
|
||||
"""Stores potential blog post topics/seeds from stream."""
|
||||
|
||||
__tablename__ = "blog_seeds"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
session_id = Column(String, nullable=False, index=True)
|
||||
topic = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
related_messages = Column(Text, nullable=True) # JSON array of message IDs
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
190
app/memory/repository.py
Normal file
190
app/memory/repository.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Data access layer for database operations."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.memory.models import (
|
||||
StreamSession,
|
||||
ChatMessage,
|
||||
AgentAction,
|
||||
ClipCandidate,
|
||||
BlogSeed,
|
||||
AgentActionType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Repository:
|
||||
"""Repository for all database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
# Stream Session operations
|
||||
|
||||
async def create_session(self, channel_name: str) -> str:
|
||||
"""Create a new stream session."""
|
||||
session_id = str(uuid.uuid4())
|
||||
session = StreamSession(
|
||||
id=session_id,
|
||||
channel_name=channel_name,
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(session)
|
||||
await self.session.commit()
|
||||
logger.info(f"Created session {session_id} for {channel_name}")
|
||||
return session_id
|
||||
|
||||
async def end_session(self, session_id: str) -> None:
|
||||
"""End a stream session."""
|
||||
stmt = (
|
||||
update(StreamSession)
|
||||
.where(StreamSession.id == session_id)
|
||||
.values(ended_at=datetime.utcnow(), is_active=False)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
logger.info(f"Ended session {session_id}")
|
||||
|
||||
async def get_session(self, session_id: str) -> StreamSession | None:
|
||||
"""Retrieve a session by ID."""
|
||||
stmt = select(StreamSession).where(StreamSession.id == session_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
# Chat Message operations
|
||||
|
||||
async def add_chat_message(
|
||||
self,
|
||||
session_id: str,
|
||||
username: str,
|
||||
content: str,
|
||||
is_bot: bool = False,
|
||||
is_moderator: bool = False,
|
||||
) -> str:
|
||||
"""Add a chat message to the database."""
|
||||
message_id = str(uuid.uuid4())
|
||||
message = ChatMessage(
|
||||
id=message_id,
|
||||
session_id=session_id,
|
||||
username=username,
|
||||
content=content,
|
||||
is_bot=is_bot,
|
||||
is_moderator=is_moderator,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(message)
|
||||
await self.session.commit()
|
||||
logger.debug(f"Stored chat message from {username}")
|
||||
return message_id
|
||||
|
||||
async def get_recent_messages(
|
||||
self, session_id: str, limit: int = 50
|
||||
) -> list[ChatMessage]:
|
||||
"""Get recent chat messages from a session."""
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.session_id == session_id)
|
||||
.order_by(ChatMessage.timestamp.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# Agent Action operations
|
||||
|
||||
async def record_action(
|
||||
self,
|
||||
session_id: str,
|
||||
action_type: AgentActionType,
|
||||
mode: str,
|
||||
description: str,
|
||||
triggered_by_message_id: str | None = None,
|
||||
) -> str:
|
||||
"""Record an agent action."""
|
||||
action_id = str(uuid.uuid4())
|
||||
action = AgentAction(
|
||||
id=action_id,
|
||||
session_id=session_id,
|
||||
action_type=action_type,
|
||||
mode=mode,
|
||||
triggered_by_message_id=triggered_by_message_id,
|
||||
description=description,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(action)
|
||||
await self.session.commit()
|
||||
logger.info(f"Recorded agent action: {action_type} via {mode}")
|
||||
return action_id
|
||||
|
||||
async def get_session_actions(self, session_id: str) -> list[AgentAction]:
|
||||
"""Get all actions from a session."""
|
||||
stmt = (
|
||||
select(AgentAction)
|
||||
.where(AgentAction.session_id == session_id)
|
||||
.order_by(AgentAction.timestamp.asc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# Clip Candidate operations
|
||||
|
||||
async def add_clip_candidate(
|
||||
self, session_id: str, message_id: str, reason: str
|
||||
) -> str:
|
||||
"""Add a clip candidate."""
|
||||
candidate_id = str(uuid.uuid4())
|
||||
candidate = ClipCandidate(
|
||||
id=candidate_id,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
reason=reason,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(candidate)
|
||||
await self.session.commit()
|
||||
logger.info(f"Added clip candidate: {reason}")
|
||||
return candidate_id
|
||||
|
||||
async def get_clip_candidates(self, session_id: str) -> list[ClipCandidate]:
|
||||
"""Get all clip candidates from a session."""
|
||||
stmt = (
|
||||
select(ClipCandidate)
|
||||
.where(ClipCandidate.session_id == session_id)
|
||||
.order_by(ClipCandidate.timestamp.asc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# Blog Seed operations
|
||||
|
||||
async def add_blog_seed(
|
||||
self, session_id: str, topic: str, description: str
|
||||
) -> str:
|
||||
"""Add a blog post seed."""
|
||||
seed_id = str(uuid.uuid4())
|
||||
seed = BlogSeed(
|
||||
id=seed_id,
|
||||
session_id=session_id,
|
||||
topic=topic,
|
||||
description=description,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(seed)
|
||||
await self.session.commit()
|
||||
logger.info(f"Added blog seed: {topic}")
|
||||
return seed_id
|
||||
|
||||
async def get_blog_seeds(self, session_id: str) -> list[BlogSeed]:
|
||||
"""Get all blog seeds from a session."""
|
||||
stmt = (
|
||||
select(BlogSeed)
|
||||
.where(BlogSeed.session_id == session_id)
|
||||
.order_by(BlogSeed.timestamp.asc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
Reference in New Issue
Block a user