diff --git a/src/nendo/library/sqlalchemy_library.py b/src/nendo/library/sqlalchemy_library.py index ab0bc76..d89b460 100644 --- a/src/nendo/library/sqlalchemy_library.py +++ b/src/nendo/library/sqlalchemy_library.py @@ -2165,16 +2165,16 @@ def add_track_to_collection( def add_tracks_to_collection( self, - track_ids: List[Union[str, uuid.UUID]], collection_id: Union[str, uuid.UUID], + track_ids: List[Union[str, uuid.UUID]], meta: Optional[Dict[str, Any]] = None, ) -> schema.NendoCollection: """Creates a relationship from the track to the collection. Args: - track_ids (List[Union[str, uuid.UUID]]): List of track ids to add. collection_id (Union[str, uuid.UUID]): ID of the collection to which to add the track. + track_ids (List[Union[str, uuid.UUID]]): List of track ids to add. meta (Dict[str, Any], optional): Metadata of the relationship. Returns: @@ -2525,6 +2525,64 @@ def remove_track_from_collection( collection_id=collection_id, session=session, ) + + def remove_tracks_from_collection( + self, + collection_id: Union[str, uuid.UUID], + track_ids: List[Union[str, uuid.UUID]], + meta: Optional[Dict[str, Any]] = None, + ) -> bool: + """Creates a relationship from the track to the collection. + + Args: + collection_id (Union[str, uuid.UUID]): ID of the collection from + which to remove the tracks. + track_ids (List[Union[str, uuid.UUID]]): List of track ids to remove. + meta (Dict[str, Any], optional): Metadata of the relationship. + + Returns: + success (bool): True if removal was successful, False otherwise. + """ + with self.session_scope() as session: + # Convert IDs to UUIDs if they're strings + collection_id = ensure_uuid(collection_id) + track_ids = [ensure_uuid(track_id) for track_id in track_ids] + + # Check the collection object + collection = ( + session.query(model.NendoCollectionDB) + .filter_by(id=collection_id) + .first() + ) + if not collection: + raise schema.NendoCollectionNotFoundError( + "The collection does not exist", + collection_id, + ) + existing_track_ids = ( + session.query(model.NendoTrackDB) + .filter(model.NendoTrackDB.id.in_(track_ids)) + .all() + ) + existing_track_ids = [t.id for t in existing_track_ids] + missing_ids = [tid for tid in track_ids if tid not in existing_track_ids] + if len(missing_ids) > 0: + self.logger.warning( + f"Tracks with the following IDs not found: {missing_ids}" + ) + + # remove relationships from all tracks to the collection + results = [] + for track_id in track_ids: + # TODO this could be made more efficient by deleting them in bulk + result = self._remove_track_from_collection_db( + track_id=track_id, + collection_id=collection_id, + session=session, + ) + results.append(result) + session.commit() + return all(results) def update_collection( self,