from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from haunt_fm.db import get_session from haunt_fm.models.track import FeedbackEvent, Track from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback router = APIRouter(prefix="/api/feedback") class FeedbackRequest(BaseModel): playlist_id: int track_id: int signal: str class ScoreRequest(BaseModel): track_id: int vibe_text: str @router.post("") async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(get_session)): """Submit feedback for a track in the context of a playlist's vibe.""" try: event = await record_feedback(session, req.playlist_id, req.track_id, req.signal) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return { "id": event.id, "playlist_id": event.playlist_id, "track_id": event.track_id, "signal": event.signal, "signal_weight": event.signal_weight, "vibe_text": event.vibe_text, "created_at": event.created_at.isoformat(), } @router.delete("/{event_id}") async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)): """Delete a feedback event entirely, removing its influence on scoring.""" event = await delete_feedback(session, event_id) if event is None: raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found") return { "id": event.id, "playlist_id": event.playlist_id, "track_id": event.track_id, "signal": event.signal, "signal_weight": event.signal_weight, "vibe_text": event.vibe_text, "created_at": event.created_at.isoformat(), } @router.post("/score") async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)): """Compute the contextual feedback score for a track given a vibe description.""" from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model if not is_model_loaded(): load_model() vibe_embedding = embed_text(req.vibe_text) result = await session.execute( select(FeedbackEvent).where(FeedbackEvent.track_id == req.track_id) ) events = list(result.scalars().all()) score = compute_contextual_score(events, vibe_embedding) # Build breakdown of contributing events import numpy as np breakdown = [] for event in events: event_emb = np.array(event.vibe_embedding, dtype=np.float32) event_norm = np.linalg.norm(event_emb) vibe_norm = np.linalg.norm(vibe_embedding) if event_norm > 0 and vibe_norm > 0: cos_sim = float(np.dot(event_emb / event_norm, vibe_embedding / vibe_norm)) else: cos_sim = 0.0 breakdown.append({ "id": event.id, "signal": event.signal, "signal_weight": event.signal_weight, "vibe_text": event.vibe_text, "cosine_similarity": round(cos_sim, 4), "created_at": event.created_at.isoformat(), }) return { "track_id": req.track_id, "vibe_text": req.vibe_text, "contextual_score": round(score, 4), "event_count": len(events), "breakdown": breakdown, } @router.get("/history") async def get_history( limit: int = Query(default=50, ge=1, le=200), track_id: int | None = Query(default=None), session: AsyncSession = Depends(get_session), ): """Get recent feedback events, optionally filtered by track.""" query = ( select(FeedbackEvent) .options(joinedload(FeedbackEvent.track)) .order_by(FeedbackEvent.created_at.desc()) .limit(limit) ) if track_id is not None: query = query.where(FeedbackEvent.track_id == track_id) result = await session.execute(query) events = result.scalars().unique().all() return { "events": [ { "id": e.id, "playlist_id": e.playlist_id, "track_id": e.track_id, "artist": e.track.artist, "title": e.track.title, "signal": e.signal, "signal_weight": e.signal_weight, "vibe_text": e.vibe_text, "created_at": e.created_at.isoformat(), } for e in events ], "count": len(events), }