Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_concurrency option to limit number of concurrent BlobClient connections #288

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from __future__ import absolute_import, division, print_function

import asyncio
import contextlib
from glob import has_magic
import io
import logging
import os
import warnings
import weakref
from typing import Optional

from azure.core.exceptions import (
ClientAuthenticationError,
Expand Down Expand Up @@ -39,6 +41,7 @@
get_blob_metadata,
close_service_client,
close_container_client,
_nullcontext,
)

from datetime import datetime, timedelta
Expand Down Expand Up @@ -354,6 +357,12 @@ class AzureBlobFileSystem(AsyncFileSystem):
default_cache_type: string ('bytes')
If given, the default cache_type value used for "open()". Set to none if no caching
is desired. Docs in fsspec
max_concurrency : int, optional
The maximum number of BlobClient connections that can exist simultaneously for this
filesystem instance. By default, there is no limit. Setting this might be helpful if
you have a very large number of small, independent blob operations to perform. By
default a single BlobClient is created per blob, which might cause high memory usage
and clogging the asyncio event loop as many instances are created and quickly destroyed.

Pass on to fsspec:

Expand Down Expand Up @@ -412,6 +421,7 @@ def __init__(
asynchronous: bool = False,
default_fill_cache: bool = True,
default_cache_type: str = "bytes",
max_concurrency: Optional[int] = None,
**kwargs,
):
super_kwargs = {
Expand Down Expand Up @@ -440,6 +450,13 @@ def __init__(
self.blocksize = blocksize
self.default_fill_cache = default_fill_cache
self.default_cache_type = default_cache_type
self.max_concurrency = max_concurrency

if self.max_concurrency is None:
self._blob_client_semaphore = _nullcontext()
else:
self._blob_client_semaphore = asyncio.Semaphore(max_concurrency)

if (
self.credential is None
and self.account_key is None
Expand All @@ -452,6 +469,7 @@ def __init__(
) = self._get_credential_from_service_principal()
else:
self.sync_credential = None

self.do_connect()
weakref.finalize(self, sync, self.loop, close_service_client, self)

Expand Down Expand Up @@ -491,6 +509,15 @@ def _strip_protocol(cls, path: str):
logger.debug(f"_strip_protocol({path}) = {ops}")
return ops["path"]

@contextlib.asynccontextmanager
async def _get_blob_client(self, container_name, path):
"""
Get a blob client, respecting `self.max_concurrency` if set.
"""
async with self._blob_client_semaphore:
async with self.service_client.get_blob_client(container_name, path) as bc:
yield bc

def _get_credential_from_service_principal(self):
"""
Create a Credential for authentication. This can include a TokenCredential
Expand Down Expand Up @@ -1332,9 +1359,7 @@ async def _isfile(self, path):
return False
else:
try:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
props = await bc.get_blob_properties()
if props["metadata"]["is_directory"] == "false":
return True
Expand Down Expand Up @@ -1393,7 +1418,7 @@ async def _exists(self, path):
# Empty paths exist by definition
return True

async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
if await bc.exists():
return True

Expand All @@ -1411,9 +1436,7 @@ async def _exists(self, path):
async def _pipe_file(self, path, value, overwrite=True, **kwargs):
"""Set the bytes of given file"""
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
result = await bc.upload_blob(
data=value, overwrite=overwrite, metadata={"is_directory": "false"}
)
Expand All @@ -1430,9 +1453,7 @@ async def _cat_file(self, path, start=None, end=None, **kwargs):
else:
length = None
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
try:
stream = await bc.download_blob(offset=start, length=length)
except ResourceNotFoundError as e:
Expand Down Expand Up @@ -1494,7 +1515,7 @@ async def _url(self, path, expires=3600, **kwargs):
expiry=datetime.utcnow() + timedelta(seconds=expires),
)

async with self.service_client.get_blob_client(container_name, blob) as bc:
async with self._get_blob_client(container_name, blob) as bc:
url = f"{bc.url}?{sas_token}"
return url

Expand Down Expand Up @@ -1569,9 +1590,7 @@ async def _put_file(
else:
try:
with open(lpath, "rb") as f1:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.upload_blob(
f1,
overwrite=overwrite,
Expand All @@ -1596,14 +1615,10 @@ async def _cp_file(self, path1, path2, **kwargs):
container1, path1 = self.split_path(path1, delimiter="/")
container2, path2 = self.split_path(path2, delimiter="/")

cc1 = self.service_client.get_container_client(container1)
blobclient1 = cc1.get_blob_client(blob=path1)
if container1 == container2:
blobclient2 = cc1.get_blob_client(blob=path2)
else:
cc2 = self.service_client.get_container_client(container2)
blobclient2 = cc2.get_blob_client(blob=path2)
await blobclient2.start_copy_from_url(blobclient1.url)
# TODO: this could cause a deadlock. Can we protect the user?
async with self._get_blob_client(container1, path1) as blobclient1:
async with self._get_blob_client(container2, path1) as blobclient2:
await blobclient2.start_copy_from_url(blobclient1.url)
self.invalidate_cache(container1)
self.invalidate_cache(container2)

Expand All @@ -1623,7 +1638,7 @@ async def _get_file(
""" Copy single file remote to local """
container_name, path = self.split_path(rpath, delimiter=delimiter)
try:
async with self.service_client.get_blob_client(
async with self._get_blob_client(
container_name, path.rstrip(delimiter)
) as bc:
with open(lpath, "wb") as my_blob:
Expand All @@ -1645,7 +1660,7 @@ def getxattr(self, path, attr):
async def _setxattrs(self, rpath, **kwargs):
container_name, path = self.split_path(rpath)
try:
async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.set_blob_metadata(metadata=kwargs)
self.invalidate_cache(self._parent(rpath))
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os
import tempfile
from unittest import mock
import datetime
import dask.dataframe as dd
from fsspec.implementations.local import LocalFileSystem
Expand Down Expand Up @@ -1348,3 +1350,17 @@ def test_find_with_prefix(storage):
assert test_1s == [test_bucket_name + "/prefixes/test_1"] + [
test_bucket_name + f"/prefixes/test_{cursor}" for cursor in range(10, 20)
]


def test_max_concurrency(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR, max_concurrency=2
)

assert isinstance(fs._blob_client_semaphore, asyncio.Semaphore)

fs._blob_client_semaphore = mock.MagicMock(fs._blob_client_semaphore)
path = {f"/data/{i}": b"value" for i in range(10)}
fs.pipe(path)

assert fs._blob_client_semaphore.__aenter__.call_count == 10
15 changes: 15 additions & 0 deletions adlfs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import contextlib
import sys


async def filter_blobs(blobs, target_path, delimiter="/"):
"""
Filters out blobs that do not come from target_path
Expand Down Expand Up @@ -43,3 +47,14 @@ async def close_container_client(file_obj):
AzureBlobFile objects
"""
await file_obj.container_client.close()


if sys.version_info < (3, 10):
# PYthon 3.10 added support for async to nullcontext
@contextlib.asynccontextmanager
async def _nullcontext(*args):
yield


else:
_nullcontext = contextlib.nullcontext