Skip to content

Commit

Permalink
Add CacheMapper to map from remote URL to local cached basename
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 committed Jun 20, 2023
1 parent aff3f42 commit caf9986
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 28 deletions.
57 changes: 57 additions & 0 deletions fsspec/implementations/cache_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import abc
import hashlib
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any


class AbstractCacheMapper(abc.ABC):
"""Abstract super-class for mappers from remote URLs to local cached
basenames.
"""

@abc.abstractmethod
def __call__(self, path: str) -> str:
...

def __eq__(self, other: Any) -> bool:
# Identity only depends on class. When derived classes have attributes
# they will need to be included.
return isinstance(other, type(self))

def __hash__(self) -> int:
# Identity only depends on class. When derived classes have attributes
# they will need to be included.
return hash(type(self))


class BasenameCacheMapper(AbstractCacheMapper):
"""Cache mapper that uses the basename of the remote URL.
Different paths with the same basename will therefore have the same cached
basename.
"""

def __call__(self, path: str) -> str:
return os.path.basename(path)


class HashCacheMapper(AbstractCacheMapper):
"""Cache mapper that uses a hash of the remote URL."""

def __call__(self, path: str) -> str:
return hashlib.sha256(path.encode()).hexdigest()


def create_cache_mapper(same_names: bool) -> AbstractCacheMapper:
"""Factory method to create cache mapper for backward compatibility with
``CachingFileSystem`` constructor using ``same_names`` kwarg.
"""
if same_names:
return BasenameCacheMapper()
else:
return HashCacheMapper()
48 changes: 20 additions & 28 deletions fsspec/implementations/cached.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from __future__ import annotations

import contextlib
import hashlib
import inspect
import logging
import os
import pickle
import tempfile
import time
from shutil import rmtree
from typing import ClassVar
from typing import Any, ClassVar

from fsspec import AbstractFileSystem, filesystem
from fsspec.callbacks import _DEFAULT_CALLBACK
from fsspec.compression import compr
from fsspec.core import BaseCache, MMapCache
from fsspec.exceptions import BlocksizeMismatchError
from fsspec.implementations.cache_mapper import create_cache_mapper
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import infer_compression

