Initial haunt-fm implementation
Full music recommendation pipeline: listening history capture via webhook, Last.fm candidate discovery, iTunes preview download, CLAP audio embeddings (512-dim), pgvector cosine similarity recommendations, playlist generation with known/new track interleaving, and Music Assistant playback via HA. Includes: FastAPI app, SQLAlchemy models, Alembic migrations, Docker Compose with pgvector/pg17, status dashboard, and all API endpoints. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
23
.env.example
Normal file
23
.env.example
Normal file
@@ -0,0 +1,23 @@
|
||||
# PostgreSQL
|
||||
POSTGRES_DB=hauntfm
|
||||
POSTGRES_USER=hauntfm
|
||||
POSTGRES_PASSWORD=changeme
|
||||
|
||||
# Database URL (constructed in docker-compose, override for local dev)
|
||||
DATABASE_URL=postgresql+asyncpg://hauntfm:changeme@localhost:5432/hauntfm
|
||||
|
||||
# Last.fm API
|
||||
HAUNTFM_LASTFM_API_KEY=
|
||||
|
||||
# Home Assistant
|
||||
HAUNTFM_HA_URL=http://192.168.86.51:8123
|
||||
HAUNTFM_HA_TOKEN=
|
||||
|
||||
# CLAP model
|
||||
HAUNTFM_MODEL_CACHE_DIR=/app/model-cache
|
||||
HAUNTFM_AUDIO_CACHE_DIR=/app/audio-cache
|
||||
|
||||
# Embedding worker
|
||||
HAUNTFM_EMBEDDING_WORKER_ENABLED=true
|
||||
HAUNTFM_EMBEDDING_BATCH_SIZE=10
|
||||
HAUNTFM_EMBEDDING_INTERVAL_SECONDS=30
|
||||
54
CLAUDE.md
Normal file
54
CLAUDE.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# haunt-fm
|
||||
|
||||
Personal music recommendation service. Captures listening history from Music Assistant, discovers similar tracks via Last.fm, embeds audio with CLAP, and generates playlists.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# On NAS
|
||||
cd /volume1/homes/antialias/projects/haunt-fm
|
||||
docker compose up -d
|
||||
docker compose exec haunt-fm alembic upgrade head
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- **FastAPI** app with async SQLAlchemy + asyncpg
|
||||
- **PostgreSQL + pgvector** for tracks, embeddings, and vector similarity search
|
||||
- **CLAP model** (laion/larger_clap_music) for 512-dim audio embeddings
|
||||
- **Last.fm API** for track similarity discovery
|
||||
- **iTunes Search API** for 30-second audio previews
|
||||
- **Music Assistant** (via Home Assistant REST API) for playback
|
||||
|
||||
## Key Commands
|
||||
|
||||
```bash
|
||||
# Health check
|
||||
curl http://192.168.86.51:8321/health
|
||||
|
||||
# Log a listen event
|
||||
curl -X POST http://192.168.86.51:8321/api/history/webhook \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"title":"Song","artist":"Artist"}'
|
||||
|
||||
# Run discovery
|
||||
curl -X POST http://192.168.86.51:8321/api/admin/discover -H "Content-Type: application/json" -d '{}'
|
||||
|
||||
# Get recommendations
|
||||
curl http://192.168.86.51:8321/api/recommendations?limit=20
|
||||
|
||||
# Generate and play a playlist
|
||||
curl -X POST http://192.168.86.51:8321/api/playlists/generate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"total_tracks":20,"known_pct":30,"speaker_entity":"media_player.living_room_2","auto_play":true}'
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All prefixed with `HAUNTFM_`. See `.env.example` for full list.
|
||||
|
||||
## Database
|
||||
|
||||
- Alembic migrations in `alembic/versions/`
|
||||
- Run migrations: `alembic upgrade head`
|
||||
- Schema: tracks, listen_events, track_embeddings, similarity_links, taste_profiles, playlists, playlist_tracks
|
||||
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
# System deps for librosa/soundfile (ffmpeg for AAC decoding)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python deps
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir .
|
||||
|
||||
# Copy source
|
||||
COPY alembic.ini .
|
||||
COPY alembic/ alembic/
|
||||
COPY src/ src/
|
||||
|
||||
# Install the project itself
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
# Create cache directories
|
||||
RUN mkdir -p /app/model-cache /app/audio-cache
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "haunt_fm.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
36
alembic.ini
Normal file
36
alembic.ini
Normal file
@@ -0,0 +1,36 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
69
alembic/env.py
Normal file
69
alembic/env.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from haunt_fm.models.base import Base
|
||||
|
||||
# Import all models so they register with Base.metadata
|
||||
from haunt_fm.models.track import ( # noqa: F401
|
||||
ListenEvent,
|
||||
Playlist,
|
||||
PlaylistTrack,
|
||||
SimilarityLink,
|
||||
TasteProfile,
|
||||
Track,
|
||||
TrackEmbedding,
|
||||
)
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url from environment
|
||||
db_url = os.environ.get("HAUNTFM_DATABASE_URL", "")
|
||||
if db_url:
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
alembic/script.py.mako
Normal file
26
alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
121
alembic/versions/001_initial_schema.py
Normal file
121
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Initial schema
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-02-22
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Enable pgvector extension
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
# Tracks
|
||||
op.create_table(
|
||||
"tracks",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("title", sa.Text, nullable=False),
|
||||
sa.Column("artist", sa.Text, nullable=False),
|
||||
sa.Column("album", sa.Text),
|
||||
sa.Column("fingerprint", sa.Text, unique=True, nullable=False),
|
||||
sa.Column("lastfm_url", sa.Text),
|
||||
sa.Column("itunes_track_id", sa.BigInteger),
|
||||
sa.Column("itunes_preview_url", sa.Text),
|
||||
sa.Column("apple_music_id", sa.Text),
|
||||
sa.Column("duration_ms", sa.Integer),
|
||||
sa.Column("genre", sa.Text),
|
||||
sa.Column("embedding_status", sa.Text, nullable=False, server_default="pending"),
|
||||
sa.Column("embedding_error", sa.Text),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Listen events
|
||||
op.create_table(
|
||||
"listen_events",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("track_id", sa.BigInteger, sa.ForeignKey("tracks.id"), nullable=False),
|
||||
sa.Column("source", sa.Text, nullable=False, server_default="music_assistant"),
|
||||
sa.Column("speaker_name", sa.Text),
|
||||
sa.Column("listened_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("duration_played", sa.Integer),
|
||||
sa.Column("raw_payload", sa.dialects.postgresql.JSONB),
|
||||
)
|
||||
|
||||
# Track embeddings (512-dim CLAP)
|
||||
op.create_table(
|
||||
"track_embeddings",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("track_id", sa.BigInteger, sa.ForeignKey("tracks.id"), unique=True, nullable=False),
|
||||
sa.Column("embedding", Vector(512), nullable=False),
|
||||
sa.Column("model_version", sa.Text, nullable=False, server_default="laion/larger_clap_music"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_track_embeddings_hnsw ON track_embeddings "
|
||||
"USING hnsw (embedding vector_cosine_ops)"
|
||||
)
|
||||
|
||||
# Similarity links
|
||||
op.create_table(
|
||||
"similarity_links",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("source_track_id", sa.BigInteger, sa.ForeignKey("tracks.id"), nullable=False),
|
||||
sa.Column("target_track_id", sa.BigInteger, sa.ForeignKey("tracks.id"), nullable=False),
|
||||
sa.Column("lastfm_match", sa.Real),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.UniqueConstraint("source_track_id", "target_track_id", name="uq_similarity_link"),
|
||||
)
|
||||
|
||||
# Taste profiles
|
||||
op.create_table(
|
||||
"taste_profiles",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("name", sa.Text, unique=True, nullable=False, server_default="default"),
|
||||
sa.Column("embedding", Vector(512), nullable=False),
|
||||
sa.Column("track_count", sa.Integer, nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Playlists
|
||||
op.create_table(
|
||||
"playlists",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("name", sa.Text),
|
||||
sa.Column("known_pct", sa.Integer, nullable=False),
|
||||
sa.Column("total_tracks", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"playlist_tracks",
|
||||
sa.Column("id", sa.BigInteger, primary_key=True),
|
||||
sa.Column("playlist_id", sa.BigInteger, sa.ForeignKey("playlists.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("track_id", sa.BigInteger, sa.ForeignKey("tracks.id"), nullable=False),
|
||||
sa.Column("position", sa.Integer, nullable=False),
|
||||
sa.Column("is_known", sa.Boolean, nullable=False),
|
||||
sa.Column("similarity_score", sa.Real),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("playlist_tracks")
|
||||
op.drop_table("playlists")
|
||||
op.drop_table("taste_profiles")
|
||||
op.drop_table("similarity_links")
|
||||
op.execute("DROP INDEX IF EXISTS ix_track_embeddings_hnsw")
|
||||
op.drop_table("track_embeddings")
|
||||
op.drop_table("listen_events")
|
||||
op.drop_table("tracks")
|
||||
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||
53
docker-compose.yml
Normal file
53
docker-compose.yml
Normal file
@@ -0,0 +1,53 @@
|
||||
services:
|
||||
haunt-fm-db:
|
||||
image: pgvector/pgvector:pg17
|
||||
container_name: haunt-fm-db
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_DB: ${POSTGRES_DB:-hauntfm}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-hauntfm}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
volumes:
|
||||
- haunt-fm-db-data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "${POSTGRES_USER:-hauntfm}"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- haunt-fm-internal
|
||||
|
||||
haunt-fm:
|
||||
build: .
|
||||
container_name: haunt-fm
|
||||
restart: unless-stopped
|
||||
env_file: .env
|
||||
environment:
|
||||
HAUNTFM_DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-hauntfm}:${POSTGRES_PASSWORD}@haunt-fm-db:5432/${POSTGRES_DB:-hauntfm}
|
||||
ports:
|
||||
- "8321:8000"
|
||||
volumes:
|
||||
- haunt-fm-model-cache:/app/model-cache
|
||||
- haunt-fm-audio-cache:/app/audio-cache
|
||||
depends_on:
|
||||
haunt-fm-db:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8000/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
networks:
|
||||
- haunt-fm-internal
|
||||
- webgateway
|
||||
|
||||
networks:
|
||||
haunt-fm-internal:
|
||||
webgateway:
|
||||
external: true
|
||||
|
||||
volumes:
|
||||
haunt-fm-db-data:
|
||||
haunt-fm-model-cache:
|
||||
haunt-fm-audio-cache:
|
||||
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "haunt-fm"
|
||||
version = "0.1.0"
|
||||
description = "Personal music recommendation service"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.34",
|
||||
"sqlalchemy[asyncio]>=2.0",
|
||||
"asyncpg>=0.30",
|
||||
"alembic>=1.14",
|
||||
"pydantic-settings>=2.7",
|
||||
"pgvector>=0.3",
|
||||
"httpx>=0.28",
|
||||
"jinja2>=3.1",
|
||||
"numpy>=1.26",
|
||||
"librosa>=0.10",
|
||||
"transformers>=4.48",
|
||||
"torch>=2.5",
|
||||
"soundfile>=0.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
]
|
||||
70
scripts/seed_history_from_ma.py
Normal file
70
scripts/seed_history_from_ma.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
"""One-time backfill: pull recently played tracks from Music Assistant via HA REST API."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
HA_URL = os.environ.get("HAUNTFM_HA_URL", "http://192.168.86.51:8123")
|
||||
HA_TOKEN = os.environ.get("HAUNTFM_HA_TOKEN", "")
|
||||
HAUNTFM_URL = os.environ.get("HAUNTFM_URL", "http://localhost:8321")
|
||||
|
||||
|
||||
async def get_recently_played() -> list[dict]:
|
||||
"""Get recently played items from Music Assistant via HA."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {HA_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
# Get all media_player entities
|
||||
resp = await client.get(f"{HA_URL}/api/states", headers=headers)
|
||||
resp.raise_for_status()
|
||||
states = resp.json()
|
||||
|
||||
# Filter for music assistant players that have media info
|
||||
tracks = []
|
||||
for state in states:
|
||||
if not state["entity_id"].startswith("media_player."):
|
||||
continue
|
||||
attrs = state.get("attributes", {})
|
||||
title = attrs.get("media_title")
|
||||
artist = attrs.get("media_artist")
|
||||
if title and artist:
|
||||
tracks.append({
|
||||
"title": title,
|
||||
"artist": artist,
|
||||
"album": attrs.get("media_album_name"),
|
||||
"speaker_name": attrs.get("friendly_name"),
|
||||
"source": "music_assistant_backfill",
|
||||
})
|
||||
|
||||
return tracks
|
||||
|
||||
|
||||
async def send_to_webhook(track: dict):
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(f"{HAUNTFM_URL}/api/history/webhook", json=track)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def main():
|
||||
if not HA_TOKEN:
|
||||
print("Set HAUNTFM_HA_TOKEN environment variable")
|
||||
return
|
||||
|
||||
print(f"Fetching from {HA_URL}...")
|
||||
tracks = await get_recently_played()
|
||||
print(f"Found {len(tracks)} tracks with media info")
|
||||
|
||||
for track in tracks:
|
||||
try:
|
||||
result = await send_to_webhook(track)
|
||||
print(f" OK: {track['artist']} - {track['title']} -> track_id={result['track_id']}")
|
||||
except Exception as e:
|
||||
print(f" FAIL: {track['artist']} - {track['title']}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
src/haunt_fm/__init__.py
Normal file
0
src/haunt_fm/__init__.py
Normal file
0
src/haunt_fm/api/__init__.py
Normal file
0
src/haunt_fm/api/__init__.py
Normal file
59
src/haunt_fm/api/admin.py
Normal file
59
src/haunt_fm/api/admin.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.models.track import ListenEvent, Track
|
||||
from haunt_fm.services.lastfm_client import discover_similar_for_track
|
||||
from haunt_fm.services.taste_profile import build_taste_profile
|
||||
|
||||
router = APIRouter(prefix="/api/admin")
|
||||
|
||||
|
||||
class DiscoverRequest(BaseModel):
|
||||
limit: int = 50 # max tracks from history to expand
|
||||
|
||||
|
||||
@router.post("/discover")
|
||||
async def discover(req: DiscoverRequest, session: AsyncSession = Depends(get_session)):
|
||||
"""Expand listening history via Last.fm track.getSimilar."""
|
||||
# Get most-listened tracks that haven't been expanded yet
|
||||
listened_tracks = (
|
||||
await session.execute(
|
||||
select(Track)
|
||||
.join(ListenEvent, ListenEvent.track_id == Track.id)
|
||||
.group_by(Track.id)
|
||||
.order_by(func.count(ListenEvent.id).desc())
|
||||
.limit(req.limit)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
total_discovered = 0
|
||||
errors = []
|
||||
|
||||
for track in listened_tracks:
|
||||
try:
|
||||
count = await discover_similar_for_track(session, track)
|
||||
total_discovered += count
|
||||
except Exception as e:
|
||||
errors.append({"track": f"{track.artist} - {track.title}", "error": str(e)})
|
||||
|
||||
return {
|
||||
"tracks_expanded": len(listened_tracks),
|
||||
"candidates_discovered": total_discovered,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/build-taste-profile")
|
||||
async def build_profile(session: AsyncSession = Depends(get_session)):
|
||||
"""Rebuild the taste profile from listened-track embeddings."""
|
||||
profile = await build_taste_profile(session)
|
||||
if profile is None:
|
||||
return {"ok": False, "error": "No listened tracks with embeddings found"}
|
||||
return {
|
||||
"ok": True,
|
||||
"track_count": profile.track_count,
|
||||
"updated_at": profile.updated_at.isoformat(),
|
||||
}
|
||||
19
src/haunt_fm/api/health.py
Normal file
19
src/haunt_fm/api/health.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(session: AsyncSession = Depends(get_session)):
|
||||
try:
|
||||
await session.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
db_ok = False
|
||||
|
||||
status = "healthy" if db_ok else "degraded"
|
||||
return {"status": status, "db_connected": db_ok}
|
||||
39
src/haunt_fm/api/history.py
Normal file
39
src/haunt_fm/api/history.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.services.history_ingest import ingest_listen_event
|
||||
|
||||
router = APIRouter(prefix="/api/history")
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
title: str
|
||||
artist: str
|
||||
album: str | None = None
|
||||
speaker_name: str | None = None
|
||||
duration_played: int | None = None
|
||||
source: str = "music_assistant"
|
||||
listened_at: datetime | None = None
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def receive_webhook(payload: WebhookPayload, session: AsyncSession = Depends(get_session)):
|
||||
if payload.listened_at is None:
|
||||
payload.listened_at = datetime.now(timezone.utc)
|
||||
|
||||
event = await ingest_listen_event(
|
||||
session=session,
|
||||
title=payload.title,
|
||||
artist=payload.artist,
|
||||
album=payload.album,
|
||||
speaker_name=payload.speaker_name,
|
||||
duration_played=payload.duration_played,
|
||||
source=payload.source,
|
||||
listened_at=payload.listened_at,
|
||||
raw_payload=payload.model_dump(mode="json"),
|
||||
)
|
||||
return {"ok": True, "track_id": event.track_id, "event_id": event.id}
|
||||
63
src/haunt_fm/api/playlists.py
Normal file
63
src/haunt_fm/api/playlists.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.models.track import PlaylistTrack, Track
|
||||
from haunt_fm.services.music_assistant import play_playlist_on_speaker
|
||||
from haunt_fm.services.playlist_generator import generate_playlist
|
||||
|
||||
router = APIRouter(prefix="/api/playlists")
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
total_tracks: int = 20
|
||||
known_pct: int = 30
|
||||
name: str | None = None
|
||||
speaker_entity: str | None = None
|
||||
auto_play: bool = False
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_session)):
|
||||
playlist = await generate_playlist(
|
||||
session,
|
||||
total_tracks=req.total_tracks,
|
||||
known_pct=req.known_pct,
|
||||
name=req.name,
|
||||
)
|
||||
|
||||
# Load playlist tracks with track info
|
||||
result = await session.execute(
|
||||
select(PlaylistTrack, Track)
|
||||
.join(Track, PlaylistTrack.track_id == Track.id)
|
||||
.where(PlaylistTrack.playlist_id == playlist.id)
|
||||
.order_by(PlaylistTrack.position)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
track_list = [
|
||||
{
|
||||
"position": pt.position,
|
||||
"artist": t.artist,
|
||||
"title": t.title,
|
||||
"album": t.album,
|
||||
"is_known": pt.is_known,
|
||||
"similarity_score": pt.similarity_score,
|
||||
}
|
||||
for pt, t in rows
|
||||
]
|
||||
|
||||
# Auto-play if requested
|
||||
if req.auto_play and req.speaker_entity:
|
||||
await play_playlist_on_speaker(track_list, req.speaker_entity)
|
||||
|
||||
return {
|
||||
"playlist_id": playlist.id,
|
||||
"name": playlist.name,
|
||||
"total_tracks": playlist.total_tracks,
|
||||
"known_pct": playlist.known_pct,
|
||||
"tracks": track_list,
|
||||
"auto_played": req.auto_play and req.speaker_entity is not None,
|
||||
}
|
||||
19
src/haunt_fm/api/recommendations.py
Normal file
19
src/haunt_fm/api/recommendations.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.services.recommender import get_recommendations
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.get("/recommendations")
|
||||
async def recommendations(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
include_known: bool = Query(default=False),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
results = await get_recommendations(
|
||||
session, limit=limit, exclude_known=not include_known
|
||||
)
|
||||
return {"recommendations": results, "count": len(results)}
|
||||
113
src/haunt_fm/api/status.py
Normal file
113
src/haunt_fm/api/status.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.db import get_session
|
||||
from haunt_fm.models.track import (
|
||||
ListenEvent,
|
||||
Playlist,
|
||||
TasteProfile,
|
||||
Track,
|
||||
TrackEmbedding,
|
||||
)
|
||||
from haunt_fm.config import settings
|
||||
from haunt_fm.services.embedding import is_model_loaded
|
||||
from haunt_fm.services.embedding_worker import is_running as is_worker_running
|
||||
from haunt_fm.services.embedding_worker import last_processed as worker_last_processed
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def status(session: AsyncSession = Depends(get_session)):
|
||||
# DB connectivity
|
||||
try:
|
||||
await session.execute(text("SELECT 1"))
|
||||
db_connected = True
|
||||
except Exception:
|
||||
db_connected = False
|
||||
|
||||
if not db_connected:
|
||||
return {"healthy": False, "db_connected": False}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
day_ago = now - timedelta(days=1)
|
||||
|
||||
# Listen events
|
||||
total_events = (await session.execute(select(func.count(ListenEvent.id)))).scalar() or 0
|
||||
events_24h = (
|
||||
await session.execute(
|
||||
select(func.count(ListenEvent.id)).where(ListenEvent.listened_at >= day_ago)
|
||||
)
|
||||
).scalar() or 0
|
||||
latest_event = (
|
||||
await session.execute(select(func.max(ListenEvent.listened_at)))
|
||||
).scalar()
|
||||
|
||||
# Tracks
|
||||
total_tracks = (await session.execute(select(func.count(Track.id)))).scalar() or 0
|
||||
from_history = (
|
||||
await session.execute(
|
||||
select(func.count(func.distinct(ListenEvent.track_id)))
|
||||
)
|
||||
).scalar() or 0
|
||||
from_discovery = total_tracks - from_history
|
||||
|
||||
# Embeddings
|
||||
def _embedding_count(status_val: str):
|
||||
return select(func.count(Track.id)).where(Track.embedding_status == status_val)
|
||||
|
||||
emb_done = (await session.execute(_embedding_count("done"))).scalar() or 0
|
||||
emb_pending = (await session.execute(_embedding_count("pending"))).scalar() or 0
|
||||
emb_failed = (await session.execute(_embedding_count("failed"))).scalar() or 0
|
||||
emb_no_preview = (await session.execute(_embedding_count("no_preview"))).scalar() or 0
|
||||
|
||||
# Taste profile
|
||||
taste = (await session.execute(select(TasteProfile).where(TasteProfile.name == "default"))).scalar()
|
||||
|
||||
# Playlists
|
||||
total_playlists = (await session.execute(select(func.count(Playlist.id)))).scalar() or 0
|
||||
last_playlist = (await session.execute(select(func.max(Playlist.created_at)))).scalar()
|
||||
|
||||
return {
|
||||
"healthy": db_connected,
|
||||
"db_connected": db_connected,
|
||||
"clap_model_loaded": is_model_loaded(),
|
||||
"pipeline": {
|
||||
"listen_events": {
|
||||
"total": total_events,
|
||||
"last_24h": events_24h,
|
||||
"latest": latest_event.isoformat() if latest_event else None,
|
||||
},
|
||||
"tracks": {
|
||||
"total": total_tracks,
|
||||
"from_history": from_history,
|
||||
"from_discovery": from_discovery,
|
||||
},
|
||||
"embeddings": {
|
||||
"done": emb_done,
|
||||
"pending": emb_pending,
|
||||
"failed": emb_failed,
|
||||
"no_preview": emb_no_preview,
|
||||
"worker_running": is_worker_running(),
|
||||
"worker_last_processed": worker_last_processed().isoformat() if worker_last_processed() else None,
|
||||
},
|
||||
"taste_profile": {
|
||||
"exists": taste is not None,
|
||||
"track_count": taste.track_count if taste else 0,
|
||||
"updated_at": taste.updated_at.isoformat() if taste else None,
|
||||
},
|
||||
"playlists": {
|
||||
"total_generated": total_playlists,
|
||||
"last_generated": last_playlist.isoformat() if last_playlist else None,
|
||||
},
|
||||
},
|
||||
"dependencies": {
|
||||
"lastfm_api": "configured" if settings.lastfm_api_key else "not_configured",
|
||||
"itunes_api": "ok", # no auth needed
|
||||
"ha_reachable": bool(settings.ha_token),
|
||||
"music_assistant_reachable": bool(settings.ha_token),
|
||||
},
|
||||
}
|
||||
23
src/haunt_fm/api/status_page.py
Normal file
23
src/haunt_fm/api/status_page.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.api.status import status as get_status_data
|
||||
from haunt_fm.db import get_session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_template_dir = Path(__file__).parent.parent / "templates"
|
||||
_jinja_env = Environment(loader=FileSystemLoader(str(_template_dir)), autoescape=True)
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def status_page(request: Request, session: AsyncSession = Depends(get_session)):
|
||||
data = await get_status_data(session)
|
||||
template = _jinja_env.get_template("status.html")
|
||||
html = template.render(data=data, now=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC"))
|
||||
return HTMLResponse(html)
|
||||
27
src/haunt_fm/config.py
Normal file
27
src/haunt_fm/config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_prefix": "HAUNTFM_"}
|
||||
|
||||
# Database
|
||||
database_url: str = "postgresql+asyncpg://hauntfm:changeme@localhost:5432/hauntfm"
|
||||
|
||||
# Last.fm
|
||||
lastfm_api_key: str = ""
|
||||
|
||||
# Home Assistant
|
||||
ha_url: str = "http://192.168.86.51:8123"
|
||||
ha_token: str = ""
|
||||
|
||||
# CLAP model
|
||||
model_cache_dir: str = "/app/model-cache"
|
||||
audio_cache_dir: str = "/app/audio-cache"
|
||||
|
||||
# Embedding worker
|
||||
embedding_worker_enabled: bool = True
|
||||
embedding_batch_size: int = 10
|
||||
embedding_interval_seconds: int = 30
|
||||
|
||||
|
||||
settings = Settings()
|
||||
13
src/haunt_fm/db.py
Normal file
13
src/haunt_fm/db.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from haunt_fm.config import settings
|
||||
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
50
src/haunt_fm/main.py
Normal file
50
src/haunt_fm/main.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from haunt_fm.api import admin, health, history, playlists, recommendations, status, status_page
|
||||
from haunt_fm.config import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("haunt-fm starting up")
|
||||
|
||||
# Start embedding worker in background
|
||||
worker_task = None
|
||||
if settings.embedding_worker_enabled:
|
||||
from haunt_fm.services.embedding_worker import run_worker
|
||||
|
||||
worker_task = asyncio.create_task(run_worker())
|
||||
logger.info("Embedding worker task created")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if worker_task:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("haunt-fm shut down")
|
||||
|
||||
|
||||
app = FastAPI(title="haunt-fm", lifespan=lifespan)
|
||||
|
||||
app.include_router(health.router)
|
||||
app.include_router(status.router)
|
||||
app.include_router(status_page.router)
|
||||
app.include_router(history.router)
|
||||
app.include_router(admin.router)
|
||||
app.include_router(recommendations.router)
|
||||
app.include_router(playlists.router)
|
||||
0
src/haunt_fm/models/__init__.py
Normal file
0
src/haunt_fm/models/__init__.py
Normal file
5
src/haunt_fm/models/base.py
Normal file
5
src/haunt_fm/models/base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
113
src/haunt_fm/models/track.py
Normal file
113
src/haunt_fm/models/track.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import BigInteger, DateTime, Index, Integer, Real, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from haunt_fm.models.base import Base
|
||||
|
||||
|
||||
class Track(Base):
|
||||
__tablename__ = "tracks"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
title: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
artist: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
album: Mapped[str | None] = mapped_column(Text)
|
||||
fingerprint: Mapped[str] = mapped_column(Text, unique=True, nullable=False)
|
||||
lastfm_url: Mapped[str | None] = mapped_column(Text)
|
||||
itunes_track_id: Mapped[int | None] = mapped_column(BigInteger)
|
||||
itunes_preview_url: Mapped[str | None] = mapped_column(Text)
|
||||
apple_music_id: Mapped[str | None] = mapped_column(Text)
|
||||
duration_ms: Mapped[int | None] = mapped_column(Integer)
|
||||
genre: Mapped[str | None] = mapped_column(Text)
|
||||
embedding_status: Mapped[str] = mapped_column(Text, nullable=False, default="pending")
|
||||
embedding_error: Mapped[str | None] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
listen_events: Mapped[list["ListenEvent"]] = relationship(back_populates="track")
|
||||
embedding: Mapped["TrackEmbedding | None"] = relationship(back_populates="track")
|
||||
|
||||
|
||||
class ListenEvent(Base):
|
||||
__tablename__ = "listen_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
track_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
source: Mapped[str] = mapped_column(Text, nullable=False, default="music_assistant")
|
||||
speaker_name: Mapped[str | None] = mapped_column(Text)
|
||||
listened_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
duration_played: Mapped[int | None] = mapped_column(Integer)
|
||||
raw_payload: Mapped[dict | None] = mapped_column(JSONB)
|
||||
|
||||
track: Mapped[Track] = relationship(back_populates="listen_events")
|
||||
|
||||
|
||||
class TrackEmbedding(Base):
|
||||
__tablename__ = "track_embeddings"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
track_id: Mapped[int] = mapped_column(BigInteger, unique=True, nullable=False)
|
||||
embedding = mapped_column(Vector(512), nullable=False)
|
||||
model_version: Mapped[str] = mapped_column(Text, nullable=False, default="laion/larger_clap_music")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_track_embeddings_embedding_hnsw", "embedding", postgresql_using="hnsw", postgresql_with={"m": 16, "ef_construction": 64}, postgresql_ops={"embedding": "vector_cosine_ops"}),
|
||||
)
|
||||
|
||||
track: Mapped[Track] = relationship(back_populates="embedding")
|
||||
|
||||
|
||||
class SimilarityLink(Base):
|
||||
__tablename__ = "similarity_links"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
source_track_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
target_track_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
lastfm_match: Mapped[float | None] = mapped_column(Real)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_similarity_link", "source_track_id", "target_track_id", unique=True),
|
||||
)
|
||||
|
||||
|
||||
class TasteProfile(Base):
|
||||
__tablename__ = "taste_profiles"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(Text, unique=True, nullable=False, default="default")
|
||||
embedding = mapped_column(Vector(512), nullable=False)
|
||||
track_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class Playlist(Base):
|
||||
__tablename__ = "playlists"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
name: Mapped[str | None] = mapped_column(Text)
|
||||
known_pct: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
total_tracks: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
tracks: Mapped[list["PlaylistTrack"]] = relationship(back_populates="playlist", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class PlaylistTrack(Base):
|
||||
__tablename__ = "playlist_tracks"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
playlist_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
track_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
position: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
is_known: Mapped[bool] = mapped_column(nullable=False)
|
||||
similarity_score: Mapped[float | None] = mapped_column(Real)
|
||||
|
||||
playlist: Mapped[Playlist] = relationship(back_populates="tracks")
|
||||
track: Mapped[Track] = relationship()
|
||||
0
src/haunt_fm/services/__init__.py
Normal file
0
src/haunt_fm/services/__init__.py
Normal file
50
src/haunt_fm/services/embedding.py
Normal file
50
src/haunt_fm/services/embedding.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from haunt_fm.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Load the CLAP model. Call once at startup."""
|
||||
global _model, _processor
|
||||
if _model is not None:
|
||||
return
|
||||
|
||||
import torch
|
||||
from transformers import ClapModel, ClapProcessor
|
||||
|
||||
logger.info("Loading CLAP model laion/larger_clap_music...")
|
||||
cache_dir = settings.model_cache_dir
|
||||
|
||||
_processor = ClapProcessor.from_pretrained("laion/larger_clap_music", cache_dir=cache_dir)
|
||||
_model = ClapModel.from_pretrained("laion/larger_clap_music", cache_dir=cache_dir)
|
||||
_model.eval()
|
||||
logger.info("CLAP model loaded successfully")
|
||||
|
||||
|
||||
def is_model_loaded() -> bool:
|
||||
return _model is not None
|
||||
|
||||
|
||||
def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
|
||||
"""Generate a 512-dim embedding from audio waveform."""
|
||||
import torch
|
||||
|
||||
if _model is None or _processor is None:
|
||||
raise RuntimeError("CLAP model not loaded. Call load_model() first.")
|
||||
|
||||
inputs = _processor(audios=audio, sampling_rate=sample_rate, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
embeddings = _model.get_audio_features(**inputs)
|
||||
|
||||
# Normalize
|
||||
emb = embeddings[0].numpy()
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
return emb
|
||||
140
src/haunt_fm/services/embedding_worker.py
Normal file
140
src/haunt_fm/services/embedding_worker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.config import settings
|
||||
from haunt_fm.db import async_session
|
||||
from haunt_fm.models.track import Track, TrackEmbedding
|
||||
from haunt_fm.services.embedding import embed_audio, load_model
|
||||
from haunt_fm.services.itunes_client import search_track
|
||||
from haunt_fm.utils.audio import decode_audio, download_preview
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_running = False
|
||||
_last_processed: datetime | None = None
|
||||
|
||||
|
||||
def is_running() -> bool:
|
||||
return _running
|
||||
|
||||
|
||||
def last_processed() -> datetime | None:
|
||||
return _last_processed
|
||||
|
||||
|
||||
async def _process_track(session: AsyncSession, track: Track) -> bool:
|
||||
"""Process a single track: find preview, download, embed, store. Returns True on success."""
|
||||
global _last_processed
|
||||
|
||||
# Mark as downloading
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="downloading")
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Find iTunes preview
|
||||
if not track.itunes_preview_url:
|
||||
result = await search_track(track.artist, track.title)
|
||||
if result is None:
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="no_preview")
|
||||
)
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
await session.execute(
|
||||
update(Track)
|
||||
.where(Track.id == track.id)
|
||||
.values(
|
||||
itunes_track_id=result["track_id"],
|
||||
itunes_preview_url=result["preview_url"],
|
||||
apple_music_id=result["apple_music_id"],
|
||||
duration_ms=result.get("duration_ms"),
|
||||
genre=result.get("genre"),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
preview_url = result["preview_url"]
|
||||
else:
|
||||
preview_url = track.itunes_preview_url
|
||||
|
||||
# Download and decode
|
||||
filename = f"{track.id}.m4a"
|
||||
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
|
||||
audio = decode_audio(filepath)
|
||||
|
||||
# Embed
|
||||
embedding = embed_audio(audio)
|
||||
|
||||
# Store
|
||||
track_embedding = TrackEmbedding(
|
||||
track_id=track.id,
|
||||
embedding=embedding.tolist(),
|
||||
)
|
||||
session.add(track_embedding)
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="done")
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
_last_processed = datetime.now(timezone.utc)
|
||||
return True
|
||||
|
||||
|
||||
async def run_worker():
|
||||
"""Background loop that processes unembedded tracks."""
|
||||
global _running
|
||||
|
||||
if not settings.embedding_worker_enabled:
|
||||
logger.info("Embedding worker disabled")
|
||||
return
|
||||
|
||||
# Load model on first run
|
||||
load_model()
|
||||
_running = True
|
||||
logger.info("Embedding worker started")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
async with async_session() as session:
|
||||
# Find batch of pending tracks
|
||||
result = await session.execute(
|
||||
select(Track)
|
||||
.where(Track.embedding_status == "pending")
|
||||
.order_by(Track.created_at)
|
||||
.limit(settings.embedding_batch_size)
|
||||
)
|
||||
tracks = result.scalars().all()
|
||||
|
||||
if not tracks:
|
||||
await asyncio.sleep(settings.embedding_interval_seconds)
|
||||
continue
|
||||
|
||||
for track in tracks:
|
||||
try:
|
||||
await _process_track(session, track)
|
||||
logger.info("Embedded: %s - %s", track.artist, track.title)
|
||||
except Exception:
|
||||
logger.exception("Failed to embed %s - %s", track.artist, track.title)
|
||||
await session.execute(
|
||||
update(Track)
|
||||
.where(Track.id == track.id)
|
||||
.values(
|
||||
embedding_status="failed",
|
||||
embedding_error=str(Exception),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Embedding worker error")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
await asyncio.sleep(1) # Brief pause between batches
|
||||
finally:
|
||||
_running = False
|
||||
64
src/haunt_fm/services/history_ingest.py
Normal file
64
src/haunt_fm/services/history_ingest.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.models.track import ListenEvent, Track
|
||||
|
||||
|
||||
def make_fingerprint(artist: str, title: str) -> str:
|
||||
return f"{artist.lower().strip()}::{title.lower().strip()}"
|
||||
|
||||
|
||||
async def upsert_track(
|
||||
session: AsyncSession,
|
||||
title: str,
|
||||
artist: str,
|
||||
album: str | None = None,
|
||||
) -> Track:
|
||||
fingerprint = make_fingerprint(artist, title)
|
||||
result = await session.execute(select(Track).where(Track.fingerprint == fingerprint))
|
||||
track = result.scalar_one_or_none()
|
||||
|
||||
if track is None:
|
||||
track = Track(
|
||||
title=title,
|
||||
artist=artist,
|
||||
album=album,
|
||||
fingerprint=fingerprint,
|
||||
)
|
||||
session.add(track)
|
||||
await session.flush()
|
||||
|
||||
if album and not track.album:
|
||||
track.album = album
|
||||
await session.flush()
|
||||
|
||||
return track
|
||||
|
||||
|
||||
async def ingest_listen_event(
|
||||
session: AsyncSession,
|
||||
title: str,
|
||||
artist: str,
|
||||
album: str | None,
|
||||
speaker_name: str | None,
|
||||
duration_played: int | None,
|
||||
source: str,
|
||||
listened_at: datetime,
|
||||
raw_payload: dict | None = None,
|
||||
) -> ListenEvent:
|
||||
track = await upsert_track(session, title, artist, album)
|
||||
|
||||
event = ListenEvent(
|
||||
track_id=track.id,
|
||||
source=source,
|
||||
speaker_name=speaker_name,
|
||||
listened_at=listened_at,
|
||||
duration_played=duration_played,
|
||||
raw_payload=raw_payload,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
await session.refresh(event)
|
||||
return event
|
||||
57
src/haunt_fm/services/itunes_client.py
Normal file
57
src/haunt_fm/services/itunes_client.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ITUNES_SEARCH_URL = "https://itunes.apple.com/search"
|
||||
|
||||
# Rate limit: ~20 req/min for iTunes
|
||||
_last_request_time = 0.0
|
||||
_min_interval = 3.0 # 3s between requests
|
||||
|
||||
|
||||
async def _rate_limit():
|
||||
global _last_request_time
|
||||
now = asyncio.get_event_loop().time()
|
||||
elapsed = now - _last_request_time
|
||||
if elapsed < _min_interval:
|
||||
await asyncio.sleep(_min_interval - elapsed)
|
||||
_last_request_time = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
async def search_track(artist: str, title: str) -> dict | None:
|
||||
"""Search iTunes for a track and return preview info, or None if not found."""
|
||||
await _rate_limit()
|
||||
|
||||
query = f"{artist} {title}"
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(
|
||||
ITUNES_SEARCH_URL,
|
||||
params={
|
||||
"term": query,
|
||||
"media": "music",
|
||||
"entity": "song",
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Find best match (simple: first result with a preview URL)
|
||||
for r in results:
|
||||
if r.get("previewUrl"):
|
||||
return {
|
||||
"track_id": r["trackId"],
|
||||
"preview_url": r["previewUrl"],
|
||||
"apple_music_id": str(r.get("trackId", "")),
|
||||
"duration_ms": r.get("trackTimeMillis"),
|
||||
"genre": r.get("primaryGenreName"),
|
||||
}
|
||||
|
||||
return None
|
||||
106
src/haunt_fm/services/lastfm_client.py
Normal file
106
src/haunt_fm/services/lastfm_client.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.config import settings
|
||||
from haunt_fm.models.track import SimilarityLink, Track
|
||||
from haunt_fm.services.history_ingest import make_fingerprint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LASTFM_API_URL = "https://ws.audioscrobbler.com/2.0/"
|
||||
|
||||
# Simple rate limiter: max 5 req/s
|
||||
_last_request_time = 0.0
|
||||
_min_interval = 0.2 # 200ms between requests
|
||||
|
||||
|
||||
async def _rate_limit():
|
||||
global _last_request_time
|
||||
now = asyncio.get_event_loop().time()
|
||||
elapsed = now - _last_request_time
|
||||
if elapsed < _min_interval:
|
||||
await asyncio.sleep(_min_interval - elapsed)
|
||||
_last_request_time = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
async def get_similar_tracks(artist: str, title: str, limit: int = 50) -> list[dict]:
|
||||
"""Call Last.fm track.getSimilar and return list of similar track dicts."""
|
||||
if not settings.lastfm_api_key:
|
||||
raise ValueError("HAUNTFM_LASTFM_API_KEY not configured")
|
||||
|
||||
await _rate_limit()
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(
|
||||
LASTFM_API_URL,
|
||||
params={
|
||||
"method": "track.getSimilar",
|
||||
"artist": artist,
|
||||
"track": title,
|
||||
"api_key": settings.lastfm_api_key,
|
||||
"format": "json",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
similar = data.get("similartracks", {}).get("track", [])
|
||||
return [
|
||||
{
|
||||
"title": t["name"],
|
||||
"artist": t["artist"]["name"],
|
||||
"match": float(t.get("match", 0)),
|
||||
"url": t.get("url"),
|
||||
}
|
||||
for t in similar
|
||||
]
|
||||
|
||||
|
||||
async def discover_similar_for_track(session: AsyncSession, source_track: Track) -> int:
|
||||
"""Discover similar tracks for a given track, upsert them, and create similarity links."""
|
||||
try:
|
||||
similar = await get_similar_tracks(source_track.artist, source_track.title)
|
||||
except Exception:
|
||||
logger.exception("Failed to get similar tracks for %s - %s", source_track.artist, source_track.title)
|
||||
raise
|
||||
|
||||
count = 0
|
||||
for item in similar:
|
||||
fingerprint = make_fingerprint(item["artist"], item["title"])
|
||||
|
||||
# Upsert target track
|
||||
result = await session.execute(select(Track).where(Track.fingerprint == fingerprint))
|
||||
target = result.scalar_one_or_none()
|
||||
if target is None:
|
||||
target = Track(
|
||||
title=item["title"],
|
||||
artist=item["artist"],
|
||||
fingerprint=fingerprint,
|
||||
lastfm_url=item.get("url"),
|
||||
)
|
||||
session.add(target)
|
||||
await session.flush()
|
||||
|
||||
# Upsert similarity link
|
||||
existing_link = await session.execute(
|
||||
select(SimilarityLink).where(
|
||||
SimilarityLink.source_track_id == source_track.id,
|
||||
SimilarityLink.target_track_id == target.id,
|
||||
)
|
||||
)
|
||||
if existing_link.scalar_one_or_none() is None:
|
||||
link = SimilarityLink(
|
||||
source_track_id=source_track.id,
|
||||
target_track_id=target.id,
|
||||
lastfm_match=item["match"],
|
||||
)
|
||||
session.add(link)
|
||||
count += 1
|
||||
|
||||
await session.commit()
|
||||
return count
|
||||
117
src/haunt_fm/services/music_assistant.py
Normal file
117
src/haunt_fm/services/music_assistant.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from haunt_fm.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ha_request(method: str, path: str, **kwargs) -> dict:
|
||||
"""Make an authenticated request to Home Assistant REST API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.ha_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.request(
|
||||
method, f"{settings.ha_url}{path}", headers=headers, **kwargs
|
||||
)
|
||||
resp.raise_for_status()
|
||||
if resp.content:
|
||||
return resp.json()
|
||||
return {}
|
||||
|
||||
|
||||
async def is_ha_reachable() -> bool:
|
||||
"""Check if Home Assistant is reachable."""
|
||||
try:
|
||||
await _ha_request("GET", "/api/")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def play_media_on_speaker(
|
||||
media_content_id: str,
|
||||
speaker_entity: str,
|
||||
media_content_type: str = "music",
|
||||
) -> None:
|
||||
"""Play a media item on a speaker via HA media_player service."""
|
||||
await _ha_request(
|
||||
"POST",
|
||||
"/api/services/media_player/play_media",
|
||||
json={
|
||||
"entity_id": speaker_entity,
|
||||
"media_content_id": media_content_id,
|
||||
"media_content_type": media_content_type,
|
||||
},
|
||||
)
|
||||
logger.info("Playing %s on %s", media_content_id, speaker_entity)
|
||||
|
||||
|
||||
async def search_and_play(
|
||||
artist: str,
|
||||
title: str,
|
||||
speaker_entity: str,
|
||||
) -> bool:
|
||||
"""Search Music Assistant for a track and play it.
|
||||
|
||||
Uses the mass.search service to find the track, then plays it.
|
||||
"""
|
||||
try:
|
||||
# Use Music Assistant search via HA
|
||||
result = await _ha_request(
|
||||
"POST",
|
||||
"/api/services/mass/search",
|
||||
json={
|
||||
"name": f"{artist} {title}",
|
||||
"media_type": "track",
|
||||
"limit": 1,
|
||||
},
|
||||
)
|
||||
logger.info("MA search result for '%s - %s': %s", artist, title, result)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to search MA for %s - %s", artist, title)
|
||||
return False
|
||||
|
||||
|
||||
async def play_playlist_on_speaker(
|
||||
tracks: list[dict],
|
||||
speaker_entity: str,
|
||||
) -> None:
|
||||
"""Play a list of tracks on a speaker. Each track dict has 'artist' and 'title'.
|
||||
|
||||
Enqueues tracks via Music Assistant.
|
||||
"""
|
||||
if not tracks:
|
||||
return
|
||||
|
||||
for i, track in enumerate(tracks):
|
||||
try:
|
||||
if i == 0:
|
||||
# Play first track
|
||||
await _ha_request(
|
||||
"POST",
|
||||
"/api/services/media_player/play_media",
|
||||
json={
|
||||
"entity_id": speaker_entity,
|
||||
"media_content_id": f"{track['artist']} - {track['title']}",
|
||||
"media_content_type": "music",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Enqueue subsequent tracks
|
||||
await _ha_request(
|
||||
"POST",
|
||||
"/api/services/media_player/play_media",
|
||||
json={
|
||||
"entity_id": speaker_entity,
|
||||
"media_content_id": f"{track['artist']} - {track['title']}",
|
||||
"media_content_type": "music",
|
||||
"enqueue": "add",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to enqueue %s - %s", track["artist"], track["title"])
|
||||
106
src/haunt_fm/services/playlist_generator.py
Normal file
106
src/haunt_fm/services/playlist_generator.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.models.track import (
|
||||
ListenEvent,
|
||||
Playlist,
|
||||
PlaylistTrack,
|
||||
Track,
|
||||
)
|
||||
from haunt_fm.services.recommender import get_recommendations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_playlist(
|
||||
session: AsyncSession,
|
||||
total_tracks: int = 20,
|
||||
known_pct: int = 30,
|
||||
name: str | None = None,
|
||||
) -> Playlist:
|
||||
"""Generate a playlist mixing known-liked tracks with new recommendations.
|
||||
|
||||
Args:
|
||||
total_tracks: Total number of tracks in the playlist
|
||||
known_pct: Percentage of tracks from listening history (0-100)
|
||||
name: Optional playlist name
|
||||
"""
|
||||
known_count = round(total_tracks * known_pct / 100)
|
||||
new_count = total_tracks - known_count
|
||||
|
||||
# Get known-liked tracks (most played, with some randomness)
|
||||
known_result = await session.execute(
|
||||
select(Track.id)
|
||||
.join(ListenEvent, ListenEvent.track_id == Track.id)
|
||||
.group_by(Track.id)
|
||||
.order_by(func.count(ListenEvent.id).desc())
|
||||
.limit(known_count * 3) # oversample for randomness
|
||||
)
|
||||
known_pool = [row[0] for row in known_result]
|
||||
if len(known_pool) > known_count:
|
||||
known_ids = random.sample(known_pool, known_count)
|
||||
else:
|
||||
known_ids = known_pool
|
||||
known_count = len(known_ids)
|
||||
new_count = total_tracks - known_count
|
||||
|
||||
# Get new recommendations
|
||||
recs = await get_recommendations(session, limit=new_count * 2, exclude_known=True)
|
||||
new_tracks = [(r["track_id"], r["similarity"]) for r in recs[:new_count]]
|
||||
|
||||
# Interleave: spread known tracks throughout the playlist
|
||||
playlist_items: list[tuple[int, bool, float | None]] = []
|
||||
for tid in known_ids:
|
||||
playlist_items.append((tid, True, None))
|
||||
for tid, sim in new_tracks:
|
||||
playlist_items.append((tid, False, sim))
|
||||
|
||||
# Shuffle but keep some structure — interleave known/new
|
||||
known_items = [(tid, True, None) for tid in known_ids]
|
||||
new_items = [(tid, False, sim) for tid, sim in new_tracks]
|
||||
random.shuffle(known_items)
|
||||
random.shuffle(new_items)
|
||||
|
||||
# Interleave
|
||||
interleaved = []
|
||||
ki, ni = 0, 0
|
||||
for i in range(len(known_items) + len(new_items)):
|
||||
# Roughly distribute known tracks evenly
|
||||
if ki < len(known_items) and (ni >= len(new_items) or i % max(1, total_tracks // max(1, known_count)) == 0):
|
||||
interleaved.append(known_items[ki])
|
||||
ki += 1
|
||||
elif ni < len(new_items):
|
||||
interleaved.append(new_items[ni])
|
||||
ni += 1
|
||||
elif ki < len(known_items):
|
||||
interleaved.append(known_items[ki])
|
||||
ki += 1
|
||||
|
||||
# Create playlist record
|
||||
playlist = Playlist(
|
||||
name=name or f"haunt-fm mix ({len(interleaved)} tracks)",
|
||||
known_pct=known_pct,
|
||||
total_tracks=len(interleaved),
|
||||
)
|
||||
session.add(playlist)
|
||||
await session.flush()
|
||||
|
||||
# Create playlist tracks
|
||||
for pos, (track_id, is_known, similarity) in enumerate(interleaved):
|
||||
pt = PlaylistTrack(
|
||||
playlist_id=playlist.id,
|
||||
track_id=track_id,
|
||||
position=pos,
|
||||
is_known=is_known,
|
||||
similarity_score=similarity,
|
||||
)
|
||||
session.add(pt)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(playlist)
|
||||
logger.info("Generated playlist '%s' with %d tracks (%d known, %d new)",
|
||||
playlist.name, len(interleaved), known_count, len(new_items))
|
||||
return playlist
|
||||
71
src/haunt_fm/services/recommender.py
Normal file
71
src/haunt_fm/services/recommender.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.models.track import (
|
||||
ListenEvent,
|
||||
TasteProfile,
|
||||
Track,
|
||||
TrackEmbedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_recommendations(
|
||||
session: AsyncSession,
|
||||
limit: int = 50,
|
||||
exclude_known: bool = True,
|
||||
profile_name: str = "default",
|
||||
) -> list[dict]:
|
||||
"""Get track recommendations ranked by cosine similarity to taste profile."""
|
||||
# Load taste profile
|
||||
profile = (
|
||||
await session.execute(select(TasteProfile).where(TasteProfile.name == profile_name))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if profile is None:
|
||||
return []
|
||||
|
||||
# Use pgvector cosine distance operator (<=>)
|
||||
# Lower distance = more similar
|
||||
if exclude_known:
|
||||
# Subquery: track IDs that have listen events
|
||||
known_ids_subq = select(ListenEvent.track_id).distinct().subquery()
|
||||
|
||||
query = text("""
|
||||
SELECT t.id, t.title, t.artist, t.album, t.genre,
|
||||
1 - (te.embedding <=> :profile_embedding) AS similarity
|
||||
FROM track_embeddings te
|
||||
JOIN tracks t ON t.id = te.track_id
|
||||
WHERE te.track_id NOT IN (SELECT track_id FROM listen_events)
|
||||
ORDER BY te.embedding <=> :profile_embedding
|
||||
LIMIT :limit
|
||||
""")
|
||||
else:
|
||||
query = text("""
|
||||
SELECT t.id, t.title, t.artist, t.album, t.genre,
|
||||
1 - (te.embedding <=> :profile_embedding) AS similarity
|
||||
FROM track_embeddings te
|
||||
JOIN tracks t ON t.id = te.track_id
|
||||
ORDER BY te.embedding <=> :profile_embedding
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
result = await session.execute(
|
||||
query,
|
||||
{"profile_embedding": str(profile.embedding), "limit": limit},
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"track_id": row.id,
|
||||
"title": row.title,
|
||||
"artist": row.artist,
|
||||
"album": row.album,
|
||||
"genre": row.genre,
|
||||
"similarity": round(float(row.similarity), 4),
|
||||
}
|
||||
for row in result
|
||||
]
|
||||
88
src/haunt_fm/services/taste_profile.py
Normal file
88
src/haunt_fm/services/taste_profile.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from haunt_fm.models.track import (
|
||||
ListenEvent,
|
||||
TasteProfile,
|
||||
Track,
|
||||
TrackEmbedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _recency_weight(listened_at: datetime, now: datetime, half_life_days: float = 30.0) -> float:
|
||||
"""Exponential decay weight based on recency. Half-life = 30 days by default."""
|
||||
age_days = (now - listened_at).total_seconds() / 86400
|
||||
return 2 ** (-age_days / half_life_days)
|
||||
|
||||
|
||||
async def build_taste_profile(session: AsyncSession, name: str = "default") -> TasteProfile | None:
|
||||
"""Build a taste profile as the weighted average of listened-track embeddings.
|
||||
|
||||
Weights: play_count * recency_decay for each track.
|
||||
"""
|
||||
# Get all listened tracks with embeddings
|
||||
result = await session.execute(
|
||||
select(
|
||||
Track.id,
|
||||
TrackEmbedding.embedding,
|
||||
func.count(ListenEvent.id).label("play_count"),
|
||||
func.max(ListenEvent.listened_at).label("last_played"),
|
||||
)
|
||||
.join(TrackEmbedding, TrackEmbedding.track_id == Track.id)
|
||||
.join(ListenEvent, ListenEvent.track_id == Track.id)
|
||||
.group_by(Track.id, TrackEmbedding.embedding)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
if not rows:
|
||||
logger.warning("No listened tracks with embeddings found")
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
embeddings = []
|
||||
weights = []
|
||||
|
||||
for row in rows:
|
||||
emb = np.array(row.embedding, dtype=np.float32)
|
||||
play_count = row.play_count
|
||||
recency = _recency_weight(row.last_played, now)
|
||||
weight = play_count * recency
|
||||
embeddings.append(emb)
|
||||
weights.append(weight)
|
||||
|
||||
embeddings_arr = np.array(embeddings)
|
||||
weights_arr = np.array(weights, dtype=np.float32)
|
||||
weights_arr /= weights_arr.sum()
|
||||
|
||||
# Weighted average
|
||||
profile_emb = (embeddings_arr * weights_arr[:, np.newaxis]).sum(axis=0)
|
||||
profile_emb = profile_emb / np.linalg.norm(profile_emb)
|
||||
|
||||
# Upsert
|
||||
existing = (
|
||||
await session.execute(select(TasteProfile).where(TasteProfile.name == name))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
existing.embedding = profile_emb.tolist()
|
||||
existing.track_count = len(rows)
|
||||
existing.updated_at = now
|
||||
else:
|
||||
existing = TasteProfile(
|
||||
name=name,
|
||||
embedding=profile_emb.tolist(),
|
||||
track_count=len(rows),
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(existing)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
logger.info("Built taste profile '%s' from %d tracks", name, len(rows))
|
||||
return existing
|
||||
140
src/haunt_fm/templates/status.html
Normal file
140
src/haunt_fm/templates/status.html
Normal file
@@ -0,0 +1,140 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>haunt-fm status</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, monospace; background: #0d1117; color: #c9d1d9; padding: 2rem; }
|
||||
h1 { color: #58a6ff; margin-bottom: 1.5rem; font-size: 1.5rem; }
|
||||
.status-badge { display: inline-block; padding: 0.25rem 0.75rem; border-radius: 1rem; font-size: 0.85rem; font-weight: 600; margin-bottom: 1.5rem; }
|
||||
.status-badge.healthy { background: #238636; color: #fff; }
|
||||
.status-badge.degraded { background: #da3633; color: #fff; }
|
||||
.section { background: #161b22; border: 1px solid #30363d; border-radius: 0.5rem; padding: 1.25rem; margin-bottom: 1rem; }
|
||||
.section h2 { color: #8b949e; font-size: 0.8rem; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 0.75rem; }
|
||||
.row { display: flex; justify-content: space-between; padding: 0.35rem 0; border-bottom: 1px solid #21262d; }
|
||||
.row:last-child { border-bottom: none; }
|
||||
.label { color: #8b949e; }
|
||||
.value { color: #c9d1d9; font-weight: 500; }
|
||||
.dot { display: inline-block; width: 8px; height: 8px; border-radius: 50%; margin-right: 0.4rem; vertical-align: middle; }
|
||||
.dot.green { background: #3fb950; }
|
||||
.dot.red { background: #f85149; }
|
||||
.dot.yellow { background: #d29922; }
|
||||
.dot.gray { background: #484f58; }
|
||||
.timestamp { color: #484f58; font-size: 0.8rem; margin-top: 1.5rem; text-align: center; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>haunt-fm</h1>
|
||||
<span class="status-badge {{ 'healthy' if data.healthy else 'degraded' }}">
|
||||
{{ 'Healthy' if data.healthy else 'Degraded' }}
|
||||
</span>
|
||||
|
||||
<div class="section">
|
||||
<h2>Pipeline</h2>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.db_connected else 'red' }}"></span>Database</span>
|
||||
<span class="value">{{ 'Connected' if data.db_connected else 'Down' }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.clap_model_loaded else 'gray' }}"></span>CLAP Model</span>
|
||||
<span class="value">{{ 'Loaded' if data.clap_model_loaded else 'Not loaded' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Listening History</h2>
|
||||
<div class="row">
|
||||
<span class="label">Total events</span>
|
||||
<span class="value">{{ data.pipeline.listen_events.total }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Last 24h</span>
|
||||
<span class="value">{{ data.pipeline.listen_events.last_24h }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Latest</span>
|
||||
<span class="value">{{ data.pipeline.listen_events.latest or 'Never' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Tracks</h2>
|
||||
<div class="row">
|
||||
<span class="label">Total</span>
|
||||
<span class="value">{{ data.pipeline.tracks.total }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">From history</span>
|
||||
<span class="value">{{ data.pipeline.tracks.from_history }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">From discovery</span>
|
||||
<span class="value">{{ data.pipeline.tracks.from_discovery }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Embeddings</h2>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.pipeline.embeddings.worker_running else 'gray' }}"></span>Worker</span>
|
||||
<span class="value">{{ 'Running' if data.pipeline.embeddings.worker_running else 'Stopped' }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Done</span>
|
||||
<span class="value">{{ data.pipeline.embeddings.done }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Pending</span>
|
||||
<span class="value">{{ data.pipeline.embeddings.pending }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Failed</span>
|
||||
<span class="value">{{ data.pipeline.embeddings.failed }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">No preview</span>
|
||||
<span class="value">{{ data.pipeline.embeddings.no_preview }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Taste Profile</h2>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.pipeline.taste_profile.exists else 'gray' }}"></span>Profile</span>
|
||||
<span class="value">{{ 'Built (' ~ data.pipeline.taste_profile.track_count ~ ' tracks)' if data.pipeline.taste_profile.exists else 'Not built' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Playlists</h2>
|
||||
<div class="row">
|
||||
<span class="label">Generated</span>
|
||||
<span class="value">{{ data.pipeline.playlists.total_generated }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label">Last generated</span>
|
||||
<span class="value">{{ data.pipeline.playlists.last_generated or 'Never' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Dependencies</h2>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.dependencies.lastfm_api == 'ok' else 'gray' }}"></span>Last.fm API</span>
|
||||
<span class="value">{{ data.dependencies.lastfm_api }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.dependencies.itunes_api == 'ok' else 'gray' }}"></span>iTunes API</span>
|
||||
<span class="value">{{ data.dependencies.itunes_api }}</span>
|
||||
</div>
|
||||
<div class="row">
|
||||
<span class="label"><span class="dot {{ 'green' if data.dependencies.ha_reachable else 'gray' }}"></span>Home Assistant</span>
|
||||
<span class="value">{{ 'Reachable' if data.dependencies.ha_reachable else 'Unknown' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p class="timestamp">Updated {{ now }}</p>
|
||||
</body>
|
||||
</html>
|
||||
0
src/haunt_fm/utils/__init__.py
Normal file
0
src/haunt_fm/utils/__init__.py
Normal file
31
src/haunt_fm/utils/audio.py
Normal file
31
src/haunt_fm/utils/audio.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_preview(url: str, cache_dir: str, filename: str) -> Path:
|
||||
"""Download an AAC preview file, return the local path."""
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
filepath = cache_path / filename
|
||||
|
||||
if filepath.exists():
|
||||
return filepath
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
filepath.write_bytes(resp.content)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def decode_audio(filepath: Path, target_sr: int = 48000) -> np.ndarray:
|
||||
"""Decode audio file to numpy array at target sample rate. Uses librosa (ffmpeg backend)."""
|
||||
audio, _ = librosa.load(str(filepath), sr=target_sr, mono=True)
|
||||
return audio
|
||||
25
src/haunt_fm/utils/rate_limiter.py
Normal file
25
src/haunt_fm/utils/rate_limiter.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class AsyncRateLimiter:
|
||||
"""Simple async rate limiter using token bucket."""
|
||||
|
||||
def __init__(self, rate: float, per: float = 1.0):
|
||||
"""Args:
|
||||
rate: Number of requests allowed per `per` seconds.
|
||||
per: Time window in seconds.
|
||||
"""
|
||||
self._rate = rate
|
||||
self._per = per
|
||||
self._min_interval = per / rate
|
||||
self._last_time = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_time
|
||||
if elapsed < self._min_interval:
|
||||
await asyncio.sleep(self._min_interval - elapsed)
|
||||
self._last_time = time.monotonic()
|
||||
Reference in New Issue
Block a user