Add automatic skip detection for playlist playback
Background poller monitors HA media_player state during playlist sessions. When a track transition occurs and the previous track was played < 40% of its duration, automatically records "skip" feedback. Also includes the previously uncommitted delete_feedback endpoint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlalchemy.orm import joinedload
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.models.track import FeedbackEvent, Track
|
||||
from haunt_fm.services.feedback import compute_contextual_score, record_feedback
|
||||
from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback
|
||||
|
||||
router = APIRouter(prefix="/api/feedback")
|
||||
|
||||
@@ -41,6 +41,24 @@ async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{event_id}")
|
||||
async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)):
|
||||
"""Delete a feedback event entirely, removing its influence on scoring."""
|
||||
event = await delete_feedback(session, event_id)
|
||||
if event is None:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found")
|
||||
|
||||
return {
|
||||
"id": event.id,
|
||||
"playlist_id": event.playlist_id,
|
||||
"track_id": event.track_id,
|
||||
"signal": event.signal,
|
||||
"signal_weight": event.signal_weight,
|
||||
"vibe_text": event.vibe_text,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/score")
|
||||
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
|
||||
"""Compute the contextual feedback score for a track given a vibe description."""
|
||||
|
||||
@@ -62,6 +62,7 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
|
||||
"artist": t.artist,
|
||||
"title": t.title,
|
||||
"album": t.album,
|
||||
"duration_ms": t.duration_ms,
|
||||
"is_known": pt.is_known,
|
||||
"similarity_score": pt.similarity_score,
|
||||
}
|
||||
@@ -72,6 +73,11 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
|
||||
if req.auto_play and req.speaker_entity:
|
||||
await play_playlist_on_speaker(track_list, req.speaker_entity)
|
||||
|
||||
# Register with skip detector for automatic skip feedback
|
||||
from haunt_fm.services.skip_detector import register_session
|
||||
|
||||
register_session(req.speaker_entity, playlist.id, track_list)
|
||||
|
||||
return {
|
||||
"playlist_id": playlist.id,
|
||||
"name": playlist.name,
|
||||
|
||||
@@ -16,6 +16,8 @@ from haunt_fm.config import settings
|
||||
from haunt_fm.services.embedding import is_model_loaded
|
||||
from haunt_fm.services.embedding_worker import is_running as is_worker_running
|
||||
from haunt_fm.services.embedding_worker import last_processed as worker_last_processed
|
||||
from haunt_fm.services.skip_detector import get_sessions as get_skip_sessions
|
||||
from haunt_fm.services.skip_detector import is_running as is_skip_detector_running
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
@@ -103,6 +105,25 @@ async def status(session: AsyncSession = Depends(get_session)):
|
||||
"total_generated": total_playlists,
|
||||
"last_generated": last_playlist.isoformat() if last_playlist else None,
|
||||
},
|
||||
"skip_detector": {
|
||||
"running": is_skip_detector_running(),
|
||||
"active_sessions": len(get_skip_sessions()),
|
||||
"sessions": [
|
||||
{
|
||||
"speaker_entity": entity,
|
||||
"playlist_id": s.playlist_id,
|
||||
"current_position": s.current_position,
|
||||
"total_tracks": len(s.tracks),
|
||||
"current_track": (
|
||||
f"{s.tracks[s.current_position]['artist']} - {s.tracks[s.current_position]['title']}"
|
||||
if s.current_position < len(s.tracks)
|
||||
else None
|
||||
),
|
||||
"last_activity": s.last_activity_at.isoformat(),
|
||||
}
|
||||
for entity, s in get_skip_sessions().items()
|
||||
],
|
||||
},
|
||||
},
|
||||
"dependencies": {
|
||||
"lastfm_api": "configured" if settings.lastfm_api_key else "not_configured",
|
||||
|
||||
@@ -25,7 +25,14 @@ class Settings(BaseSettings):
|
||||
|
||||
# Feedback
|
||||
feedback_overlap_threshold: float = 0.85
|
||||
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3}
|
||||
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3, "neutral": 0.0}
|
||||
|
||||
# Skip detection
|
||||
skip_detector_enabled: bool = True
|
||||
skip_detector_poll_interval_seconds: float = 3.0
|
||||
skip_detector_skip_threshold: float = 0.4 # < 40% played = skip
|
||||
skip_detector_session_timeout_minutes: int = 30
|
||||
skip_detector_min_track_duration_ms: int = 30000 # ignore tracks < 30s
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -27,13 +27,22 @@ async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(run_worker())
|
||||
logger.info("Embedding worker task created")
|
||||
|
||||
# Start skip detector in background
|
||||
skip_detector_task = None
|
||||
if settings.skip_detector_enabled:
|
||||
from haunt_fm.services.skip_detector import run_skip_detector
|
||||
|
||||
skip_detector_task = asyncio.create_task(run_skip_detector())
|
||||
logger.info("Skip detector task created")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if worker_task:
|
||||
worker_task.cancel()
|
||||
for task in [worker_task, skip_detector_task]:
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("haunt-fm shut down")
|
||||
|
||||
@@ -122,6 +122,19 @@ def compute_contextual_score(
|
||||
return score
|
||||
|
||||
|
||||
async def delete_feedback(session: AsyncSession, event_id: int) -> FeedbackEvent | None:
|
||||
"""Delete a feedback event by ID. Returns the event if found, None otherwise."""
|
||||
event = await session.get(FeedbackEvent, event_id)
|
||||
if event is None:
|
||||
return None
|
||||
|
||||
await session.delete(event)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Deleted feedback event %d (signal=%s, track=%d)", event.id, event.signal, event.track_id)
|
||||
return event
|
||||
|
||||
|
||||
async def apply_feedback_adjustments(
|
||||
session: AsyncSession,
|
||||
recommendations: list[dict],
|
||||
|
||||
@@ -33,6 +33,27 @@ async def is_ha_reachable() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def get_speaker_state(speaker_entity: str) -> dict | None:
|
||||
"""Get current playback state from a HA media_player entity.
|
||||
|
||||
Returns dict with state, media_title, media_artist, media_duration, media_position
|
||||
or None if unreachable.
|
||||
"""
|
||||
try:
|
||||
data = await _ha_request("GET", f"/api/states/{speaker_entity}")
|
||||
attrs = data.get("attributes", {})
|
||||
return {
|
||||
"state": data.get("state"),
|
||||
"media_title": attrs.get("media_title"),
|
||||
"media_artist": attrs.get("media_artist"),
|
||||
"media_duration": attrs.get("media_duration"),
|
||||
"media_position": attrs.get("media_position"),
|
||||
}
|
||||
except Exception:
|
||||
logger.debug("Failed to get state for %s", speaker_entity)
|
||||
return None
|
||||
|
||||
|
||||
async def play_media_on_speaker(
|
||||
media_content_id: str,
|
||||
speaker_entity: str,
|
||||
|
||||
297
src/haunt_fm/services/skip_detector.py
Normal file
297
src/haunt_fm/services/skip_detector.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""Background skip detection for haunt-fm playlists.
|
||||
|
||||
Polls Home Assistant media_player state to detect when tracks are skipped
|
||||
(played < threshold % of duration) and automatically records skip feedback.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from haunt_fm.config import settings
|
||||
from haunt_fm.db import async_session
|
||||
from haunt_fm.services.feedback import record_feedback
|
||||
from haunt_fm.services.music_assistant import get_speaker_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_running = False
|
||||
_sessions: dict[str, "PlaylistSession"] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaylistSession:
|
||||
playlist_id: int
|
||||
speaker_entity: str
|
||||
tracks: list[dict] # each: {track_id, artist, title, duration_ms, position}
|
||||
current_position: int = 0
|
||||
current_track_started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_ha_media_title: str | None = None
|
||||
last_ha_media_artist: str | None = None
|
||||
|
||||
|
||||
def is_running() -> bool:
|
||||
return _running
|
||||
|
||||
|
||||
def get_sessions() -> dict[str, "PlaylistSession"]:
|
||||
return dict(_sessions)
|
||||
|
||||
|
||||
def register_session(speaker_entity: str, playlist_id: int, tracks: list[dict]) -> None:
|
||||
"""Register a new playlist session for skip detection.
|
||||
|
||||
Called from playlists.py after auto_play succeeds.
|
||||
"""
|
||||
session = PlaylistSession(
|
||||
playlist_id=playlist_id,
|
||||
speaker_entity=speaker_entity,
|
||||
tracks=tracks,
|
||||
)
|
||||
_sessions[speaker_entity] = session
|
||||
logger.info(
|
||||
"Skip detector: registered session for %s (playlist %d, %d tracks)",
|
||||
speaker_entity,
|
||||
playlist_id,
|
||||
len(tracks),
|
||||
)
|
||||
|
||||
|
||||
# --- String normalization and fuzzy matching ---
|
||||
|
||||
_PAREN_SUFFIX = re.compile(r"\s*\(.*\)\s*$")
|
||||
_EXTRA_WHITESPACE = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
"""Normalize a string for fuzzy comparison."""
|
||||
s = s.lower().strip()
|
||||
s = _PAREN_SUFFIX.sub("", s) # strip "(Remastered)", "(Live)", etc.
|
||||
s = _EXTRA_WHITESPACE.sub(" ", s)
|
||||
return s
|
||||
|
||||
|
||||
def _fuzzy_score(a: str, b: str) -> float:
|
||||
"""Fuzzy similarity score between two strings (0-1)."""
|
||||
return difflib.SequenceMatcher(None, _normalize(a), _normalize(b)).ratio()
|
||||
|
||||
|
||||
def _match_track(
|
||||
ha_title: str, ha_artist: str, tracks: list[dict], hint_position: int
|
||||
) -> int | None:
|
||||
"""Find the best matching track in the playlist for the current HA playback.
|
||||
|
||||
Returns the track's position index, or None if no match found.
|
||||
Checks hint_position+1 first (sequential playback is the common case).
|
||||
"""
|
||||
best_idx = None
|
||||
best_score = 0.0
|
||||
threshold = 0.5
|
||||
|
||||
# Check the next sequential position first (most common case)
|
||||
candidates = []
|
||||
next_pos = hint_position + 1
|
||||
if next_pos < len(tracks):
|
||||
candidates.append(next_pos)
|
||||
# Then check all other positions
|
||||
candidates.extend(i for i in range(len(tracks)) if i != next_pos)
|
||||
|
||||
for idx in candidates:
|
||||
t = tracks[idx]
|
||||
title_score = _fuzzy_score(ha_title, t["title"])
|
||||
artist_score = _fuzzy_score(ha_artist, t["artist"])
|
||||
combined = title_score * 0.7 + artist_score * 0.3
|
||||
|
||||
if combined > best_score:
|
||||
best_score = combined
|
||||
best_idx = idx
|
||||
|
||||
# If sequential position matches well, use it immediately
|
||||
if idx == next_pos and combined >= threshold:
|
||||
return idx
|
||||
|
||||
if best_score >= threshold:
|
||||
return best_idx
|
||||
return None
|
||||
|
||||
|
||||
# --- Core polling loop ---
|
||||
|
||||
|
||||
async def _evaluate_skip(
|
||||
session: PlaylistSession, track: dict, elapsed_ms: float
|
||||
) -> None:
|
||||
"""Check if a track was skipped and record feedback if so."""
|
||||
duration_ms = track.get("duration_ms")
|
||||
if duration_ms is None or duration_ms <= 0:
|
||||
logger.debug("Skip eval: no duration for track %d, skipping evaluation", track["track_id"])
|
||||
return
|
||||
|
||||
if duration_ms < settings.skip_detector_min_track_duration_ms:
|
||||
logger.debug(
|
||||
"Skip eval: track %d too short (%dms), ignoring",
|
||||
track["track_id"],
|
||||
duration_ms,
|
||||
)
|
||||
return
|
||||
|
||||
fraction_played = elapsed_ms / duration_ms
|
||||
if fraction_played < settings.skip_detector_skip_threshold:
|
||||
logger.info(
|
||||
"Skip detected: '%s - %s' played %.0f%% (track_id=%d, playlist_id=%d)",
|
||||
track["artist"],
|
||||
track["title"],
|
||||
fraction_played * 100,
|
||||
track["track_id"],
|
||||
session.playlist_id,
|
||||
)
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await record_feedback(
|
||||
db, session.playlist_id, track["track_id"], "skip"
|
||||
)
|
||||
logger.info(
|
||||
"Skip feedback recorded for track %d in playlist %d",
|
||||
track["track_id"],
|
||||
session.playlist_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
# Playlist without vibe_embedding — can't record contextual feedback
|
||||
logger.warning(
|
||||
"Could not record skip feedback for track %d: %s",
|
||||
track["track_id"],
|
||||
e,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error recording skip feedback for track %d", track["track_id"]
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Track '%s - %s' played %.0f%%, not a skip",
|
||||
track["artist"],
|
||||
track["title"],
|
||||
fraction_played * 100,
|
||||
)
|
||||
|
||||
|
||||
async def _poll_all_sessions() -> None:
|
||||
"""Single poll cycle across all active sessions."""
|
||||
now = datetime.now(timezone.utc)
|
||||
to_remove = []
|
||||
|
||||
for entity, session in list(_sessions.items()):
|
||||
# Check session timeout
|
||||
idle_minutes = (now - session.last_activity_at).total_seconds() / 60
|
||||
if idle_minutes > settings.skip_detector_session_timeout_minutes:
|
||||
logger.info(
|
||||
"Skip detector: session for %s timed out (idle %.0f min)",
|
||||
entity,
|
||||
idle_minutes,
|
||||
)
|
||||
to_remove.append(entity)
|
||||
continue
|
||||
|
||||
# Poll HA for current state
|
||||
state = await get_speaker_state(entity)
|
||||
if state is None:
|
||||
continue # HA unreachable, try again next cycle
|
||||
|
||||
ha_title = state.get("media_title")
|
||||
ha_artist = state.get("media_artist")
|
||||
|
||||
# No media info — speaker might be idle
|
||||
if not ha_title or not ha_artist:
|
||||
continue
|
||||
|
||||
session.last_activity_at = now
|
||||
|
||||
# Check if track changed
|
||||
if (
|
||||
ha_title == session.last_ha_media_title
|
||||
and ha_artist == session.last_ha_media_artist
|
||||
):
|
||||
# Same track still playing — no transition
|
||||
continue
|
||||
|
||||
# Track transition detected
|
||||
logger.debug(
|
||||
"Track transition on %s: '%s - %s' → '%s - %s'",
|
||||
entity,
|
||||
session.last_ha_media_artist,
|
||||
session.last_ha_media_title,
|
||||
ha_artist,
|
||||
ha_title,
|
||||
)
|
||||
|
||||
# Evaluate previous track for skip (if we had one)
|
||||
if session.last_ha_media_title is not None:
|
||||
prev_track = (
|
||||
session.tracks[session.current_position]
|
||||
if session.current_position < len(session.tracks)
|
||||
else None
|
||||
)
|
||||
if prev_track:
|
||||
elapsed_ms = (now - session.current_track_started_at).total_seconds() * 1000
|
||||
await _evaluate_skip(session, prev_track, elapsed_ms)
|
||||
|
||||
# Match new track to playlist
|
||||
matched_pos = _match_track(
|
||||
ha_title, ha_artist, session.tracks, session.current_position
|
||||
)
|
||||
|
||||
if matched_pos is None:
|
||||
logger.info(
|
||||
"Skip detector: '%s - %s' on %s doesn't match playlist — killing session",
|
||||
ha_artist,
|
||||
ha_title,
|
||||
entity,
|
||||
)
|
||||
to_remove.append(entity)
|
||||
continue
|
||||
|
||||
# Update session state
|
||||
session.current_position = matched_pos
|
||||
session.current_track_started_at = now
|
||||
session.last_ha_media_title = ha_title
|
||||
session.last_ha_media_artist = ha_artist
|
||||
|
||||
logger.debug(
|
||||
"Skip detector: %s now at position %d (%s - %s)",
|
||||
entity,
|
||||
matched_pos,
|
||||
session.tracks[matched_pos]["artist"],
|
||||
session.tracks[matched_pos]["title"],
|
||||
)
|
||||
|
||||
for entity in to_remove:
|
||||
_sessions.pop(entity, None)
|
||||
|
||||
|
||||
async def run_skip_detector() -> None:
|
||||
"""Background loop that polls HA and detects skips."""
|
||||
global _running
|
||||
|
||||
if not settings.skip_detector_enabled:
|
||||
logger.info("Skip detector disabled")
|
||||
return
|
||||
|
||||
_running = True
|
||||
logger.info("Skip detector started (poll interval: %.1fs)", settings.skip_detector_poll_interval_seconds)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if _sessions:
|
||||
await _poll_all_sessions()
|
||||
except Exception:
|
||||
logger.exception("Skip detector poll error")
|
||||
|
||||
await asyncio.sleep(settings.skip_detector_poll_interval_seconds)
|
||||
finally:
|
||||
_running = False
|
||||
logger.info("Skip detector stopped")
|
||||
Reference in New Issue
Block a user