Fix embedding dimensions and worker session management
Two issues: 1. CLAP model output needed .flatten() to produce a 1-D vector for pgvector. Without it, the nested array caused "expected ndim to be 1". 2. Worker now uses a fresh session per track instead of sharing one across a batch, preventing PendingRollbackError cascading from one failure to the next. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
|
||||
with torch.no_grad():
|
||||
embeddings = _model.get_audio_features(**inputs)
|
||||
|
||||
# Normalize
|
||||
emb = embeddings[0].numpy()
|
||||
# Flatten to 1-D and normalize
|
||||
emb = embeddings[0].numpy().flatten()
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
@@ -26,29 +26,30 @@ def last_processed() -> datetime | None:
|
||||
return _last_processed
|
||||
|
||||
|
||||
async def _process_track(session: AsyncSession, track: Track) -> bool:
|
||||
async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool:
|
||||
"""Process a single track: find preview, download, embed, store. Returns True on success."""
|
||||
global _last_processed
|
||||
|
||||
async with async_session() as session:
|
||||
# Mark as downloading
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="downloading")
|
||||
update(Track).where(Track.id == track_id).values(embedding_status="downloading")
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Find iTunes preview
|
||||
if not track.itunes_preview_url:
|
||||
result = await search_track(track.artist, track.title)
|
||||
if not preview_url:
|
||||
result = await search_track(artist, title)
|
||||
if result is None:
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="no_preview")
|
||||
update(Track).where(Track.id == track_id).values(embedding_status="no_preview")
|
||||
)
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
await session.execute(
|
||||
update(Track)
|
||||
.where(Track.id == track.id)
|
||||
.where(Track.id == track_id)
|
||||
.values(
|
||||
itunes_track_id=result["track_id"],
|
||||
itunes_preview_url=result["preview_url"],
|
||||
@@ -59,11 +60,9 @@ async def _process_track(session: AsyncSession, track: Track) -> bool:
|
||||
)
|
||||
await session.commit()
|
||||
preview_url = result["preview_url"]
|
||||
else:
|
||||
preview_url = track.itunes_preview_url
|
||||
|
||||
# Download and decode
|
||||
filename = f"{track.id}.m4a"
|
||||
filename = f"{track_id}.m4a"
|
||||
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
|
||||
audio = decode_audio(filepath)
|
||||
|
||||
@@ -72,12 +71,12 @@ async def _process_track(session: AsyncSession, track: Track) -> bool:
|
||||
|
||||
# Store
|
||||
track_embedding = TrackEmbedding(
|
||||
track_id=track.id,
|
||||
track_id=track_id,
|
||||
embedding=embedding.tolist(),
|
||||
)
|
||||
session.add(track_embedding)
|
||||
await session.execute(
|
||||
update(Track).where(Track.id == track.id).values(embedding_status="done")
|
||||
update(Track).where(Track.id == track_id).values(embedding_status="done")
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -109,27 +108,31 @@ async def run_worker():
|
||||
.order_by(Track.created_at)
|
||||
.limit(settings.embedding_batch_size)
|
||||
)
|
||||
tracks = result.scalars().all()
|
||||
tracks = [
|
||||
(t.id, t.artist, t.title, t.itunes_preview_url)
|
||||
for t in result.scalars().all()
|
||||
]
|
||||
|
||||
if not tracks:
|
||||
await asyncio.sleep(settings.embedding_interval_seconds)
|
||||
continue
|
||||
|
||||
for track in tracks:
|
||||
for track_id, artist, title, preview_url in tracks:
|
||||
try:
|
||||
await _process_track(session, track)
|
||||
logger.info("Embedded: %s - %s", track.artist, track.title)
|
||||
await _process_track(track_id, artist, title, preview_url)
|
||||
logger.info("Embedded: %s - %s", artist, title)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to embed %s - %s", track.artist, track.title)
|
||||
await session.execute(
|
||||
logger.exception("Failed to embed %s - %s", artist, title)
|
||||
async with async_session() as err_session:
|
||||
await err_session.execute(
|
||||
update(Track)
|
||||
.where(Track.id == track.id)
|
||||
.where(Track.id == track_id)
|
||||
.values(
|
||||
embedding_status="failed",
|
||||
embedding_error=str(e),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
await err_session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Embedding worker error")
|
||||
|
||||
Reference in New Issue
Block a user