From e9cf1e9b170d8b9804d3d6544044e135bd6dc391 Mon Sep 17 00:00:00 2001 From: Thomas Hallock Date: Sun, 22 Feb 2026 10:43:44 -0600 Subject: [PATCH] Fix embedding dimensions and worker session management Two issues: 1. CLAP model output needed .flatten() to produce a 1-D vector for pgvector. Without it, the nested array caused "expected ndim to be 1". 2. Worker now uses a fresh session per track instead of sharing one across a batch, preventing PendingRollbackError cascading from one failure to the next. Co-Authored-By: Claude Opus 4.6 --- src/haunt_fm/services/embedding.py | 4 +- src/haunt_fm/services/embedding_worker.py | 123 +++++++++++----------- 2 files changed, 65 insertions(+), 62 deletions(-) diff --git a/src/haunt_fm/services/embedding.py b/src/haunt_fm/services/embedding.py index 9a1245e..842e0ce 100644 --- a/src/haunt_fm/services/embedding.py +++ b/src/haunt_fm/services/embedding.py @@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray: with torch.no_grad(): embeddings = _model.get_audio_features(**inputs) - # Normalize - emb = embeddings[0].numpy() + # Flatten to 1-D and normalize + emb = embeddings[0].numpy().flatten() emb = emb / np.linalg.norm(emb) return emb diff --git a/src/haunt_fm/services/embedding_worker.py b/src/haunt_fm/services/embedding_worker.py index 7c4f2f6..69af361 100644 --- a/src/haunt_fm/services/embedding_worker.py +++ b/src/haunt_fm/services/embedding_worker.py @@ -26,63 +26,62 @@ def last_processed() -> datetime | None: return _last_processed -async def _process_track(session: AsyncSession, track: Track) -> bool: +async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool: """Process a single track: find preview, download, embed, store. Returns True on success.""" global _last_processed - # Mark as downloading - await session.execute( - update(Track).where(Track.id == track.id).values(embedding_status="downloading") - ) - await session.commit() - - # Find iTunes preview - if not track.itunes_preview_url: - result = await search_track(track.artist, track.title) - if result is None: - await session.execute( - update(Track).where(Track.id == track.id).values(embedding_status="no_preview") - ) - await session.commit() - return False - + async with async_session() as session: + # Mark as downloading await session.execute( - update(Track) - .where(Track.id == track.id) - .values( - itunes_track_id=result["track_id"], - itunes_preview_url=result["preview_url"], - apple_music_id=result["apple_music_id"], - duration_ms=result.get("duration_ms"), - genre=result.get("genre"), - ) + update(Track).where(Track.id == track_id).values(embedding_status="downloading") ) await session.commit() - preview_url = result["preview_url"] - else: - preview_url = track.itunes_preview_url - # Download and decode - filename = f"{track.id}.m4a" - filepath = await download_preview(preview_url, settings.audio_cache_dir, filename) - audio = decode_audio(filepath) + # Find iTunes preview + if not preview_url: + result = await search_track(artist, title) + if result is None: + await session.execute( + update(Track).where(Track.id == track_id).values(embedding_status="no_preview") + ) + await session.commit() + return False - # Embed - embedding = embed_audio(audio) + await session.execute( + update(Track) + .where(Track.id == track_id) + .values( + itunes_track_id=result["track_id"], + itunes_preview_url=result["preview_url"], + apple_music_id=result["apple_music_id"], + duration_ms=result.get("duration_ms"), + genre=result.get("genre"), + ) + ) + await session.commit() + preview_url = result["preview_url"] - # Store - track_embedding = TrackEmbedding( - track_id=track.id, - embedding=embedding.tolist(), - ) - session.add(track_embedding) - await session.execute( - update(Track).where(Track.id == track.id).values(embedding_status="done") - ) - await session.commit() + # Download and decode + filename = f"{track_id}.m4a" + filepath = await download_preview(preview_url, settings.audio_cache_dir, filename) + audio = decode_audio(filepath) - _last_processed = datetime.now(timezone.utc) - return True + # Embed + embedding = embed_audio(audio) + + # Store + track_embedding = TrackEmbedding( + track_id=track_id, + embedding=embedding.tolist(), + ) + session.add(track_embedding) + await session.execute( + update(Track).where(Track.id == track_id).values(embedding_status="done") + ) + await session.commit() + + _last_processed = datetime.now(timezone.utc) + return True async def run_worker(): @@ -109,27 +108,31 @@ async def run_worker(): .order_by(Track.created_at) .limit(settings.embedding_batch_size) ) - tracks = result.scalars().all() + tracks = [ + (t.id, t.artist, t.title, t.itunes_preview_url) + for t in result.scalars().all() + ] - if not tracks: - await asyncio.sleep(settings.embedding_interval_seconds) - continue + if not tracks: + await asyncio.sleep(settings.embedding_interval_seconds) + continue - for track in tracks: - try: - await _process_track(session, track) - logger.info("Embedded: %s - %s", track.artist, track.title) - except Exception as e: - logger.exception("Failed to embed %s - %s", track.artist, track.title) - await session.execute( + for track_id, artist, title, preview_url in tracks: + try: + await _process_track(track_id, artist, title, preview_url) + logger.info("Embedded: %s - %s", artist, title) + except Exception as e: + logger.exception("Failed to embed %s - %s", artist, title) + async with async_session() as err_session: + await err_session.execute( update(Track) - .where(Track.id == track.id) + .where(Track.id == track_id) .values( embedding_status="failed", embedding_error=str(e), ) ) - await session.commit() + await err_session.commit() except Exception: logger.exception("Embedding worker error")