Fix embedding dimensions and worker session management

Two issues:
1. CLAP model output needed .flatten() to produce a 1-D vector for
   pgvector. Without it, the nested array caused "expected ndim to be 1".
2. Worker now uses a fresh session per track instead of sharing one
   across a batch, preventing PendingRollbackError cascading from one
   failure to the next.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 10:43:44 -06:00
parent 771f714384
commit e9cf1e9b17
2 changed files with 65 additions and 62 deletions

View File

@@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
with torch.no_grad():
embeddings = _model.get_audio_features(**inputs)
# Normalize
emb = embeddings[0].numpy()
# Flatten to 1-D and normalize
emb = embeddings[0].numpy().flatten()
emb = emb / np.linalg.norm(emb)
return emb

View File

@@ -26,63 +26,62 @@ def last_processed() -> datetime | None:
return _last_processed
async def _process_track(session: AsyncSession, track: Track) -> bool:
async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool:
"""Process a single track: find preview, download, embed, store. Returns True on success."""
global _last_processed
# Mark as downloading
await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="downloading")
)
await session.commit()
# Find iTunes preview
if not track.itunes_preview_url:
result = await search_track(track.artist, track.title)
if result is None:
await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="no_preview")
)
await session.commit()
return False
async with async_session() as session:
# Mark as downloading
await session.execute(
update(Track)
.where(Track.id == track.id)
.values(
itunes_track_id=result["track_id"],
itunes_preview_url=result["preview_url"],
apple_music_id=result["apple_music_id"],
duration_ms=result.get("duration_ms"),
genre=result.get("genre"),
)
update(Track).where(Track.id == track_id).values(embedding_status="downloading")
)
await session.commit()
preview_url = result["preview_url"]
else:
preview_url = track.itunes_preview_url
# Download and decode
filename = f"{track.id}.m4a"
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
audio = decode_audio(filepath)
# Find iTunes preview
if not preview_url:
result = await search_track(artist, title)
if result is None:
await session.execute(
update(Track).where(Track.id == track_id).values(embedding_status="no_preview")
)
await session.commit()
return False
# Embed
embedding = embed_audio(audio)
await session.execute(
update(Track)
.where(Track.id == track_id)
.values(
itunes_track_id=result["track_id"],
itunes_preview_url=result["preview_url"],
apple_music_id=result["apple_music_id"],
duration_ms=result.get("duration_ms"),
genre=result.get("genre"),
)
)
await session.commit()
preview_url = result["preview_url"]
# Store
track_embedding = TrackEmbedding(
track_id=track.id,
embedding=embedding.tolist(),
)
session.add(track_embedding)
await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="done")
)
await session.commit()
# Download and decode
filename = f"{track_id}.m4a"
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
audio = decode_audio(filepath)
_last_processed = datetime.now(timezone.utc)
return True
# Embed
embedding = embed_audio(audio)
# Store
track_embedding = TrackEmbedding(
track_id=track_id,
embedding=embedding.tolist(),
)
session.add(track_embedding)
await session.execute(
update(Track).where(Track.id == track_id).values(embedding_status="done")
)
await session.commit()
_last_processed = datetime.now(timezone.utc)
return True
async def run_worker():
@@ -109,27 +108,31 @@ async def run_worker():
.order_by(Track.created_at)
.limit(settings.embedding_batch_size)
)
tracks = result.scalars().all()
tracks = [
(t.id, t.artist, t.title, t.itunes_preview_url)
for t in result.scalars().all()
]
if not tracks:
await asyncio.sleep(settings.embedding_interval_seconds)
continue
if not tracks:
await asyncio.sleep(settings.embedding_interval_seconds)
continue
for track in tracks:
try:
await _process_track(session, track)
logger.info("Embedded: %s - %s", track.artist, track.title)
except Exception as e:
logger.exception("Failed to embed %s - %s", track.artist, track.title)
await session.execute(
for track_id, artist, title, preview_url in tracks:
try:
await _process_track(track_id, artist, title, preview_url)
logger.info("Embedded: %s - %s", artist, title)
except Exception as e:
logger.exception("Failed to embed %s - %s", artist, title)
async with async_session() as err_session:
await err_session.execute(
update(Track)
.where(Track.id == track.id)
.where(Track.id == track_id)
.values(
embedding_status="failed",
embedding_error=str(e),
)
)
await session.commit()
await err_session.commit()
except Exception:
logger.exception("Embedding worker error")