diff --git a/src/nendo/library/duckdb_library.py b/src/nendo/library/duckdb_library.py index 392be29..6e297ee 100644 --- a/src/nendo/library/duckdb_library.py +++ b/src/nendo/library/duckdb_library.py @@ -48,7 +48,6 @@ def __init__( self.storage_driver = schema.NendoStorageLocalFS( library_path=self.config.library_path, user_id=self.config.user_id, ) - self._connect(db, session) def _connect( diff --git a/src/nendo/library/sqlalchemy_library.py b/src/nendo/library/sqlalchemy_library.py index b4a14f7..bf7d1ab 100644 --- a/src/nendo/library/sqlalchemy_library.py +++ b/src/nendo/library/sqlalchemy_library.py @@ -152,6 +152,7 @@ def _create_track_from_file( raise schema.NendoResourceError("Unsupported filetype", file_path) file_checksum = md5sum(file_path) + file_stats = os.stat(file_path) # skip adding a duplicate based on config flag and hashsum of the file skip_duplicate = skip_duplicate or self.config.skip_duplicate @@ -174,7 +175,6 @@ def _create_track_from_file( copy_to_library = copy_to_library or self.config.copy_to_library if copy_to_library or (self.config.auto_convert and file_path.endswith(".mp3")): try: - file_stats = os.stat(file_path) sr = None if self.config.auto_convert: if file_path.endswith(".mp3"): @@ -220,14 +220,6 @@ def _create_track_from_file( if sr is not None: meta["sr"] = sr - meta.update( - { - "original_filename": os.path.basename(file_path), - "original_filepath": os.path.dirname(file_path), - "original_size": file_stats.st_size, - "original_checksum": file_checksum, - }, - ) location = self.storage_driver.get_driver_location() except Exception as e: # noqa: BLE001 raise schema.NendoLibraryError( @@ -237,6 +229,15 @@ def _create_track_from_file( path_in_library = file_path location = schema.ResourceLocation.original + meta.update( + { + "original_filename": os.path.basename(file_path), + "original_filepath": os.path.dirname(file_path), + "original_size": file_stats.st_size, + "original_checksum": file_checksum, + }, + ) + resource = schema.NendoResource( file_path=self.storage_driver.get_file_path( src=path_in_library, @@ -469,7 +470,7 @@ def _upsert_track_db( return db_track def _upsert_tracks_db( - self, tracks: List[schema.NendoTrackCreate], session: Session, + self, tracks: List[schema.NendoTrackBase], session: Session, ) -> List[model.NendoTrackDB]: """Create multiple tracks in DB or update if it exists. @@ -482,9 +483,25 @@ def _upsert_tracks_db( """ db_tracks = [] for track in tracks: - track_dict = track.model_dump() - track_dict.pop("nendo_instance") - db_tracks.append(model.NendoTrackDB(**track_dict)) + if type(track) == schema.NendoTrackCreate: + # create new track + track_dict = track.model_dump() + track_dict.pop("nendo_instance") + db_tracks.append(model.NendoTrackDB(**track_dict)) + else: + # update existing track + db_tracks.append( + session.query(model.NendoTrackDB).filter_by(id=track.id).one_or_none() + ) + db_track = db_tracks[-1] + if db_track is None: + raise schema.NendoTrackNotFoundError("Track not found", id=track.id) + db_track.user_id = track.user_id + db_track.visibility = track.visibility + db_track.resource = track.resource.model_dump() + db_track.track_type = track.track_type + db_track.images = track.images + db_track.meta = track.meta session.add_all(db_tracks) session.commit() return db_tracks @@ -1377,10 +1394,10 @@ def remove_track( session.delete(target) # only delete if file has been copied to the library # ("original_filepath" is present) + target_track = schema.NendoTrack.model_validate(target) if ( remove_resources - and "original_filepath" - in schema.NendoTrack.model_validate(target).resource.meta + and target_track.resource.location != "original" ): logger.info("Removing resources associated with Track %s", str(track_id)) return self.storage_driver.remove_file( diff --git a/tests/assets/sub_assets/test.mp3 b/tests/assets/sub_assets/test.mp3 index 1762c05..23c2e3a 100644 Binary files a/tests/assets/sub_assets/test.mp3 and b/tests/assets/sub_assets/test.mp3 differ diff --git a/tests/test_library.py b/tests/test_library.py index 093be31..6837435 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -260,6 +260,10 @@ def test_add_tracks_adds_all_files_in_folder(self): self.assertEqual(len(results), 3) results = nd.library.get_tracks(limit=1) self.assertEqual(len(results), 1) + # try adding again, should update existing + nd.library.add_tracks(path="tests/assets") + results = nd.library.get_tracks() + self.assertEqual(len(results), 3) def test_remove_file_from_library(self): """Test the `nd.library.remove_track()` function."""