Add profile-scoped feedback endpoint
New POST /api/profiles/{name}/feedback accepts explicit vibe text and
records feedback against a named profile. GET history endpoint added too.
Scoring now filters feedback by profile_name for profile-aware playlists.
Migration 005 adds profile_name column and makes playlist_id nullable.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
30
alembic/versions/005_add_profile_feedback.py
Normal file
30
alembic/versions/005_add_profile_feedback.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Add profile_name to feedback_events, make playlist_id nullable
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2026-02-23
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "004"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("feedback_events", sa.Column("profile_name", sa.Text, nullable=True))
|
||||
op.create_index("ix_feedback_events_profile_name", "feedback_events", ["profile_name"])
|
||||
|
||||
# Make playlist_id nullable (profile-scoped feedback doesn't require a playlist)
|
||||
op.alter_column("feedback_events", "playlist_id", existing_type=sa.BigInteger, nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("feedback_events", "playlist_id", existing_type=sa.BigInteger, nullable=False)
|
||||
op.drop_index("ix_feedback_events_profile_name", table_name="feedback_events")
|
||||
op.drop_column("feedback_events", "profile_name")
|
||||
@@ -129,6 +129,7 @@ async def get_history(
|
||||
"events": [
|
||||
{
|
||||
"id": e.id,
|
||||
"profile_name": e.profile_name,
|
||||
"playlist_id": e.playlist_id,
|
||||
"track_id": e.track_id,
|
||||
"artist": e.track.artist,
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.models.track import (
|
||||
FeedbackEvent,
|
||||
ListenEvent,
|
||||
Profile,
|
||||
SpeakerProfileMapping,
|
||||
TasteProfile,
|
||||
)
|
||||
from haunt_fm.services.feedback import record_profile_feedback
|
||||
|
||||
router = APIRouter(prefix="/api/profiles")
|
||||
|
||||
@@ -23,6 +26,13 @@ class SetSpeakersRequest(BaseModel):
|
||||
speakers: list[str]
|
||||
|
||||
|
||||
class ProfileFeedbackRequest(BaseModel):
|
||||
track_id: int
|
||||
signal: str
|
||||
vibe: str
|
||||
playlist_id: int | None = None
|
||||
|
||||
|
||||
async def _get_profile_or_404(session: AsyncSession, name: str) -> Profile:
|
||||
result = await session.execute(select(Profile).where(Profile.name == name))
|
||||
profile = result.scalar_one_or_none()
|
||||
@@ -204,3 +214,93 @@ async def get_speakers(name: str, session: AsyncSession = Depends(get_session)):
|
||||
.where(SpeakerProfileMapping.profile_id == profile.id)
|
||||
)
|
||||
return {"profile": name, "speakers": [r.speaker_name for r in result]}
|
||||
|
||||
|
||||
@router.post("/{name}/feedback")
|
||||
async def submit_profile_feedback(
|
||||
name: str, req: ProfileFeedbackRequest, session: AsyncSession = Depends(get_session)
|
||||
):
|
||||
"""Submit feedback for a track scoped to a named profile with explicit vibe text."""
|
||||
from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model
|
||||
|
||||
if not is_model_loaded():
|
||||
load_model()
|
||||
vibe_embedding = embed_text(req.vibe)
|
||||
|
||||
try:
|
||||
event = await record_profile_feedback(
|
||||
session,
|
||||
profile_name=name,
|
||||
track_id=req.track_id,
|
||||
signal=req.signal,
|
||||
vibe_text=req.vibe,
|
||||
vibe_embedding=vibe_embedding.tolist(),
|
||||
playlist_id=req.playlist_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return {
|
||||
"id": event.id,
|
||||
"profile_name": event.profile_name,
|
||||
"playlist_id": event.playlist_id,
|
||||
"track_id": event.track_id,
|
||||
"signal": event.signal,
|
||||
"signal_weight": event.signal_weight,
|
||||
"vibe_text": event.vibe_text,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{name}/feedback/history")
|
||||
async def get_profile_feedback_history(
|
||||
name: str,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
track_id: int | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get feedback history scoped to a profile."""
|
||||
# Verify profile exists
|
||||
await _get_profile_or_404(session, name)
|
||||
|
||||
# For "default", include events where profile_name IS NULL or "default"
|
||||
if name == "default":
|
||||
profile_filter = or_(
|
||||
FeedbackEvent.profile_name.is_(None),
|
||||
FeedbackEvent.profile_name == "default",
|
||||
)
|
||||
else:
|
||||
profile_filter = FeedbackEvent.profile_name == name
|
||||
|
||||
query = (
|
||||
select(FeedbackEvent)
|
||||
.options(joinedload(FeedbackEvent.track))
|
||||
.where(profile_filter)
|
||||
.order_by(FeedbackEvent.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
if track_id is not None:
|
||||
query = query.where(FeedbackEvent.track_id == track_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
events = result.scalars().unique().all()
|
||||
|
||||
return {
|
||||
"profile": name,
|
||||
"events": [
|
||||
{
|
||||
"id": e.id,
|
||||
"profile_name": e.profile_name,
|
||||
"playlist_id": e.playlist_id,
|
||||
"track_id": e.track_id,
|
||||
"artist": e.track.artist,
|
||||
"title": e.track.title,
|
||||
"signal": e.signal,
|
||||
"signal_weight": e.signal_weight,
|
||||
"vibe_text": e.vibe_text,
|
||||
"created_at": e.created_at.isoformat(),
|
||||
}
|
||||
for e in events
|
||||
],
|
||||
"count": len(events),
|
||||
}
|
||||
|
||||
@@ -34,6 +34,6 @@ async def recommendations(
|
||||
)
|
||||
|
||||
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
||||
results = await apply_feedback_adjustments(session, results, vibe_embedding)
|
||||
results = await apply_feedback_adjustments(session, results, vibe_embedding, profile_name=profile or "default")
|
||||
|
||||
return {"recommendations": results, "count": len(results), "vibe": vibe, "alpha": effective_alpha, "profile": profile or "default"}
|
||||
|
||||
@@ -140,8 +140,9 @@ class FeedbackEvent(Base):
|
||||
__tablename__ = "feedback_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
playlist_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("playlists.id"), nullable=False)
|
||||
playlist_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey("playlists.id"), nullable=True)
|
||||
track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False)
|
||||
profile_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
vibe_embedding = mapped_column(Vector(512), nullable=False)
|
||||
vibe_text: Mapped[str | None] = mapped_column(Text)
|
||||
signal: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
@@ -150,7 +151,8 @@ class FeedbackEvent(Base):
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_feedback_events_track_id", "track_id"),
|
||||
Index("ix_feedback_events_profile_name", "profile_name"),
|
||||
)
|
||||
|
||||
playlist: Mapped[Playlist] = relationship(back_populates="feedback_events")
|
||||
playlist: Mapped[Playlist | None] = relationship(back_populates="feedback_events")
|
||||
track: Mapped[Track] = relationship()
|
||||
|
||||
@@ -4,8 +4,10 @@ import numpy as np
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from haunt_fm.config import settings
|
||||
from haunt_fm.models.track import FeedbackEvent, Playlist, Track
|
||||
from haunt_fm.models.track import FeedbackEvent, Playlist, Profile, Track
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +53,55 @@ async def record_feedback(
|
||||
return event
|
||||
|
||||
|
||||
async def record_profile_feedback(
|
||||
session: AsyncSession,
|
||||
profile_name: str,
|
||||
track_id: int,
|
||||
signal: str,
|
||||
vibe_text: str,
|
||||
vibe_embedding: list[float],
|
||||
playlist_id: int | None = None,
|
||||
) -> FeedbackEvent:
|
||||
"""Record a feedback event scoped to a named profile with explicit vibe."""
|
||||
if signal not in VALID_SIGNALS:
|
||||
raise ValueError(f"Invalid signal '{signal}'. Must be one of: {', '.join(sorted(VALID_SIGNALS))}")
|
||||
|
||||
# Verify profile exists
|
||||
result = await session.execute(select(Profile).where(Profile.name == profile_name))
|
||||
if result.scalar_one_or_none() is None:
|
||||
raise ValueError(f"Profile '{profile_name}' not found")
|
||||
|
||||
# Verify track exists
|
||||
track = await session.get(Track, track_id)
|
||||
if track is None:
|
||||
raise ValueError(f"Track {track_id} not found")
|
||||
|
||||
# Verify playlist if provided
|
||||
if playlist_id is not None:
|
||||
playlist = await session.get(Playlist, playlist_id)
|
||||
if playlist is None:
|
||||
raise ValueError(f"Playlist {playlist_id} not found")
|
||||
|
||||
weight = settings.feedback_signal_weights[signal]
|
||||
|
||||
event = FeedbackEvent(
|
||||
playlist_id=playlist_id,
|
||||
track_id=track_id,
|
||||
profile_name=profile_name,
|
||||
vibe_embedding=vibe_embedding,
|
||||
vibe_text=vibe_text,
|
||||
signal=signal,
|
||||
signal_weight=weight,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
await session.refresh(event)
|
||||
|
||||
logger.info("Recorded %s feedback for track %d, profile '%s' (vibe: %s)",
|
||||
signal, track_id, profile_name, vibe_text)
|
||||
return event
|
||||
|
||||
|
||||
def compute_contextual_score(
|
||||
events: list[FeedbackEvent],
|
||||
current_vibe_embedding: np.ndarray,
|
||||
@@ -139,20 +190,30 @@ async def apply_feedback_adjustments(
|
||||
session: AsyncSession,
|
||||
recommendations: list[dict],
|
||||
current_vibe_embedding: np.ndarray | None,
|
||||
profile_name: str = "default",
|
||||
) -> list[dict]:
|
||||
"""Adjust recommendation scores based on contextual feedback.
|
||||
|
||||
Fetches feedback events for the recommended tracks, computes contextual
|
||||
scores, adds them to similarity, and re-sorts.
|
||||
|
||||
When profile_name is "default", includes events where profile_name IS NULL
|
||||
or profile_name = "default" (backward compatible).
|
||||
"""
|
||||
if current_vibe_embedding is None or not recommendations:
|
||||
return recommendations
|
||||
|
||||
track_ids = [r["track_id"] for r in recommendations]
|
||||
|
||||
result = await session.execute(
|
||||
select(FeedbackEvent).where(FeedbackEvent.track_id.in_(track_ids))
|
||||
query = select(FeedbackEvent).where(FeedbackEvent.track_id.in_(track_ids))
|
||||
if profile_name == "default":
|
||||
query = query.where(
|
||||
or_(FeedbackEvent.profile_name.is_(None), FeedbackEvent.profile_name == "default")
|
||||
)
|
||||
else:
|
||||
query = query.where(FeedbackEvent.profile_name == profile_name)
|
||||
|
||||
result = await session.execute(query)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
if not events:
|
||||
|
||||
@@ -61,7 +61,7 @@ async def generate_playlist(
|
||||
)
|
||||
|
||||
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
||||
recs = await apply_feedback_adjustments(session, recs, vibe_embedding)
|
||||
recs = await apply_feedback_adjustments(session, recs, vibe_embedding, profile_name=profile_name)
|
||||
|
||||
new_tracks = [(r["track_id"], r.get("adjusted_score", r["similarity"])) for r in recs[:new_count]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user