Files
haunt-fm/src/haunt_fm/api/feedback.py
Thomas Hallock af6159a297 Add automatic skip detection for playlist playback
Background poller monitors HA media_player state during playlist sessions.
When a track transition occurs and the previous track was played < 40% of
its duration, automatically records "skip" feedback. Also includes the
previously uncommitted delete_feedback endpoint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 09:17:52 -06:00

145 lines
4.5 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from haunt_fm.db import get_session
from haunt_fm.models.track import FeedbackEvent, Track
from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback
router = APIRouter(prefix="/api/feedback")
class FeedbackRequest(BaseModel):
playlist_id: int
track_id: int
signal: str
class ScoreRequest(BaseModel):
track_id: int
vibe_text: str
@router.post("")
async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(get_session)):
"""Submit feedback for a track in the context of a playlist's vibe."""
try:
event = await record_feedback(session, req.playlist_id, req.track_id, req.signal)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {
"id": event.id,
"playlist_id": event.playlist_id,
"track_id": event.track_id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"created_at": event.created_at.isoformat(),
}
@router.delete("/{event_id}")
async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)):
"""Delete a feedback event entirely, removing its influence on scoring."""
event = await delete_feedback(session, event_id)
if event is None:
raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found")
return {
"id": event.id,
"playlist_id": event.playlist_id,
"track_id": event.track_id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"created_at": event.created_at.isoformat(),
}
@router.post("/score")
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
"""Compute the contextual feedback score for a track given a vibe description."""
from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model
if not is_model_loaded():
load_model()
vibe_embedding = embed_text(req.vibe_text)
result = await session.execute(
select(FeedbackEvent).where(FeedbackEvent.track_id == req.track_id)
)
events = list(result.scalars().all())
score = compute_contextual_score(events, vibe_embedding)
# Build breakdown of contributing events
import numpy as np
breakdown = []
for event in events:
event_emb = np.array(event.vibe_embedding, dtype=np.float32)
event_norm = np.linalg.norm(event_emb)
vibe_norm = np.linalg.norm(vibe_embedding)
if event_norm > 0 and vibe_norm > 0:
cos_sim = float(np.dot(event_emb / event_norm, vibe_embedding / vibe_norm))
else:
cos_sim = 0.0
breakdown.append({
"id": event.id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"cosine_similarity": round(cos_sim, 4),
"created_at": event.created_at.isoformat(),
})
return {
"track_id": req.track_id,
"vibe_text": req.vibe_text,
"contextual_score": round(score, 4),
"event_count": len(events),
"breakdown": breakdown,
}
@router.get("/history")
async def get_history(
limit: int = Query(default=50, ge=1, le=200),
track_id: int | None = Query(default=None),
session: AsyncSession = Depends(get_session),
):
"""Get recent feedback events, optionally filtered by track."""
query = (
select(FeedbackEvent)
.options(joinedload(FeedbackEvent.track))
.order_by(FeedbackEvent.created_at.desc())
.limit(limit)
)
if track_id is not None:
query = query.where(FeedbackEvent.track_id == track_id)
result = await session.execute(query)
events = result.scalars().unique().all()
return {
"events": [
{
"id": e.id,
"playlist_id": e.playlist_id,
"track_id": e.track_id,
"artist": e.track.artist,
"title": e.track.title,
"signal": e.signal,
"signal_weight": e.signal_weight,
"vibe_text": e.vibe_text,
"created_at": e.created_at.isoformat(),
}
for e in events
],
"count": len(events),
}