diff --git a/fsspec/implementations/cache_mapper.py b/fsspec/implementations/cache_mapper.py new file mode 100644 index 000000000..f9ee29ac2 --- /dev/null +++ b/fsspec/implementations/cache_mapper.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import abc +import hashlib +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + +class AbstractCacheMapper(abc.ABC): + """Abstract super-class for mappers from remote URLs to local cached + basenames. + """ + + @abc.abstractmethod + def __call__(self, path: str) -> str: + ... + + def __eq__(self, other: Any) -> bool: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return isinstance(other, type(self)) + + def __hash__(self) -> int: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return hash(type(self)) + + +class BasenameCacheMapper(AbstractCacheMapper): + """Cache mapper that uses the basename of the remote URL. + + Different paths with the same basename will therefore have the same cached + basename. + """ + + def __call__(self, path: str) -> str: + return os.path.basename(path) + + +class HashCacheMapper(AbstractCacheMapper): + """Cache mapper that uses a hash of the remote URL.""" + + def __call__(self, path: str) -> str: + return hashlib.sha256(path.encode()).hexdigest() + + +def create_cache_mapper(same_names: bool) -> AbstractCacheMapper: + """Factory method to create cache mapper for backward compatibility with + ``CachingFileSystem`` constructor using ``same_names`` kwarg. + """ + if same_names: + return BasenameCacheMapper() + else: + return HashCacheMapper() diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index 1e3bf4fc3..c47c3b290 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import hashlib import inspect import logging import os @@ -9,13 +8,14 @@ import tempfile import time from shutil import rmtree -from typing import ClassVar +from typing import Any, ClassVar from fsspec import AbstractFileSystem, filesystem from fsspec.callbacks import _DEFAULT_CALLBACK from fsspec.compression import compr from fsspec.core import BaseCache, MMapCache from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper from fsspec.spec import AbstractBufferedFile from fsspec.utils import infer_compression @@ -115,9 +115,7 @@ def __init__( self.check_files = check_files self.expiry = expiry_time self.compression = compression - # TODO: same_names should allow for variable prefix, not only - # to keep the basename - self.same_names = same_names + self._mapper = create_cache_mapper(same_names) self.target_protocol = ( target_protocol if isinstance(target_protocol, str) @@ -255,11 +253,12 @@ def clear_expired_cache(self, expiry_time=None): for path, detail in self.cached_files[-1].copy().items(): if time.time() - detail["time"] > expiry_time: - if self.same_names: - basename = os.path.basename(detail["original"]) - fn = os.path.join(self.storage[-1], basename) - else: - fn = os.path.join(self.storage[-1], detail["fn"]) + fn = detail.get("fn", "") + if not fn: + raise RuntimeError( + f"Cache metadata does not contain 'fn' for {path}" + ) + fn = os.path.join(self.storage[-1], fn) if os.path.exists(fn): os.remove(fn) self.cached_files[-1].pop(path) @@ -339,7 +338,7 @@ def _open( # TODO: action where partial file exists in read-only cache logger.debug("Opening partially cached copy of %s" % path) else: - hash = self.hash_name(path, self.same_names) + hash = self._mapper(path) fn = os.path.join(self.storage[-1], hash) blocks = set() detail = { @@ -385,8 +384,10 @@ def _open( self.save_cache() return f - def hash_name(self, path, same_name): - return hash_name(path, same_name=same_name) + def hash_name(self, path: str, *args: Any) -> str: + # Kept for backward compatibility with downstream libraries. + # Ignores extra arguments, previously same_name boolean. + return self._mapper(path) def close_and_update(self, f, close): """Called when a file is closing, so store the set of blocks""" @@ -488,7 +489,7 @@ def __eq__(self, other): and self.check_files == other.check_files and self.expiry == other.expiry and self.compression == other.compression - and self.same_names == other.same_names + and self._mapper == other._mapper and self.target_protocol == other.target_protocol ) @@ -501,7 +502,7 @@ def __hash__(self): ^ hash(self.check_files) ^ hash(self.expiry) ^ hash(self.compression) - ^ hash(self.same_names) + ^ hash(self._mapper) ^ hash(self.target_protocol) ) @@ -546,7 +547,7 @@ def open_many(self, open_files): details = [self._check_file(sp) for sp in paths] downpath = [p for p, d in zip(paths, details) if not d] downfn0 = [ - os.path.join(self.storage[-1], self.hash_name(p, self.same_names)) + os.path.join(self.storage[-1], self._mapper(p)) for p, d in zip(paths, details) ] # keep these path names for opening later downfn = [fn for fn, d in zip(downfn0, details) if not d] @@ -558,7 +559,7 @@ def open_many(self, open_files): newdetail = [ { "original": path, - "fn": self.hash_name(path, self.same_names), + "fn": self._mapper(path), "blocks": True, "time": time.time(), "uid": self.fs.ukey(path), @@ -590,7 +591,7 @@ def commit_many(self, open_files): pass def _make_local_details(self, path): - hash = self.hash_name(path, self.same_names) + hash = self._mapper(path) fn = os.path.join(self.storage[-1], hash) detail = { "original": path, @@ -731,7 +732,7 @@ def __init__(self, **kwargs): def _check_file(self, path): self._check_cache() - sha = self.hash_name(path, self.same_names) + sha = self._mapper(path) for storage in self.storage: fn = os.path.join(storage, sha) if os.path.exists(fn): @@ -752,7 +753,7 @@ def _open(self, path, mode="rb", **kwargs): if fn: return open(fn, mode) - sha = self.hash_name(path, self.same_names) + sha = self._mapper(path) fn = os.path.join(self.storage[-1], sha) logger.debug("Copying %s to local cache" % path) kwargs["mode"] = mode @@ -838,14 +839,6 @@ def __getattr__(self, item): return getattr(self.fh, item) -def hash_name(path, same_name): - if same_name: - hash = os.path.basename(path) - else: - hash = hashlib.sha256(path.encode()).hexdigest() - return hash - - @contextlib.contextmanager def atomic_write(path, mode="wb"): """ diff --git a/fsspec/implementations/tests/test_cached.py b/fsspec/implementations/tests/test_cached.py index 53ec289fc..d8295e778 100644 --- a/fsspec/implementations/tests/test_cached.py +++ b/fsspec/implementations/tests/test_cached.py @@ -8,7 +8,9 @@ import fsspec from fsspec.compression import compr from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper from fsspec.implementations.cached import CachingFileSystem, LocalTempFile +from fsspec.implementations.local import make_path_posix from .test_ftp import FTPFileSystem @@ -32,6 +34,61 @@ def local_filecache(): return data, original_file, cache_location, fs +def test_mapper(): + mapper0 = create_cache_mapper(True) + assert mapper0("/somedir/somefile") == "somefile" + assert mapper0("/otherdir/somefile") == "somefile" + + mapper1 = create_cache_mapper(False) + assert ( + mapper1("/somedir/somefile") + == "67a6956e5a5f95231263f03758c1fd9254fdb1c564d311674cec56b0372d2056" + ) + assert ( + mapper1("/otherdir/somefile") + == "f043dee01ab9b752c7f2ecaeb1a5e1b2d872018e2d0a1a26c43835ebf34e7d3e" + ) + + assert mapper0 != mapper1 + assert create_cache_mapper(True) == mapper0 + assert create_cache_mapper(False) == mapper1 + + assert hash(mapper0) != hash(mapper1) + assert hash(create_cache_mapper(True)) == hash(mapper0) + assert hash(create_cache_mapper(False)) == hash(mapper1) + + +@pytest.mark.parametrize("same_names", [False, True]) +def test_metadata(tmpdir, same_names): + source = os.path.join(tmpdir, "source") + afile = os.path.join(source, "afile") + os.mkdir(source) + open(afile, "w").write("test") + + fs = fsspec.filesystem( + "filecache", + target_protocol="file", + cache_storage=os.path.join(tmpdir, "cache"), + same_names=same_names, + ) + + with fs.open(afile, "rb") as f: + assert f.read(5) == b"test" + + afile_posix = make_path_posix(afile) + detail = fs.cached_files[0][afile_posix] + assert sorted(detail.keys()) == ["blocks", "fn", "original", "time", "uid"] + assert isinstance(detail["blocks"], bool) + assert isinstance(detail["fn"], str) + assert isinstance(detail["time"], float) + assert isinstance(detail["uid"], str) + + assert detail["original"] == afile_posix + assert detail["fn"] == fs._mapper(afile_posix) + if same_names: + assert detail["fn"] == "afile" + + def test_idempotent(): fs = CachingFileSystem("file") fs2 = CachingFileSystem("file") @@ -154,7 +211,7 @@ def test_clear(): def test_clear_expired(tmp_path): - def __ager(cache_fn, fn): + def __ager(cache_fn, fn, del_fn=False): """ Modify the cache file to virtually add time lag to selected files. @@ -164,6 +221,8 @@ def __ager(cache_fn, fn): cache path fn: str file name to be modified + del_fn: bool + whether or not to delete 'fn' from cache details """ import pathlib import time @@ -174,6 +233,8 @@ def __ager(cache_fn, fn): fn_posix = pathlib.Path(fn).as_posix() cached_files[fn_posix]["time"] = cached_files[fn_posix]["time"] - 691200 assert os.access(cache_fn, os.W_OK), "Cache is not writable" + if del_fn: + del cached_files[fn_posix]["fn"] with open(cache_fn, "wb") as f: pickle.dump(cached_files, f) time.sleep(1) @@ -255,6 +316,22 @@ def __ager(cache_fn, fn): fs.clear_expired_cache() assert not fs._check_file(str(f4)) + # check cache metadata lacking 'fn' raises RuntimeError. + fs = fsspec.filesystem( + "filecache", + target_protocol="file", + cache_storage=str(cache1), + same_names=True, + cache_check=1, + ) + assert fs.cat(str(f1)) == data + + cache_fn = os.path.join(fs.storage[-1], "cache") + __ager(cache_fn, f1, del_fn=True) + + with pytest.raises(RuntimeError, match="Cache metadata does not contain 'fn' for"): + fs.clear_expired_cache() + def test_pop(): import tempfile diff --git a/fsspec/tests/test_api.py b/fsspec/tests/test_api.py index ed7aa0d3a..8187dda81 100644 --- a/fsspec/tests/test_api.py +++ b/fsspec/tests/test_api.py @@ -308,6 +308,7 @@ def test_chained_equivalent(): # since the parameters don't quite match. Also, the url understood by the two # of s are not the same (path gets munged a bit differently) assert of.fs == of2.fs + assert hash(of.fs) == hash(of2.fs) assert of.open().read() == of2.open().read()