Files
haunt-fm/src/haunt_fm/api/feedback.py

127 lines
3.9 KiB
Python
Raw Normal View History

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from haunt_fm.db import get_session
from haunt_fm.models.track import FeedbackEvent, Track
from haunt_fm.services.feedback import compute_contextual_score, record_feedback
router = APIRouter(prefix="/api/feedback")
class FeedbackRequest(BaseModel):
playlist_id: int
track_id: int
signal: str
class ScoreRequest(BaseModel):
track_id: int
vibe_text: str
@router.post("")
async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(get_session)):
"""Submit feedback for a track in the context of a playlist's vibe."""
try:
event = await record_feedback(session, req.playlist_id, req.track_id, req.signal)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {
"id": event.id,
"playlist_id": event.playlist_id,
"track_id": event.track_id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"created_at": event.created_at.isoformat(),
}
@router.post("/score")
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
"""Compute the contextual feedback score for a track given a vibe description."""
from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model
if not is_model_loaded():
load_model()
vibe_embedding = embed_text(req.vibe_text)
result = await session.execute(
select(FeedbackEvent).where(FeedbackEvent.track_id == req.track_id)
)
events = list(result.scalars().all())
score = compute_contextual_score(events, vibe_embedding)
# Build breakdown of contributing events
import numpy as np
breakdown = []
for event in events:
event_emb = np.array(event.vibe_embedding, dtype=np.float32)
event_norm = np.linalg.norm(event_emb)
vibe_norm = np.linalg.norm(vibe_embedding)
if event_norm > 0 and vibe_norm > 0:
cos_sim = float(np.dot(event_emb / event_norm, vibe_embedding / vibe_norm))
else:
cos_sim = 0.0
breakdown.append({
"id": event.id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"cosine_similarity": round(cos_sim, 4),
"created_at": event.created_at.isoformat(),
})
return {
"track_id": req.track_id,
"vibe_text": req.vibe_text,
"contextual_score": round(score, 4),
"event_count": len(events),
"breakdown": breakdown,
}
@router.get("/history")
async def get_history(
limit: int = Query(default=50, ge=1, le=200),
track_id: int | None = Query(default=None),
session: AsyncSession = Depends(get_session),
):
"""Get recent feedback events, optionally filtered by track."""
query = (
select(FeedbackEvent)
.options(joinedload(FeedbackEvent.track))
.order_by(FeedbackEvent.created_at.desc())
.limit(limit)
)
if track_id is not None:
query = query.where(FeedbackEvent.track_id == track_id)
result = await session.execute(query)
events = result.scalars().unique().all()
return {
"events": [
{
"id": e.id,
"playlist_id": e.playlist_id,
"track_id": e.track_id,
"artist": e.track.artist,
"title": e.track.title,
"signal": e.signal,
"signal_weight": e.signal_weight,
"vibe_text": e.vibe_text,
"created_at": e.created_at.isoformat(),
}
for e in events
],
"count": len(events),
}