diff --git a/src/haunt_fm/api/feedback.py b/src/haunt_fm/api/feedback.py index 4c6f563..13cf9ad 100644 --- a/src/haunt_fm/api/feedback.py +++ b/src/haunt_fm/api/feedback.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import joinedload from haunt_fm.db import get_session from haunt_fm.models.track import FeedbackEvent, Track -from haunt_fm.services.feedback import compute_contextual_score, record_feedback +from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback router = APIRouter(prefix="/api/feedback") @@ -41,6 +41,24 @@ async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends( } +@router.delete("/{event_id}") +async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)): + """Delete a feedback event entirely, removing its influence on scoring.""" + event = await delete_feedback(session, event_id) + if event is None: + raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found") + + return { + "id": event.id, + "playlist_id": event.playlist_id, + "track_id": event.track_id, + "signal": event.signal, + "signal_weight": event.signal_weight, + "vibe_text": event.vibe_text, + "created_at": event.created_at.isoformat(), + } + + @router.post("/score") async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)): """Compute the contextual feedback score for a track given a vibe description.""" diff --git a/src/haunt_fm/api/playlists.py b/src/haunt_fm/api/playlists.py index a2d1589..e6a52dc 100644 --- a/src/haunt_fm/api/playlists.py +++ b/src/haunt_fm/api/playlists.py @@ -62,6 +62,7 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses "artist": t.artist, "title": t.title, "album": t.album, + "duration_ms": t.duration_ms, "is_known": pt.is_known, "similarity_score": pt.similarity_score, } @@ -72,6 +73,11 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses if req.auto_play and req.speaker_entity: await play_playlist_on_speaker(track_list, req.speaker_entity) + # Register with skip detector for automatic skip feedback + from haunt_fm.services.skip_detector import register_session + + register_session(req.speaker_entity, playlist.id, track_list) + return { "playlist_id": playlist.id, "name": playlist.name, diff --git a/src/haunt_fm/api/status.py b/src/haunt_fm/api/status.py index 0d623f5..6d94539 100644 --- a/src/haunt_fm/api/status.py +++ b/src/haunt_fm/api/status.py @@ -16,6 +16,8 @@ from haunt_fm.config import settings from haunt_fm.services.embedding import is_model_loaded from haunt_fm.services.embedding_worker import is_running as is_worker_running from haunt_fm.services.embedding_worker import last_processed as worker_last_processed +from haunt_fm.services.skip_detector import get_sessions as get_skip_sessions +from haunt_fm.services.skip_detector import is_running as is_skip_detector_running router = APIRouter(prefix="/api") @@ -103,6 +105,25 @@ async def status(session: AsyncSession = Depends(get_session)): "total_generated": total_playlists, "last_generated": last_playlist.isoformat() if last_playlist else None, }, + "skip_detector": { + "running": is_skip_detector_running(), + "active_sessions": len(get_skip_sessions()), + "sessions": [ + { + "speaker_entity": entity, + "playlist_id": s.playlist_id, + "current_position": s.current_position, + "total_tracks": len(s.tracks), + "current_track": ( + f"{s.tracks[s.current_position]['artist']} - {s.tracks[s.current_position]['title']}" + if s.current_position < len(s.tracks) + else None + ), + "last_activity": s.last_activity_at.isoformat(), + } + for entity, s in get_skip_sessions().items() + ], + }, }, "dependencies": { "lastfm_api": "configured" if settings.lastfm_api_key else "not_configured", diff --git a/src/haunt_fm/config.py b/src/haunt_fm/config.py index 2e35c38..5b46da5 100644 --- a/src/haunt_fm/config.py +++ b/src/haunt_fm/config.py @@ -25,7 +25,14 @@ class Settings(BaseSettings): # Feedback feedback_overlap_threshold: float = 0.85 - feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3} + feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3, "neutral": 0.0} + + # Skip detection + skip_detector_enabled: bool = True + skip_detector_poll_interval_seconds: float = 3.0 + skip_detector_skip_threshold: float = 0.4 # < 40% played = skip + skip_detector_session_timeout_minutes: int = 30 + skip_detector_min_track_duration_ms: int = 30000 # ignore tracks < 30s settings = Settings() diff --git a/src/haunt_fm/main.py b/src/haunt_fm/main.py index f153fae..69d7639 100644 --- a/src/haunt_fm/main.py +++ b/src/haunt_fm/main.py @@ -27,15 +27,24 @@ async def lifespan(app: FastAPI): worker_task = asyncio.create_task(run_worker()) logger.info("Embedding worker task created") + # Start skip detector in background + skip_detector_task = None + if settings.skip_detector_enabled: + from haunt_fm.services.skip_detector import run_skip_detector + + skip_detector_task = asyncio.create_task(run_skip_detector()) + logger.info("Skip detector task created") + yield # Shutdown - if worker_task: - worker_task.cancel() - try: - await worker_task - except asyncio.CancelledError: - pass + for task in [worker_task, skip_detector_task]: + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass logger.info("haunt-fm shut down") diff --git a/src/haunt_fm/services/feedback.py b/src/haunt_fm/services/feedback.py index 2beb264..6f34c85 100644 --- a/src/haunt_fm/services/feedback.py +++ b/src/haunt_fm/services/feedback.py @@ -122,6 +122,19 @@ def compute_contextual_score( return score +async def delete_feedback(session: AsyncSession, event_id: int) -> FeedbackEvent | None: + """Delete a feedback event by ID. Returns the event if found, None otherwise.""" + event = await session.get(FeedbackEvent, event_id) + if event is None: + return None + + await session.delete(event) + await session.commit() + + logger.info("Deleted feedback event %d (signal=%s, track=%d)", event.id, event.signal, event.track_id) + return event + + async def apply_feedback_adjustments( session: AsyncSession, recommendations: list[dict], diff --git a/src/haunt_fm/services/music_assistant.py b/src/haunt_fm/services/music_assistant.py index d396768..c16dd08 100644 --- a/src/haunt_fm/services/music_assistant.py +++ b/src/haunt_fm/services/music_assistant.py @@ -33,6 +33,27 @@ async def is_ha_reachable() -> bool: return False +async def get_speaker_state(speaker_entity: str) -> dict | None: + """Get current playback state from a HA media_player entity. + + Returns dict with state, media_title, media_artist, media_duration, media_position + or None if unreachable. + """ + try: + data = await _ha_request("GET", f"/api/states/{speaker_entity}") + attrs = data.get("attributes", {}) + return { + "state": data.get("state"), + "media_title": attrs.get("media_title"), + "media_artist": attrs.get("media_artist"), + "media_duration": attrs.get("media_duration"), + "media_position": attrs.get("media_position"), + } + except Exception: + logger.debug("Failed to get state for %s", speaker_entity) + return None + + async def play_media_on_speaker( media_content_id: str, speaker_entity: str, diff --git a/src/haunt_fm/services/skip_detector.py b/src/haunt_fm/services/skip_detector.py new file mode 100644 index 0000000..f29ee2d --- /dev/null +++ b/src/haunt_fm/services/skip_detector.py @@ -0,0 +1,297 @@ +"""Background skip detection for haunt-fm playlists. + +Polls Home Assistant media_player state to detect when tracks are skipped +(played < threshold % of duration) and automatically records skip feedback. +""" + +import asyncio +import difflib +import logging +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone + +from haunt_fm.config import settings +from haunt_fm.db import async_session +from haunt_fm.services.feedback import record_feedback +from haunt_fm.services.music_assistant import get_speaker_state + +logger = logging.getLogger(__name__) + +_running = False +_sessions: dict[str, "PlaylistSession"] = {} + + +@dataclass +class PlaylistSession: + playlist_id: int + speaker_entity: str + tracks: list[dict] # each: {track_id, artist, title, duration_ms, position} + current_position: int = 0 + current_track_started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_ha_media_title: str | None = None + last_ha_media_artist: str | None = None + + +def is_running() -> bool: + return _running + + +def get_sessions() -> dict[str, "PlaylistSession"]: + return dict(_sessions) + + +def register_session(speaker_entity: str, playlist_id: int, tracks: list[dict]) -> None: + """Register a new playlist session for skip detection. + + Called from playlists.py after auto_play succeeds. + """ + session = PlaylistSession( + playlist_id=playlist_id, + speaker_entity=speaker_entity, + tracks=tracks, + ) + _sessions[speaker_entity] = session + logger.info( + "Skip detector: registered session for %s (playlist %d, %d tracks)", + speaker_entity, + playlist_id, + len(tracks), + ) + + +# --- String normalization and fuzzy matching --- + +_PAREN_SUFFIX = re.compile(r"\s*\(.*\)\s*$") +_EXTRA_WHITESPACE = re.compile(r"\s+") + + +def _normalize(s: str) -> str: + """Normalize a string for fuzzy comparison.""" + s = s.lower().strip() + s = _PAREN_SUFFIX.sub("", s) # strip "(Remastered)", "(Live)", etc. + s = _EXTRA_WHITESPACE.sub(" ", s) + return s + + +def _fuzzy_score(a: str, b: str) -> float: + """Fuzzy similarity score between two strings (0-1).""" + return difflib.SequenceMatcher(None, _normalize(a), _normalize(b)).ratio() + + +def _match_track( + ha_title: str, ha_artist: str, tracks: list[dict], hint_position: int +) -> int | None: + """Find the best matching track in the playlist for the current HA playback. + + Returns the track's position index, or None if no match found. + Checks hint_position+1 first (sequential playback is the common case). + """ + best_idx = None + best_score = 0.0 + threshold = 0.5 + + # Check the next sequential position first (most common case) + candidates = [] + next_pos = hint_position + 1 + if next_pos < len(tracks): + candidates.append(next_pos) + # Then check all other positions + candidates.extend(i for i in range(len(tracks)) if i != next_pos) + + for idx in candidates: + t = tracks[idx] + title_score = _fuzzy_score(ha_title, t["title"]) + artist_score = _fuzzy_score(ha_artist, t["artist"]) + combined = title_score * 0.7 + artist_score * 0.3 + + if combined > best_score: + best_score = combined + best_idx = idx + + # If sequential position matches well, use it immediately + if idx == next_pos and combined >= threshold: + return idx + + if best_score >= threshold: + return best_idx + return None + + +# --- Core polling loop --- + + +async def _evaluate_skip( + session: PlaylistSession, track: dict, elapsed_ms: float +) -> None: + """Check if a track was skipped and record feedback if so.""" + duration_ms = track.get("duration_ms") + if duration_ms is None or duration_ms <= 0: + logger.debug("Skip eval: no duration for track %d, skipping evaluation", track["track_id"]) + return + + if duration_ms < settings.skip_detector_min_track_duration_ms: + logger.debug( + "Skip eval: track %d too short (%dms), ignoring", + track["track_id"], + duration_ms, + ) + return + + fraction_played = elapsed_ms / duration_ms + if fraction_played < settings.skip_detector_skip_threshold: + logger.info( + "Skip detected: '%s - %s' played %.0f%% (track_id=%d, playlist_id=%d)", + track["artist"], + track["title"], + fraction_played * 100, + track["track_id"], + session.playlist_id, + ) + try: + async with async_session() as db: + await record_feedback( + db, session.playlist_id, track["track_id"], "skip" + ) + logger.info( + "Skip feedback recorded for track %d in playlist %d", + track["track_id"], + session.playlist_id, + ) + except ValueError as e: + # Playlist without vibe_embedding — can't record contextual feedback + logger.warning( + "Could not record skip feedback for track %d: %s", + track["track_id"], + e, + ) + except Exception: + logger.exception( + "Error recording skip feedback for track %d", track["track_id"] + ) + else: + logger.debug( + "Track '%s - %s' played %.0f%%, not a skip", + track["artist"], + track["title"], + fraction_played * 100, + ) + + +async def _poll_all_sessions() -> None: + """Single poll cycle across all active sessions.""" + now = datetime.now(timezone.utc) + to_remove = [] + + for entity, session in list(_sessions.items()): + # Check session timeout + idle_minutes = (now - session.last_activity_at).total_seconds() / 60 + if idle_minutes > settings.skip_detector_session_timeout_minutes: + logger.info( + "Skip detector: session for %s timed out (idle %.0f min)", + entity, + idle_minutes, + ) + to_remove.append(entity) + continue + + # Poll HA for current state + state = await get_speaker_state(entity) + if state is None: + continue # HA unreachable, try again next cycle + + ha_title = state.get("media_title") + ha_artist = state.get("media_artist") + + # No media info — speaker might be idle + if not ha_title or not ha_artist: + continue + + session.last_activity_at = now + + # Check if track changed + if ( + ha_title == session.last_ha_media_title + and ha_artist == session.last_ha_media_artist + ): + # Same track still playing — no transition + continue + + # Track transition detected + logger.debug( + "Track transition on %s: '%s - %s' → '%s - %s'", + entity, + session.last_ha_media_artist, + session.last_ha_media_title, + ha_artist, + ha_title, + ) + + # Evaluate previous track for skip (if we had one) + if session.last_ha_media_title is not None: + prev_track = ( + session.tracks[session.current_position] + if session.current_position < len(session.tracks) + else None + ) + if prev_track: + elapsed_ms = (now - session.current_track_started_at).total_seconds() * 1000 + await _evaluate_skip(session, prev_track, elapsed_ms) + + # Match new track to playlist + matched_pos = _match_track( + ha_title, ha_artist, session.tracks, session.current_position + ) + + if matched_pos is None: + logger.info( + "Skip detector: '%s - %s' on %s doesn't match playlist — killing session", + ha_artist, + ha_title, + entity, + ) + to_remove.append(entity) + continue + + # Update session state + session.current_position = matched_pos + session.current_track_started_at = now + session.last_ha_media_title = ha_title + session.last_ha_media_artist = ha_artist + + logger.debug( + "Skip detector: %s now at position %d (%s - %s)", + entity, + matched_pos, + session.tracks[matched_pos]["artist"], + session.tracks[matched_pos]["title"], + ) + + for entity in to_remove: + _sessions.pop(entity, None) + + +async def run_skip_detector() -> None: + """Background loop that polls HA and detects skips.""" + global _running + + if not settings.skip_detector_enabled: + logger.info("Skip detector disabled") + return + + _running = True + logger.info("Skip detector started (poll interval: %.1fs)", settings.skip_detector_poll_interval_seconds) + + try: + while True: + try: + if _sessions: + await _poll_all_sessions() + except Exception: + logger.exception("Skip detector poll error") + + await asyncio.sleep(settings.skip_detector_poll_interval_seconds) + finally: + _running = False + logger.info("Skip detector stopped")