Fix embedding dimensions and worker session management
Two issues: 1. CLAP model output needed .flatten() to produce a 1-D vector for pgvector. Without it, the nested array caused "expected ndim to be 1". 2. Worker now uses a fresh session per track instead of sharing one across a batch, preventing PendingRollbackError cascading from one failure to the next. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
embeddings = _model.get_audio_features(**inputs)
|
embeddings = _model.get_audio_features(**inputs)
|
||||||
|
|
||||||
# Normalize
|
# Flatten to 1-D and normalize
|
||||||
emb = embeddings[0].numpy()
|
emb = embeddings[0].numpy().flatten()
|
||||||
emb = emb / np.linalg.norm(emb)
|
emb = emb / np.linalg.norm(emb)
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
@@ -26,63 +26,62 @@ def last_processed() -> datetime | None:
|
|||||||
return _last_processed
|
return _last_processed
|
||||||
|
|
||||||
|
|
||||||
async def _process_track(session: AsyncSession, track: Track) -> bool:
|
async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool:
|
||||||
"""Process a single track: find preview, download, embed, store. Returns True on success."""
|
"""Process a single track: find preview, download, embed, store. Returns True on success."""
|
||||||
global _last_processed
|
global _last_processed
|
||||||
|
|
||||||
# Mark as downloading
|
async with async_session() as session:
|
||||||
await session.execute(
|
# Mark as downloading
|
||||||
update(Track).where(Track.id == track.id).values(embedding_status="downloading")
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
# Find iTunes preview
|
|
||||||
if not track.itunes_preview_url:
|
|
||||||
result = await search_track(track.artist, track.title)
|
|
||||||
if result is None:
|
|
||||||
await session.execute(
|
|
||||||
update(Track).where(Track.id == track.id).values(embedding_status="no_preview")
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
return False
|
|
||||||
|
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(Track)
|
update(Track).where(Track.id == track_id).values(embedding_status="downloading")
|
||||||
.where(Track.id == track.id)
|
|
||||||
.values(
|
|
||||||
itunes_track_id=result["track_id"],
|
|
||||||
itunes_preview_url=result["preview_url"],
|
|
||||||
apple_music_id=result["apple_music_id"],
|
|
||||||
duration_ms=result.get("duration_ms"),
|
|
||||||
genre=result.get("genre"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
preview_url = result["preview_url"]
|
|
||||||
else:
|
|
||||||
preview_url = track.itunes_preview_url
|
|
||||||
|
|
||||||
# Download and decode
|
# Find iTunes preview
|
||||||
filename = f"{track.id}.m4a"
|
if not preview_url:
|
||||||
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
|
result = await search_track(artist, title)
|
||||||
audio = decode_audio(filepath)
|
if result is None:
|
||||||
|
await session.execute(
|
||||||
|
update(Track).where(Track.id == track_id).values(embedding_status="no_preview")
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return False
|
||||||
|
|
||||||
# Embed
|
await session.execute(
|
||||||
embedding = embed_audio(audio)
|
update(Track)
|
||||||
|
.where(Track.id == track_id)
|
||||||
|
.values(
|
||||||
|
itunes_track_id=result["track_id"],
|
||||||
|
itunes_preview_url=result["preview_url"],
|
||||||
|
apple_music_id=result["apple_music_id"],
|
||||||
|
duration_ms=result.get("duration_ms"),
|
||||||
|
genre=result.get("genre"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
preview_url = result["preview_url"]
|
||||||
|
|
||||||
# Store
|
# Download and decode
|
||||||
track_embedding = TrackEmbedding(
|
filename = f"{track_id}.m4a"
|
||||||
track_id=track.id,
|
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
|
||||||
embedding=embedding.tolist(),
|
audio = decode_audio(filepath)
|
||||||
)
|
|
||||||
session.add(track_embedding)
|
|
||||||
await session.execute(
|
|
||||||
update(Track).where(Track.id == track.id).values(embedding_status="done")
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
_last_processed = datetime.now(timezone.utc)
|
# Embed
|
||||||
return True
|
embedding = embed_audio(audio)
|
||||||
|
|
||||||
|
# Store
|
||||||
|
track_embedding = TrackEmbedding(
|
||||||
|
track_id=track_id,
|
||||||
|
embedding=embedding.tolist(),
|
||||||
|
)
|
||||||
|
session.add(track_embedding)
|
||||||
|
await session.execute(
|
||||||
|
update(Track).where(Track.id == track_id).values(embedding_status="done")
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
_last_processed = datetime.now(timezone.utc)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def run_worker():
|
async def run_worker():
|
||||||
@@ -109,27 +108,31 @@ async def run_worker():
|
|||||||
.order_by(Track.created_at)
|
.order_by(Track.created_at)
|
||||||
.limit(settings.embedding_batch_size)
|
.limit(settings.embedding_batch_size)
|
||||||
)
|
)
|
||||||
tracks = result.scalars().all()
|
tracks = [
|
||||||
|
(t.id, t.artist, t.title, t.itunes_preview_url)
|
||||||
|
for t in result.scalars().all()
|
||||||
|
]
|
||||||
|
|
||||||
if not tracks:
|
if not tracks:
|
||||||
await asyncio.sleep(settings.embedding_interval_seconds)
|
await asyncio.sleep(settings.embedding_interval_seconds)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for track in tracks:
|
for track_id, artist, title, preview_url in tracks:
|
||||||
try:
|
try:
|
||||||
await _process_track(session, track)
|
await _process_track(track_id, artist, title, preview_url)
|
||||||
logger.info("Embedded: %s - %s", track.artist, track.title)
|
logger.info("Embedded: %s - %s", artist, title)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to embed %s - %s", track.artist, track.title)
|
logger.exception("Failed to embed %s - %s", artist, title)
|
||||||
await session.execute(
|
async with async_session() as err_session:
|
||||||
|
await err_session.execute(
|
||||||
update(Track)
|
update(Track)
|
||||||
.where(Track.id == track.id)
|
.where(Track.id == track_id)
|
||||||
.values(
|
.values(
|
||||||
embedding_status="failed",
|
embedding_status="failed",
|
||||||
embedding_error=str(e),
|
embedding_error=str(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.commit()
|
await err_session.commit()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Embedding worker error")
|
logger.exception("Embedding worker error")
|
||||||
|
|||||||
Reference in New Issue
Block a user