Expand Down Expand Up @@ -115,9 +115,7 @@ def __init__(
self.check_files = check_files
self.expiry = expiry_time
self.compression = compression
# TODO: same_names should allow for variable prefix, not only
# to keep the basename
self.same_names = same_names
self._mapper = create_cache_mapper(same_names)
self.target_protocol = (
target_protocol
if isinstance(target_protocol, str)
Expand Down Expand Up @@ -255,11 +253,11 @@ def clear_expired_cache(self, expiry_time=None):

for path, detail in self.cached_files[-1].copy().items():
if time.time() - detail["time"] > expiry_time:
if self.same_names:
basename = os.path.basename(detail["original"])
fn = os.path.join(self.storage[-1], basename)
else:
fn = os.path.join(self.storage[-1], detail["fn"])
fn = getattr(detail, "fn", "")
if not fn:
# fn should always be set, but be defensive here.
fn = self._mapper(detail["original"])
fn = os.path.join(self.storage[-1], fn)
if os.path.exists(fn):
os.remove(fn)
self.cached_files[-1].pop(path)
Expand Down Expand Up @@ -339,7 +337,7 @@ def _open(
# TODO: action where partial file exists in read-only cache
logger.debug("Opening partially cached copy of %s" % path)
else:
hash = self.hash_name(path, self.same_names)
hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
blocks = set()
detail = {
Expand Down Expand Up @@ -385,8 +383,10 @@ def _open(
self.save_cache()
return f

def hash_name(self, path, same_name):
return hash_name(path, same_name=same_name)
def hash_name(self, path: str, *args: Any) -> str:
# Kept for backward compatibility with downstream libraries.
# Ignores extra arguments, previously same_name boolean.
return self._mapper(path)

def close_and_update(self, f, close):
"""Called when a file is closing, so store the set of blocks"""
Expand Down Expand Up @@ -488,7 +488,7 @@ def __eq__(self, other):
and self.check_files == other.check_files
and self.expiry == other.expiry
and self.compression == other.compression
and self.same_names == other.same_names
and self._mapper == other._mapper
and self.target_protocol == other.target_protocol
)

Expand All @@ -501,7 +501,7 @@ def __hash__(self):
^ hash(self.check_files)
^ hash(self.expiry)
^ hash(self.compression)
^ hash(self.same_names)
^ hash(self._mapper)
^ hash(self.target_protocol)
)

Expand Down Expand Up @@ -546,7 +546,7 @@ def open_many(self, open_files):
details = [self._check_file(sp) for sp in paths]
downpath = [p for p, d in zip(paths, details) if not d]
downfn0 = [
os.path.join(self.storage[-1], self.hash_name(p, self.same_names))
os.path.join(self.storage[-1], self._mapper(p))
for p, d in zip(paths, details)
] # keep these path names for opening later
downfn = [fn for fn, d in zip(downfn0, details) if not d]
Expand All @@ -558,7 +558,7 @@ def open_many(self, open_files):
newdetail = [
{
"original": path,
"fn": self.hash_name(path, self.same_names),
"fn": self._mapper(path),
"blocks": True,
"time": time.time(),
"uid": self.fs.ukey(path),
Expand Down Expand Up @@ -590,7 +590,7 @@ def commit_many(self, open_files):
pass

def _make_local_details(self, path):
hash = self.hash_name(path, self.same_names)
hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
detail = {
"original": path,
Expand Down Expand Up @@ -731,7 +731,7 @@ def __init__(self, **kwargs):

def _check_file(self, path):
self._check_cache()
sha = self.hash_name(path, self.same_names)
sha = self._mapper(path)
for storage in self.storage:
fn = os.path.join(storage, sha)
if os.path.exists(fn):
Expand All @@ -752,7 +752,7 @@ def _open(self, path, mode="rb", **kwargs):
if fn:
return open(fn, mode)

sha = self.hash_name(path, self.same_names)
sha = self._mapper(path)
fn = os.path.join(self.storage[-1], sha)
logger.debug("Copying %s to local cache" % path)
kwargs["mode"] = mode
Expand Down Expand Up @@ -838,14 +838,6 @@ def __getattr__(self, item):
return getattr(self.fh, item)


def hash_name(path, same_name):
if same_name:
hash = os.path.basename(path)
else:
hash = hashlib.sha256(path.encode()).hexdigest()
return hash


@contextlib.contextmanager
def atomic_write(path, mode="wb"):
"""
Expand Down
25 changes: 25 additions & 0 deletions fsspec/implementations/tests/test_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import fsspec
from fsspec.compression import compr
from fsspec.exceptions import BlocksizeMismatchError
from fsspec.implementations.cache_mapper import create_cache_mapper
from fsspec.implementations.cached import CachingFileSystem, LocalTempFile

from .test_ftp import FTPFileSystem
Expand All @@ -32,6 +33,30 @@ def local_filecache():
return data, original_file, cache_location, fs


def test_mapper():
mapper0 = create_cache_mapper(True)
assert mapper0("/somedir/somefile") == "somefile"
assert mapper0("/otherdir/somefile") == "somefile"

mapper1 = create_cache_mapper(False)
assert (
mapper1("/somedir/somefile")
== "67a6956e5a5f95231263f03758c1fd9254fdb1c564d311674cec56b0372d2056"
)
assert (
mapper1("/otherdir/somefile")
== "f043dee01ab9b752c7f2ecaeb1a5e1b2d872018e2d0a1a26c43835ebf34e7d3e"
)

assert mapper0 != mapper1
assert create_cache_mapper(True) == mapper0
assert create_cache_mapper(False) == mapper1

assert hash(mapper0) != hash(mapper1)
assert hash(create_cache_mapper(True)) == hash(mapper0)
assert hash(create_cache_mapper(False)) == hash(mapper1)


def test_idempotent():
fs = CachingFileSystem("file")
fs2 = CachingFileSystem("file")
Expand Down
1 change: 1 addition & 0 deletions fsspec/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def test_chained_equivalent():
# since the parameters don't quite match. Also, the url understood by the two
# of s are not the same (path gets munged a bit differently)
assert of.fs == of2.fs
assert hash(of.fs) == hash(of2.fs)
assert of.open().read() == of2.open().read()


Expand Down

0 comments on commit caf9986

Please sign in to comment.