diff --git a/fsspec/implementations/cache_mapper.py b/fsspec/implementations/cache_mapper.py new file mode 100644 index 000000000..f9ee29ac2 --- /dev/null +++ b/fsspec/implementations/cache_mapper.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import abc +import hashlib +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + +class AbstractCacheMapper(abc.ABC): + """Abstract super-class for mappers from remote URLs to local cached + basenames. + """ + + @abc.abstractmethod + def __call__(self, path: str) -> str: + ... + + def __eq__(self, other: Any) -> bool: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return isinstance(other, type(self)) + + def __hash__(self) -> int: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return hash(type(self)) + + +class BasenameCacheMapper(AbstractCacheMapper): + """Cache mapper that uses the basename of the remote URL. + + Different paths with the same basename will therefore have the same cached + basename. + """ + + def __call__(self, path: str) -> str: + return os.path.basename(path) + + +class HashCacheMapper(AbstractCacheMapper): + """Cache mapper that uses a hash of the remote URL.""" + + def __call__(self, path: str) -> str: + return hashlib.sha256(path.encode()).hexdigest() + + +def create_cache_mapper(same_names: bool) -> AbstractCacheMapper: + """Factory method to create cache mapper for backward compatibility with + ``CachingFileSystem`` constructor using ``same_names`` kwarg. + """ + if same_names: + return BasenameCacheMapper() + else: + return HashCacheMapper() diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index 379cf04cf..8838b81a4 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import hashlib import inspect import logging import os @@ -9,13 +8,14 @@ import tempfile import time from shutil import rmtree -from typing import ClassVar +from typing import Any, ClassVar from fsspec import AbstractFileSystem, filesystem from fsspec.callbacks import _DEFAULT_CALLBACK from fsspec.compression import compr from fsspec.core import BaseCache, MMapCache from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper from fsspec.spec import AbstractBufferedFile from fsspec.utils import infer_compression @@ -115,9 +115,7 @@ def __init__( self.check_files = check_files self.expiry = expiry_time self.compression = compression - # TODO: same_names should allow for variable prefix, not only - # to keep the basename - self.same_names = same_names + self._mapper = create_cache_mapper(same_names) self.target_protocol = ( target_protocol if isinstance(target_protocol, str) @@ -255,11 +253,11 @@ def clear_expired_cache(self, expiry_time=None): for path, detail in self.cached_files[-1].copy().items(): if time.time() - detail["time"] > expiry_time: - if self.same_names: - basename = os.path.basename(detail["original"]) - fn = os.path.join(self.storage[-1], basename) - else: - fn = os.path.join(self.storage[-1], detail["fn"]) + fn = getattr(detail, "fn", "") + if not fn: + # fn should always be set, but be defensive here. + fn = self._mapper(detail["original"]) + fn = os.path.join(self.storage[-1], fn) if os.path.exists(fn): os.remove(fn) self.cached_files[-1].pop(path) @@ -339,7 +337,7 @@ def _open( # TODO: action where partial file exists in read-only cache logger.debug("Opening partially cached copy of %s" % path) else: - hash = self.hash_name(path, self.same_names) + hash = self._mapper(path) fn = os.path.join(self.storage[-1], hash) blocks = set() detail = { @@ -385,8 +383,10 @@ def _open( self.save_cache() return f - def hash_name(self, path, same_name): - return hash_name(path, same_name=same_name) + def hash_name(self, path: str, *args: Any) -> str: + # Kept for backward compatibility with downstream libraries. + # Ignores extra arguments, previously same_name boolean. + return self._mapper(path) def close_and_update(self, f, close): """Called when a file is closing, so store the set of blocks""" @@ -488,7 +488,7 @@ def __eq__(self, other): and self.check_files == other.check_files and self.expiry == other.expiry and self.compression == other.compression - and self.same_names == other.same_names + and self._mapper == other._mapper and self.target_protocol == other.target_protocol ) @@ -501,7 +501,7 @@ def __hash__(self): ^ hash(self.check_files) ^ hash(self.expiry) ^ hash(self.compression) - ^ hash(self.same_names) + ^ hash(self._mapper) ^ hash(self.target_protocol) ) @@ -546,7 +546,7 @@ def open_many(self, open_files): details = [self._check_file(sp) for sp in paths] downpath = [p for p, d in zip(paths, details) if not d] downfn0 = [ - os.path.join(self.storage[-1], self.hash_name(p, self.same_names)) + os.path.join(self.storage[-1], self._mapper(p)) for p, d in zip(paths, details) ] # keep these path names for opening later downfn = [fn for fn, d in zip(downfn0, details) if not d] @@ -558,7 +558,7 @@ def open_many(self, open_files): newdetail = [ { "original": path, - "fn": self.hash_name(path, self.same_names), + "fn": self._mapper(path), "blocks": True, "time": time.time(), "uid": self.fs.ukey(path), @@ -590,7 +590,7 @@ def commit_many(self, open_files): pass def _make_local_details(self, path): - hash = self.hash_name(path, self.same_names) + hash = self._mapper(path) fn = os.path.join(self.storage[-1], hash) detail = { "original": path, @@ -731,7 +731,7 @@ def __init__(self, **kwargs): def _check_file(self, path): self._check_cache() - sha = self.hash_name(path, self.same_names) + sha = self._mapper(path) for storage in self.storage: fn = os.path.join(storage, sha) if os.path.exists(fn): @@ -752,7 +752,7 @@ def _open(self, path, mode="rb", **kwargs): if fn: return open(fn, mode) - sha = self.hash_name(path, self.same_names) + sha = self._mapper(path) fn = os.path.join(self.storage[-1], sha) logger.debug("Copying %s to local cache" % path) kwargs["mode"] = mode @@ -838,14 +838,6 @@ def __getattr__(self, item): return getattr(self.fh, item) -def hash_name(path, same_name): - if same_name: - hash = os.path.basename(path) - else: - hash = hashlib.sha256(path.encode()).hexdigest() - return hash - - @contextlib.contextmanager def atomic_write(path, mode="wb"): """ diff --git a/fsspec/implementations/tests/test_cached.py b/fsspec/implementations/tests/test_cached.py index 53ec289fc..d0b7fca66 100644 --- a/fsspec/implementations/tests/test_cached.py +++ b/fsspec/implementations/tests/test_cached.py @@ -8,6 +8,7 @@ import fsspec from fsspec.compression import compr from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper from fsspec.implementations.cached import CachingFileSystem, LocalTempFile from .test_ftp import FTPFileSystem @@ -32,6 +33,30 @@ def local_filecache(): return data, original_file, cache_location, fs +def test_mapper(): + mapper0 = create_cache_mapper(True) + assert mapper0("/somedir/somefile") == "somefile" + assert mapper0("/otherdir/somefile") == "somefile" + + mapper1 = create_cache_mapper(False) + assert ( + mapper1("/somedir/somefile") + == "67a6956e5a5f95231263f03758c1fd9254fdb1c564d311674cec56b0372d2056" + ) + assert ( + mapper1("/otherdir/somefile") + == "f043dee01ab9b752c7f2ecaeb1a5e1b2d872018e2d0a1a26c43835ebf34e7d3e" + ) + + assert mapper0 != mapper1 + assert create_cache_mapper(True) == mapper0 + assert create_cache_mapper(False) == mapper1 + + assert hash(mapper0) != hash(mapper1) + assert hash(create_cache_mapper(True)) == hash(mapper0) + assert hash(create_cache_mapper(False)) == hash(mapper1) + + def test_idempotent(): fs = CachingFileSystem("file") fs2 = CachingFileSystem("file") diff --git a/fsspec/tests/test_api.py b/fsspec/tests/test_api.py index ed7aa0d3a..8187dda81 100644 --- a/fsspec/tests/test_api.py +++ b/fsspec/tests/test_api.py @@ -308,6 +308,7 @@ def test_chained_equivalent(): # since the parameters don't quite match. Also, the url understood by the two # of s are not the same (path gets munged a bit differently) assert of.fs == of2.fs + assert hash(of.fs) == hash(of2.fs) assert of.open().read() == of2.open().read()