diff --git a/CHANGELOG.md b/CHANGELOG.md index 520931e2417e..677156bcde89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) - Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816)) - Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853)) -- Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568)) +- Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568), [#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120)) - Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827)) - Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825)) - Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805)) diff --git a/setup.py b/setup.py index 8b492f3a16d0..74ddbaaf6aa2 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ 'yacs', 'hydra-core', 'protobuf<4.21', - 'pytorch-lightning==1.6.*', + 'pytorch-lightning', ] full_requires = graphgym_requires + [ diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 663f2329b155..0a8768f3ff75 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -22,7 +22,6 @@ """ import copy from abc import abstractmethod -from collections.abc import MutableMapping from dataclasses import dataclass from enum import Enum from typing import Any, List, Optional, Tuple, Union @@ -241,7 +240,14 @@ def __repr__(self) -> str: f'attr={self._attr})') -class FeatureStore(MutableMapping): +# TODO (manan, matthias) Ideally, we want to let `FeatureStore` inherit from +# `MutableMapping` to clearly indicate its behavior and usage to the user. +# However, having `MutableMapping` as a base class leads to strange behavior +# in combination with PyTorch and PyTorch Lightning, in particular since these +# libraries use customized logic during mini-batch for `Mapping` base classes. + + +class FeatureStore: def __init__(self, tensor_attr_cls: Any = TensorAttr): r"""Initializes the feature store. Implementor classes can customize the ordering and required nature of their :class:`TensorAttr` tensor diff --git a/torch_geometric/loader/dataloader.py b/torch_geometric/loader/dataloader.py index 6196d118efe5..6f5ebc74e6df 100644 --- a/torch_geometric/loader/dataloader.py +++ b/torch_geometric/loader/dataloader.py @@ -1,5 +1,4 @@ from collections.abc import Mapping, Sequence -from inspect import signature from typing import List, Optional, Union import torch.utils.data @@ -40,28 +39,6 @@ def collate(self, batch): # Deprecated... return self(batch) -# PyG 'Data' objects are subclasses of MutableMapping, which is an -# instance of collections.abc.Mapping. Currently, PyTorch pin_memory -# for DataLoaders treats the returned batches as Mapping objects and -# calls `pin_memory` on each element in `Data.__dict__`, which is not -# desired behavior if 'Data' has a `pin_memory` function. We patch -# this behavior here by monkeypatching `pin_memory`, but can hopefully patch -# this in PyTorch in the future: -__torch_pin_memory = torch.utils.data._utils.pin_memory.pin_memory -__torch_pin_memory_params = signature(__torch_pin_memory).parameters - - -def pin_memory(data, device=None): - if hasattr(data, "pin_memory"): - return data.pin_memory() - if len(__torch_pin_memory_params) > 1: - return __torch_pin_memory(data, device) - return __torch_pin_memory(data) - - -torch.utils.data._utils.pin_memory.pin_memory = pin_memory - - class DataLoader(torch.utils.data.DataLoader): r"""A data loader which merges data objects from a :class:`torch_geometric.data.Dataset` to a mini-batch.