Fix embedding dimensions and worker session management

Two issues:
1. CLAP model output needed .flatten() to produce a 1-D vector for
   pgvector. Without it, the nested array caused "expected ndim to be 1".
2. Worker now uses a fresh session per track instead of sharing one
   across a batch, preventing PendingRollbackError cascading from one
   failure to the next.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 10:43:44 -06:00
parent 771f714384
commit e9cf1e9b17
2 changed files with 65 additions and 62 deletions

View File

@@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
with torch.no_grad(): with torch.no_grad():
embeddings = _model.get_audio_features(**inputs) embeddings = _model.get_audio_features(**inputs)
# Normalize # Flatten to 1-D and normalize
emb = embeddings[0].numpy() emb = embeddings[0].numpy().flatten()
emb = emb / np.linalg.norm(emb) emb = emb / np.linalg.norm(emb)
return emb return emb

View File

@@ -26,63 +26,62 @@ def last_processed() -> datetime | None:
return _last_processed return _last_processed
async def _process_track(session: AsyncSession, track: Track) -> bool: async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool:
"""Process a single track: find preview, download, embed, store. Returns True on success.""" """Process a single track: find preview, download, embed, store. Returns True on success."""
global _last_processed global _last_processed
# Mark as downloading async with async_session() as session:
await session.execute( # Mark as downloading
update(Track).where(Track.id == track.id).values(embedding_status="downloading")
)
await session.commit()
# Find iTunes preview
if not track.itunes_preview_url:
result = await search_track(track.artist, track.title)
if result is None:
await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="no_preview")
)
await session.commit()
return False
await session.execute( await session.execute(
update(Track) update(Track).where(Track.id == track_id).values(embedding_status="downloading")
.where(Track.id == track.id)
.values(
itunes_track_id=result["track_id"],
itunes_preview_url=result["preview_url"],
apple_music_id=result["apple_music_id"],
duration_ms=result.get("duration_ms"),
genre=result.get("genre"),
)
) )
await session.commit() await session.commit()
preview_url = result["preview_url"]
else:
preview_url = track.itunes_preview_url
# Download and decode # Find iTunes preview
filename = f"{track.id}.m4a" if not preview_url:
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename) result = await search_track(artist, title)
audio = decode_audio(filepath) if result is None:
await session.execute(
update(Track).where(Track.id == track_id).values(embedding_status="no_preview")
)
await session.commit()
return False
# Embed await session.execute(
embedding = embed_audio(audio) update(Track)
.where(Track.id == track_id)
.values(
itunes_track_id=result["track_id"],
itunes_preview_url=result["preview_url"],
apple_music_id=result["apple_music_id"],
duration_ms=result.get("duration_ms"),
genre=result.get("genre"),
)
)
await session.commit()
preview_url = result["preview_url"]
# Store # Download and decode
track_embedding = TrackEmbedding( filename = f"{track_id}.m4a"
track_id=track.id, filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
embedding=embedding.tolist(), audio = decode_audio(filepath)
)
session.add(track_embedding)
await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="done")
)
await session.commit()
_last_processed = datetime.now(timezone.utc) # Embed
return True embedding = embed_audio(audio)
# Store
track_embedding = TrackEmbedding(
track_id=track_id,
embedding=embedding.tolist(),
)
session.add(track_embedding)
await session.execute(
update(Track).where(Track.id == track_id).values(embedding_status="done")
)
await session.commit()
_last_processed = datetime.now(timezone.utc)
return True
async def run_worker(): async def run_worker():
@@ -109,27 +108,31 @@ async def run_worker():
.order_by(Track.created_at) .order_by(Track.created_at)
.limit(settings.embedding_batch_size) .limit(settings.embedding_batch_size)
) )
tracks = result.scalars().all() tracks = [
(t.id, t.artist, t.title, t.itunes_preview_url)
for t in result.scalars().all()
]
if not tracks: if not tracks:
await asyncio.sleep(settings.embedding_interval_seconds) await asyncio.sleep(settings.embedding_interval_seconds)
continue continue
for track in tracks: for track_id, artist, title, preview_url in tracks:
try: try:
await _process_track(session, track) await _process_track(track_id, artist, title, preview_url)
logger.info("Embedded: %s - %s", track.artist, track.title) logger.info("Embedded: %s - %s", artist, title)
except Exception as e: except Exception as e:
logger.exception("Failed to embed %s - %s", track.artist, track.title) logger.exception("Failed to embed %s - %s", artist, title)
await session.execute( async with async_session() as err_session:
await err_session.execute(
update(Track) update(Track)
.where(Track.id == track.id) .where(Track.id == track_id)
.values( .values(
embedding_status="failed", embedding_status="failed",
embedding_error=str(e), embedding_error=str(e),
) )
) )
await session.commit() await err_session.commit()
except Exception: except Exception:
logger.exception("Embedding worker error") logger.exception("Embedding worker error")