diff --git a/src/haunt_fm/models/track.py b/src/haunt_fm/models/track.py index 5789cec..b64e576 100644 --- a/src/haunt_fm/models/track.py +++ b/src/haunt_fm/models/track.py @@ -1,7 +1,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import REAL, BigInteger, DateTime, Index, Integer, Text, func +from sqlalchemy import REAL, BigInteger, DateTime, ForeignKey, Index, Integer, Text, func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -37,7 +37,7 @@ class ListenEvent(Base): __tablename__ = "listen_events" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - track_id: Mapped[int] = mapped_column(BigInteger, nullable=False) + track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False) source: Mapped[str] = mapped_column(Text, nullable=False, default="music_assistant") speaker_name: Mapped[str | None] = mapped_column(Text) listened_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) @@ -51,7 +51,7 @@ class TrackEmbedding(Base): __tablename__ = "track_embeddings" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - track_id: Mapped[int] = mapped_column(BigInteger, unique=True, nullable=False) + track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), unique=True, nullable=False) embedding = mapped_column(Vector(512), nullable=False) model_version: Mapped[str] = mapped_column(Text, nullable=False, default="laion/larger_clap_music") created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) @@ -67,8 +67,8 @@ class SimilarityLink(Base): __tablename__ = "similarity_links" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - source_track_id: Mapped[int] = mapped_column(BigInteger, nullable=False) - target_track_id: Mapped[int] = mapped_column(BigInteger, nullable=False) + source_track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False) + target_track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False) lastfm_match: Mapped[float | None] = mapped_column(REAL) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) @@ -103,8 +103,8 @@ class PlaylistTrack(Base): __tablename__ = "playlist_tracks" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - playlist_id: Mapped[int] = mapped_column(BigInteger, nullable=False) - track_id: Mapped[int] = mapped_column(BigInteger, nullable=False) + playlist_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("playlists.id", ondelete="CASCADE"), nullable=False) + track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False) position: Mapped[int] = mapped_column(Integer, nullable=False) is_known: Mapped[bool] = mapped_column(nullable=False) similarity_score: Mapped[float | None] = mapped_column(REAL)