Skip to content

Commit

Permalink
[FSTORE-1534] Add support for working with multiple S3 connectors wit…
Browse files Browse the repository at this point in the history
…hin the same application (#1380)
  • Loading branch information
SirOibaf authored Sep 9, 2024
1 parent 0507730 commit 32125ae
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
33 changes: 23 additions & 10 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,41 +1047,54 @@ def setup_storage_connector(self, storage_connector, path=None):
return path

def _setup_s3_hadoop_conf(self, storage_connector, path):
FS_S3_ENDPOINT = "fs.s3a.endpoint"
# For legacy behaviour set the S3 values at global level
self._set_s3_hadoop_conf(storage_connector, "fs.s3a")

# Set credentials at bucket level as well to allow users to use multiple
# storage connector in the same application.
self._set_s3_hadoop_conf(
storage_connector, f"fs.s3a.bucket.{storage_connector.bucket}"
)
return path.replace("s3", "s3a", 1) if path is not None else None

def _set_s3_hadoop_conf(self, storage_connector, prefix):
if storage_connector.access_key:
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", storage_connector.access_key
f"{prefix}.access.key", storage_connector.access_key
)
if storage_connector.secret_key:
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", storage_connector.secret_key
f"{prefix}.secret.key", storage_connector.secret_key
)
if storage_connector.server_encryption_algorithm:
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.server-side-encryption-algorithm",
f"{prefix}.server-side-encryption-algorithm",
storage_connector.server_encryption_algorithm,
)
if storage_connector.server_encryption_key:
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.server-side-encryption-key",
f"{prefix}.server-side-encryption-key",
storage_connector.server_encryption_key,
)
if storage_connector.session_token:
print(f"session token set for {prefix}")
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.aws.credentials.provider",
f"{prefix}.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider",
)
self._spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.session.token",
f"{prefix}.session.token",
storage_connector.session_token,
)

# This is the name of the property as expected from the user, without the bucket name.
FS_S3_ENDPOINT = "fs.s3a.endpoint"
if FS_S3_ENDPOINT in storage_connector.arguments:
self._spark_context._jsc.hadoopConfiguration().set(
FS_S3_ENDPOINT, storage_connector.spark_options().get(FS_S3_ENDPOINT)
f"{prefix}.endpoint",
storage_connector.spark_options().get(FS_S3_ENDPOINT),
)

return path.replace("s3", "s3a", 1) if path is not None else None

def _setup_adls_hadoop_conf(self, storage_connector, path):
for k, v in storage_connector.spark_options().items():
self._spark_context._jsc.hadoopConfiguration().set(k, v)
Expand Down
64 changes: 62 additions & 2 deletions python/tests/engine/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4105,7 +4105,7 @@ def test_setup_storage_connector_jdbc(self, mocker):
assert mock_spark_engine_setup_adls_hadoop_conf.call_count == 0
assert mock_spark_engine_setup_gcp_hadoop_conf.call_count == 0

def test_setup_s3_hadoop_conf(self, mocker):
def test_setup_s3_hadoop_conf_legacy(self, mocker):
# Arrange
mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
Expand All @@ -4117,6 +4117,7 @@ def test_setup_s3_hadoop_conf(self, mocker):
id=1,
name="test_connector",
featurestore_id=99,
bucket="bucket-name",
access_key="1",
secret_key="2",
server_encryption_algorithm="3",
Expand All @@ -4135,7 +4136,7 @@ def test_setup_s3_hadoop_conf(self, mocker):
assert result == "s3a_test_path"
assert (
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.call_count
== 7
== 14
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.access.key", s3_connector.access_key
Expand All @@ -4161,6 +4162,65 @@ def test_setup_s3_hadoop_conf(self, mocker):
"fs.s3a.endpoint", s3_connector.arguments.get("fs.s3a.endpoint")
)

def test_setup_s3_hadoop_conf_bucket_scope(self, mocker):
# Arrange
mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
)

spark_engine = spark.Engine()

s3_connector = storage_connector.S3Connector(
id=1,
name="test_connector",
featurestore_id=99,
bucket="bucket-name",
access_key="1",
secret_key="2",
server_encryption_algorithm="3",
server_encryption_key="4",
session_token="5",
arguments=[{"name": "fs.s3a.endpoint", "value": "testEndpoint"}],
)

# Act
result = spark_engine._setup_s3_hadoop_conf(
storage_connector=s3_connector,
path="s3_test_path",
)

# Assert
assert result == "s3a_test_path"
assert (
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.call_count
== 14
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.access.key", s3_connector.access_key
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.secret.key", s3_connector.secret_key
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.server-side-encryption-algorithm",
s3_connector.server_encryption_algorithm,
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.server-side-encryption-key",
s3_connector.server_encryption_key,
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider",
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.session.token", s3_connector.session_token
)
mock_pyspark_getOrCreate.return_value.sparkContext._jsc.hadoopConfiguration.return_value.set.assert_any_call(
"fs.s3a.bucket.bucket-name.endpoint",
s3_connector.arguments.get("fs.s3a.endpoint"),
)

def test_setup_adls_hadoop_conf(self, mocker):
# Arrange
mock_pyspark_getOrCreate = mocker.patch(
Expand Down

0 comments on commit 32125ae

Please sign in to comment.