Add automatic skip detection for playlist playback
Background poller monitors HA media_player state during playlist sessions. When a track transition occurs and the previous track was played < 40% of its duration, automatically records "skip" feedback. Also includes the previously uncommitted delete_feedback endpoint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlalchemy.orm import joinedload
|
|||||||
|
|
||||||
from haunt_fm.db import get_session
|
from haunt_fm.db import get_session
|
||||||
from haunt_fm.models.track import FeedbackEvent, Track
|
from haunt_fm.models.track import FeedbackEvent, Track
|
||||||
from haunt_fm.services.feedback import compute_contextual_score, record_feedback
|
from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/feedback")
|
router = APIRouter(prefix="/api/feedback")
|
||||||
|
|
||||||
@@ -41,6 +41,24 @@ async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{event_id}")
|
||||||
|
async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)):
|
||||||
|
"""Delete a feedback event entirely, removing its influence on scoring."""
|
||||||
|
event = await delete_feedback(session, event_id)
|
||||||
|
if event is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": event.id,
|
||||||
|
"playlist_id": event.playlist_id,
|
||||||
|
"track_id": event.track_id,
|
||||||
|
"signal": event.signal,
|
||||||
|
"signal_weight": event.signal_weight,
|
||||||
|
"vibe_text": event.vibe_text,
|
||||||
|
"created_at": event.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/score")
|
@router.post("/score")
|
||||||
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
|
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
|
||||||
"""Compute the contextual feedback score for a track given a vibe description."""
|
"""Compute the contextual feedback score for a track given a vibe description."""
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
|
|||||||
"artist": t.artist,
|
"artist": t.artist,
|
||||||
"title": t.title,
|
"title": t.title,
|
||||||
"album": t.album,
|
"album": t.album,
|
||||||
|
"duration_ms": t.duration_ms,
|
||||||
"is_known": pt.is_known,
|
"is_known": pt.is_known,
|
||||||
"similarity_score": pt.similarity_score,
|
"similarity_score": pt.similarity_score,
|
||||||
}
|
}
|
||||||
@@ -72,6 +73,11 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
|
|||||||
if req.auto_play and req.speaker_entity:
|
if req.auto_play and req.speaker_entity:
|
||||||
await play_playlist_on_speaker(track_list, req.speaker_entity)
|
await play_playlist_on_speaker(track_list, req.speaker_entity)
|
||||||
|
|
||||||
|
# Register with skip detector for automatic skip feedback
|
||||||
|
from haunt_fm.services.skip_detector import register_session
|
||||||
|
|
||||||
|
register_session(req.speaker_entity, playlist.id, track_list)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"playlist_id": playlist.id,
|
"playlist_id": playlist.id,
|
||||||
"name": playlist.name,
|
"name": playlist.name,
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from haunt_fm.config import settings
|
|||||||
from haunt_fm.services.embedding import is_model_loaded
|
from haunt_fm.services.embedding import is_model_loaded
|
||||||
from haunt_fm.services.embedding_worker import is_running as is_worker_running
|
from haunt_fm.services.embedding_worker import is_running as is_worker_running
|
||||||
from haunt_fm.services.embedding_worker import last_processed as worker_last_processed
|
from haunt_fm.services.embedding_worker import last_processed as worker_last_processed
|
||||||
|
from haunt_fm.services.skip_detector import get_sessions as get_skip_sessions
|
||||||
|
from haunt_fm.services.skip_detector import is_running as is_skip_detector_running
|
||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
@@ -103,6 +105,25 @@ async def status(session: AsyncSession = Depends(get_session)):
|
|||||||
"total_generated": total_playlists,
|
"total_generated": total_playlists,
|
||||||
"last_generated": last_playlist.isoformat() if last_playlist else None,
|
"last_generated": last_playlist.isoformat() if last_playlist else None,
|
||||||
},
|
},
|
||||||
|
"skip_detector": {
|
||||||
|
"running": is_skip_detector_running(),
|
||||||
|
"active_sessions": len(get_skip_sessions()),
|
||||||
|
"sessions": [
|
||||||
|
{
|
||||||
|
"speaker_entity": entity,
|
||||||
|
"playlist_id": s.playlist_id,
|
||||||
|
"current_position": s.current_position,
|
||||||
|
"total_tracks": len(s.tracks),
|
||||||
|
"current_track": (
|
||||||
|
f"{s.tracks[s.current_position]['artist']} - {s.tracks[s.current_position]['title']}"
|
||||||
|
if s.current_position < len(s.tracks)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"last_activity": s.last_activity_at.isoformat(),
|
||||||
|
}
|
||||||
|
for entity, s in get_skip_sessions().items()
|
||||||
|
],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"lastfm_api": "configured" if settings.lastfm_api_key else "not_configured",
|
"lastfm_api": "configured" if settings.lastfm_api_key else "not_configured",
|
||||||
|
|||||||
@@ -25,7 +25,14 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# Feedback
|
# Feedback
|
||||||
feedback_overlap_threshold: float = 0.85
|
feedback_overlap_threshold: float = 0.85
|
||||||
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3}
|
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3, "neutral": 0.0}
|
||||||
|
|
||||||
|
# Skip detection
|
||||||
|
skip_detector_enabled: bool = True
|
||||||
|
skip_detector_poll_interval_seconds: float = 3.0
|
||||||
|
skip_detector_skip_threshold: float = 0.4 # < 40% played = skip
|
||||||
|
skip_detector_session_timeout_minutes: int = 30
|
||||||
|
skip_detector_min_track_duration_ms: int = 30000 # ignore tracks < 30s
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -27,15 +27,24 @@ async def lifespan(app: FastAPI):
|
|||||||
worker_task = asyncio.create_task(run_worker())
|
worker_task = asyncio.create_task(run_worker())
|
||||||
logger.info("Embedding worker task created")
|
logger.info("Embedding worker task created")
|
||||||
|
|
||||||
|
# Start skip detector in background
|
||||||
|
skip_detector_task = None
|
||||||
|
if settings.skip_detector_enabled:
|
||||||
|
from haunt_fm.services.skip_detector import run_skip_detector
|
||||||
|
|
||||||
|
skip_detector_task = asyncio.create_task(run_skip_detector())
|
||||||
|
logger.info("Skip detector task created")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
if worker_task:
|
for task in [worker_task, skip_detector_task]:
|
||||||
worker_task.cancel()
|
if task:
|
||||||
try:
|
task.cancel()
|
||||||
await worker_task
|
try:
|
||||||
except asyncio.CancelledError:
|
await task
|
||||||
pass
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
logger.info("haunt-fm shut down")
|
logger.info("haunt-fm shut down")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,19 @@ def compute_contextual_score(
|
|||||||
return score
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_feedback(session: AsyncSession, event_id: int) -> FeedbackEvent | None:
|
||||||
|
"""Delete a feedback event by ID. Returns the event if found, None otherwise."""
|
||||||
|
event = await session.get(FeedbackEvent, event_id)
|
||||||
|
if event is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
await session.delete(event)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info("Deleted feedback event %d (signal=%s, track=%d)", event.id, event.signal, event.track_id)
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
async def apply_feedback_adjustments(
|
async def apply_feedback_adjustments(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
recommendations: list[dict],
|
recommendations: list[dict],
|
||||||
|
|||||||
@@ -33,6 +33,27 @@ async def is_ha_reachable() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_speaker_state(speaker_entity: str) -> dict | None:
|
||||||
|
"""Get current playback state from a HA media_player entity.
|
||||||
|
|
||||||
|
Returns dict with state, media_title, media_artist, media_duration, media_position
|
||||||
|
or None if unreachable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await _ha_request("GET", f"/api/states/{speaker_entity}")
|
||||||
|
attrs = data.get("attributes", {})
|
||||||
|
return {
|
||||||
|
"state": data.get("state"),
|
||||||
|
"media_title": attrs.get("media_title"),
|
||||||
|
"media_artist": attrs.get("media_artist"),
|
||||||
|
"media_duration": attrs.get("media_duration"),
|
||||||
|
"media_position": attrs.get("media_position"),
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to get state for %s", speaker_entity)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def play_media_on_speaker(
|
async def play_media_on_speaker(
|
||||||
media_content_id: str,
|
media_content_id: str,
|
||||||
speaker_entity: str,
|
speaker_entity: str,
|
||||||
|
|||||||
297
src/haunt_fm/services/skip_detector.py
Normal file
297
src/haunt_fm/services/skip_detector.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""Background skip detection for haunt-fm playlists.
|
||||||
|
|
||||||
|
Polls Home Assistant media_player state to detect when tracks are skipped
|
||||||
|
(played < threshold % of duration) and automatically records skip feedback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import difflib
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from haunt_fm.config import settings
|
||||||
|
from haunt_fm.db import async_session
|
||||||
|
from haunt_fm.services.feedback import record_feedback
|
||||||
|
from haunt_fm.services.music_assistant import get_speaker_state
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_running = False
|
||||||
|
_sessions: dict[str, "PlaylistSession"] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlaylistSession:
|
||||||
|
playlist_id: int
|
||||||
|
speaker_entity: str
|
||||||
|
tracks: list[dict] # each: {track_id, artist, title, duration_ms, position}
|
||||||
|
current_position: int = 0
|
||||||
|
current_track_started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
last_ha_media_title: str | None = None
|
||||||
|
last_ha_media_artist: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_running() -> bool:
|
||||||
|
return _running
|
||||||
|
|
||||||
|
|
||||||
|
def get_sessions() -> dict[str, "PlaylistSession"]:
|
||||||
|
return dict(_sessions)
|
||||||
|
|
||||||
|
|
||||||
|
def register_session(speaker_entity: str, playlist_id: int, tracks: list[dict]) -> None:
|
||||||
|
"""Register a new playlist session for skip detection.
|
||||||
|
|
||||||
|
Called from playlists.py after auto_play succeeds.
|
||||||
|
"""
|
||||||
|
session = PlaylistSession(
|
||||||
|
playlist_id=playlist_id,
|
||||||
|
speaker_entity=speaker_entity,
|
||||||
|
tracks=tracks,
|
||||||
|
)
|
||||||
|
_sessions[speaker_entity] = session
|
||||||
|
logger.info(
|
||||||
|
"Skip detector: registered session for %s (playlist %d, %d tracks)",
|
||||||
|
speaker_entity,
|
||||||
|
playlist_id,
|
||||||
|
len(tracks),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- String normalization and fuzzy matching ---
|
||||||
|
|
||||||
|
_PAREN_SUFFIX = re.compile(r"\s*\(.*\)\s*$")
|
||||||
|
_EXTRA_WHITESPACE = re.compile(r"\s+")
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(s: str) -> str:
|
||||||
|
"""Normalize a string for fuzzy comparison."""
|
||||||
|
s = s.lower().strip()
|
||||||
|
s = _PAREN_SUFFIX.sub("", s) # strip "(Remastered)", "(Live)", etc.
|
||||||
|
s = _EXTRA_WHITESPACE.sub(" ", s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _fuzzy_score(a: str, b: str) -> float:
|
||||||
|
"""Fuzzy similarity score between two strings (0-1)."""
|
||||||
|
return difflib.SequenceMatcher(None, _normalize(a), _normalize(b)).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def _match_track(
|
||||||
|
ha_title: str, ha_artist: str, tracks: list[dict], hint_position: int
|
||||||
|
) -> int | None:
|
||||||
|
"""Find the best matching track in the playlist for the current HA playback.
|
||||||
|
|
||||||
|
Returns the track's position index, or None if no match found.
|
||||||
|
Checks hint_position+1 first (sequential playback is the common case).
|
||||||
|
"""
|
||||||
|
best_idx = None
|
||||||
|
best_score = 0.0
|
||||||
|
threshold = 0.5
|
||||||
|
|
||||||
|
# Check the next sequential position first (most common case)
|
||||||
|
candidates = []
|
||||||
|
next_pos = hint_position + 1
|
||||||
|
if next_pos < len(tracks):
|
||||||
|
candidates.append(next_pos)
|
||||||
|
# Then check all other positions
|
||||||
|
candidates.extend(i for i in range(len(tracks)) if i != next_pos)
|
||||||
|
|
||||||
|
for idx in candidates:
|
||||||
|
t = tracks[idx]
|
||||||
|
title_score = _fuzzy_score(ha_title, t["title"])
|
||||||
|
artist_score = _fuzzy_score(ha_artist, t["artist"])
|
||||||
|
combined = title_score * 0.7 + artist_score * 0.3
|
||||||
|
|
||||||
|
if combined > best_score:
|
||||||
|
best_score = combined
|
||||||
|
best_idx = idx
|
||||||
|
|
||||||
|
# If sequential position matches well, use it immediately
|
||||||
|
if idx == next_pos and combined >= threshold:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
if best_score >= threshold:
|
||||||
|
return best_idx
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Core polling loop ---
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_skip(
|
||||||
|
session: PlaylistSession, track: dict, elapsed_ms: float
|
||||||
|
) -> None:
|
||||||
|
"""Check if a track was skipped and record feedback if so."""
|
||||||
|
duration_ms = track.get("duration_ms")
|
||||||
|
if duration_ms is None or duration_ms <= 0:
|
||||||
|
logger.debug("Skip eval: no duration for track %d, skipping evaluation", track["track_id"])
|
||||||
|
return
|
||||||
|
|
||||||
|
if duration_ms < settings.skip_detector_min_track_duration_ms:
|
||||||
|
logger.debug(
|
||||||
|
"Skip eval: track %d too short (%dms), ignoring",
|
||||||
|
track["track_id"],
|
||||||
|
duration_ms,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
fraction_played = elapsed_ms / duration_ms
|
||||||
|
if fraction_played < settings.skip_detector_skip_threshold:
|
||||||
|
logger.info(
|
||||||
|
"Skip detected: '%s - %s' played %.0f%% (track_id=%d, playlist_id=%d)",
|
||||||
|
track["artist"],
|
||||||
|
track["title"],
|
||||||
|
fraction_played * 100,
|
||||||
|
track["track_id"],
|
||||||
|
session.playlist_id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await record_feedback(
|
||||||
|
db, session.playlist_id, track["track_id"], "skip"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Skip feedback recorded for track %d in playlist %d",
|
||||||
|
track["track_id"],
|
||||||
|
session.playlist_id,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Playlist without vibe_embedding — can't record contextual feedback
|
||||||
|
logger.warning(
|
||||||
|
"Could not record skip feedback for track %d: %s",
|
||||||
|
track["track_id"],
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Error recording skip feedback for track %d", track["track_id"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Track '%s - %s' played %.0f%%, not a skip",
|
||||||
|
track["artist"],
|
||||||
|
track["title"],
|
||||||
|
fraction_played * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _poll_all_sessions() -> None:
|
||||||
|
"""Single poll cycle across all active sessions."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for entity, session in list(_sessions.items()):
|
||||||
|
# Check session timeout
|
||||||
|
idle_minutes = (now - session.last_activity_at).total_seconds() / 60
|
||||||
|
if idle_minutes > settings.skip_detector_session_timeout_minutes:
|
||||||
|
logger.info(
|
||||||
|
"Skip detector: session for %s timed out (idle %.0f min)",
|
||||||
|
entity,
|
||||||
|
idle_minutes,
|
||||||
|
)
|
||||||
|
to_remove.append(entity)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Poll HA for current state
|
||||||
|
state = await get_speaker_state(entity)
|
||||||
|
if state is None:
|
||||||
|
continue # HA unreachable, try again next cycle
|
||||||
|
|
||||||
|
ha_title = state.get("media_title")
|
||||||
|
ha_artist = state.get("media_artist")
|
||||||
|
|
||||||
|
# No media info — speaker might be idle
|
||||||
|
if not ha_title or not ha_artist:
|
||||||
|
continue
|
||||||
|
|
||||||
|
session.last_activity_at = now
|
||||||
|
|
||||||
|
# Check if track changed
|
||||||
|
if (
|
||||||
|
ha_title == session.last_ha_media_title
|
||||||
|
and ha_artist == session.last_ha_media_artist
|
||||||
|
):
|
||||||
|
# Same track still playing — no transition
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Track transition detected
|
||||||
|
logger.debug(
|
||||||
|
"Track transition on %s: '%s - %s' → '%s - %s'",
|
||||||
|
entity,
|
||||||
|
session.last_ha_media_artist,
|
||||||
|
session.last_ha_media_title,
|
||||||
|
ha_artist,
|
||||||
|
ha_title,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate previous track for skip (if we had one)
|
||||||
|
if session.last_ha_media_title is not None:
|
||||||
|
prev_track = (
|
||||||
|
session.tracks[session.current_position]
|
||||||
|
if session.current_position < len(session.tracks)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if prev_track:
|
||||||
|
elapsed_ms = (now - session.current_track_started_at).total_seconds() * 1000
|
||||||
|
await _evaluate_skip(session, prev_track, elapsed_ms)
|
||||||
|
|
||||||
|
# Match new track to playlist
|
||||||
|
matched_pos = _match_track(
|
||||||
|
ha_title, ha_artist, session.tracks, session.current_position
|
||||||
|
)
|
||||||
|
|
||||||
|
if matched_pos is None:
|
||||||
|
logger.info(
|
||||||
|
"Skip detector: '%s - %s' on %s doesn't match playlist — killing session",
|
||||||
|
ha_artist,
|
||||||
|
ha_title,
|
||||||
|
entity,
|
||||||
|
)
|
||||||
|
to_remove.append(entity)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update session state
|
||||||
|
session.current_position = matched_pos
|
||||||
|
session.current_track_started_at = now
|
||||||
|
session.last_ha_media_title = ha_title
|
||||||
|
session.last_ha_media_artist = ha_artist
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Skip detector: %s now at position %d (%s - %s)",
|
||||||
|
entity,
|
||||||
|
matched_pos,
|
||||||
|
session.tracks[matched_pos]["artist"],
|
||||||
|
session.tracks[matched_pos]["title"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in to_remove:
|
||||||
|
_sessions.pop(entity, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_skip_detector() -> None:
|
||||||
|
"""Background loop that polls HA and detects skips."""
|
||||||
|
global _running
|
||||||
|
|
||||||
|
if not settings.skip_detector_enabled:
|
||||||
|
logger.info("Skip detector disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
_running = True
|
||||||
|
logger.info("Skip detector started (poll interval: %.1fs)", settings.skip_detector_poll_interval_seconds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if _sessions:
|
||||||
|
await _poll_all_sessions()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Skip detector poll error")
|
||||||
|
|
||||||
|
await asyncio.sleep(settings.skip_detector_poll_interval_seconds)
|
||||||
|
finally:
|
||||||
|
_running = False
|
||||||
|
logger.info("Skip detector stopped")
|
||||||
Reference in New Issue
Block a user