From 1b25b0a1349b32cb3fa1745b6e3cf1ad9159132d Mon Sep 17 00:00:00 2001 From: Dominic Tarro <57306102+dominictarro@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:01:11 -0400 Subject: [PATCH] Fix `S3Bucket.copy_object` target path resolution (#385) Co-authored-by: nate nowack --- prefect_aws/s3.py | 34 +++++++++++++--------- tests/test_s3.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 14 deletions(-) diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 2a2feebc..fc81898a 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -1234,20 +1234,23 @@ async def copy_object( """ s3_client = self.credentials.get_s3_client() + source_bucket_name = self.bucket_name source_path = self._resolve_path(Path(from_path).as_posix()) - target_path = self._resolve_path(Path(to_path).as_posix()) - source_bucket_name = self.bucket_name - target_bucket_name = self.bucket_name + # Default to copying within the same bucket + to_bucket = to_bucket or self + + target_bucket_name: str + target_path: str if isinstance(to_bucket, S3Bucket): target_bucket_name = to_bucket.bucket_name - target_path = to_bucket._resolve_path(target_path) + target_path = to_bucket._resolve_path(Path(to_path).as_posix()) elif isinstance(to_bucket, str): target_bucket_name = to_bucket - elif to_bucket is not None: + target_path = Path(to_path).as_posix() + else: raise TypeError( - "to_bucket must be a string or S3Bucket, not" - f" {type(target_bucket_name)}" + f"to_bucket must be a string or S3Bucket, not {type(to_bucket)}" ) self.logger.info( @@ -1316,20 +1319,23 @@ async def move_object( """ s3_client = self.credentials.get_s3_client() + source_bucket_name = self.bucket_name source_path = self._resolve_path(Path(from_path).as_posix()) - target_path = self._resolve_path(Path(to_path).as_posix()) - source_bucket_name = self.bucket_name - target_bucket_name = self.bucket_name + # Default to moving within the same bucket + to_bucket = to_bucket or self + + target_bucket_name: str + target_path: str if isinstance(to_bucket, S3Bucket): target_bucket_name = to_bucket.bucket_name - target_path = to_bucket._resolve_path(target_path) + target_path = to_bucket._resolve_path(Path(to_path).as_posix()) elif isinstance(to_bucket, str): target_bucket_name = to_bucket - elif to_bucket is not None: + target_path = Path(to_path).as_posix() + else: raise TypeError( - "to_bucket must be a string or S3Bucket, not" - f" {type(target_bucket_name)}" + f"to_bucket must be a string or S3Bucket, not {type(to_bucket)}" ) self.logger.info( diff --git a/tests/test_s3.py b/tests/test_s3.py index 3dc83d91..4958b07e 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1002,6 +1002,42 @@ def test_copy_object( s3_bucket_with_object.copy_object("object", "object_copy_4", s3_bucket_2_empty) assert s3_bucket_2_empty.read_path("object_copy_4") == b"TEST" + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + @pytest.mark.parametrize( + "to_bucket, bucket_folder, expected_path", + [ + # to_bucket=None uses the s3_bucket_2_empty fixture + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + ], + ) + def test_copy_subpaths( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + to_bucket, + bucket_folder, + expected_path, + ): + if to_bucket is None: + to_bucket = s3_bucket_2_empty + if bucket_folder is not None: + to_bucket.bucket_folder = bucket_folder + else: + # For testing purposes, don't use bucket folder unless specified + to_bucket.bucket_folder = None + + key = s3_bucket_with_object.copy_object( + "object", + "object", + to_bucket=to_bucket, + ) + assert key == expected_path + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) def test_move_object_within_bucket( self, @@ -1046,3 +1082,39 @@ def test_move_object_between_buckets( with pytest.raises(ClientError): assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + @pytest.mark.parametrize( + "to_bucket, bucket_folder, expected_path", + [ + # to_bucket=None uses the s3_bucket_2_empty fixture + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + ], + ) + def test_move_subpaths( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + to_bucket, + bucket_folder, + expected_path, + ): + if to_bucket is None: + to_bucket = s3_bucket_2_empty + if bucket_folder is not None: + to_bucket.bucket_folder = bucket_folder + else: + # For testing purposes, don't use bucket folder unless specified + to_bucket.bucket_folder = None + + key = s3_bucket_with_object.move_object( + "object", + "object", + to_bucket=to_bucket, + ) + assert key == expected_path