127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy.orm import joinedload
|
||
|
|
|
||
|
|
from haunt_fm.db import get_session
|
||
|
|
from haunt_fm.models.track import FeedbackEvent, Track
|
||
|
|
from haunt_fm.services.feedback import compute_contextual_score, record_feedback
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/api/feedback")
|
||
|
|
|
||
|
|
|
||
|
|
class FeedbackRequest(BaseModel):
|
||
|
|
playlist_id: int
|
||
|
|
track_id: int
|
||
|
|
signal: str
|
||
|
|
|
||
|
|
|
||
|
|
class ScoreRequest(BaseModel):
|
||
|
|
track_id: int
|
||
|
|
vibe_text: str
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("")
|
||
|
|
async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(get_session)):
|
||
|
|
"""Submit feedback for a track in the context of a playlist's vibe."""
|
||
|
|
try:
|
||
|
|
event = await record_feedback(session, req.playlist_id, req.track_id, req.signal)
|
||
|
|
except ValueError as e:
|
||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
||
|
|
|
||
|
|
return {
|
||
|
|
"id": event.id,
|
||
|
|
"playlist_id": event.playlist_id,
|
||
|
|
"track_id": event.track_id,
|
||
|
|
"signal": event.signal,
|
||
|
|
"signal_weight": event.signal_weight,
|
||
|
|
"vibe_text": event.vibe_text,
|
||
|
|
"created_at": event.created_at.isoformat(),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/score")
|
||
|
|
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
|
||
|
|
"""Compute the contextual feedback score for a track given a vibe description."""
|
||
|
|
from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model
|
||
|
|
|
||
|
|
if not is_model_loaded():
|
||
|
|
load_model()
|
||
|
|
vibe_embedding = embed_text(req.vibe_text)
|
||
|
|
|
||
|
|
result = await session.execute(
|
||
|
|
select(FeedbackEvent).where(FeedbackEvent.track_id == req.track_id)
|
||
|
|
)
|
||
|
|
events = list(result.scalars().all())
|
||
|
|
|
||
|
|
score = compute_contextual_score(events, vibe_embedding)
|
||
|
|
|
||
|
|
# Build breakdown of contributing events
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
breakdown = []
|
||
|
|
for event in events:
|
||
|
|
event_emb = np.array(event.vibe_embedding, dtype=np.float32)
|
||
|
|
event_norm = np.linalg.norm(event_emb)
|
||
|
|
vibe_norm = np.linalg.norm(vibe_embedding)
|
||
|
|
if event_norm > 0 and vibe_norm > 0:
|
||
|
|
cos_sim = float(np.dot(event_emb / event_norm, vibe_embedding / vibe_norm))
|
||
|
|
else:
|
||
|
|
cos_sim = 0.0
|
||
|
|
|
||
|
|
breakdown.append({
|
||
|
|
"id": event.id,
|
||
|
|
"signal": event.signal,
|
||
|
|
"signal_weight": event.signal_weight,
|
||
|
|
"vibe_text": event.vibe_text,
|
||
|
|
"cosine_similarity": round(cos_sim, 4),
|
||
|
|
"created_at": event.created_at.isoformat(),
|
||
|
|
})
|
||
|
|
|
||
|
|
return {
|
||
|
|
"track_id": req.track_id,
|
||
|
|
"vibe_text": req.vibe_text,
|
||
|
|
"contextual_score": round(score, 4),
|
||
|
|
"event_count": len(events),
|
||
|
|
"breakdown": breakdown,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/history")
|
||
|
|
async def get_history(
|
||
|
|
limit: int = Query(default=50, ge=1, le=200),
|
||
|
|
track_id: int | None = Query(default=None),
|
||
|
|
session: AsyncSession = Depends(get_session),
|
||
|
|
):
|
||
|
|
"""Get recent feedback events, optionally filtered by track."""
|
||
|
|
query = (
|
||
|
|
select(FeedbackEvent)
|
||
|
|
.options(joinedload(FeedbackEvent.track))
|
||
|
|
.order_by(FeedbackEvent.created_at.desc())
|
||
|
|
.limit(limit)
|
||
|
|
)
|
||
|
|
if track_id is not None:
|
||
|
|
query = query.where(FeedbackEvent.track_id == track_id)
|
||
|
|
|
||
|
|
result = await session.execute(query)
|
||
|
|
events = result.scalars().unique().all()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"events": [
|
||
|
|
{
|
||
|
|
"id": e.id,
|
||
|
|
"playlist_id": e.playlist_id,
|
||
|
|
"track_id": e.track_id,
|
||
|
|
"artist": e.track.artist,
|
||
|
|
"title": e.track.title,
|
||
|
|
"signal": e.signal,
|
||
|
|
"signal_weight": e.signal_weight,
|
||
|
|
"vibe_text": e.vibe_text,
|
||
|
|
"created_at": e.created_at.isoformat(),
|
||
|
|
}
|
||
|
|
for e in events
|
||
|
|
],
|
||
|
|
"count": len(events),
|
||
|
|
}
|