Skip to content

Commit

Permalink
[GraphBolt] Alias ItemSetDict to HeteroItemSet (#7466)
Browse files Browse the repository at this point in the history
Co-authored-by: Rhett Ying <[email protected]>
  • Loading branch information
Skeleton003 and Rhett-Ying authored Jun 20, 2024
1 parent f898053 commit 7c3d418
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
48 changes: 47 additions & 1 deletion python/dgl/graphbolt/itemset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import torch

__all__ = ["ItemSet", "HeteroItemSet"]
from ..base import dgl_warning

__all__ = ["ItemSet", "HeteroItemSet", "ItemSetDict"]


def is_scalar(x):
Expand Down Expand Up @@ -406,3 +408,47 @@ def __repr__(self) -> str:
itemsets=itemsets_str,
names=self._names,
)


class ItemSetDict:
"""`ItemSetDict` is a deprecated class and will be removed in a future
version. Please use `HeteroItemSet` instead.
This class is an alias for `HeteroItemSet` and serves as a wrapper to
provide a smooth transition for users of the old class name. It issues a
deprecation warning upon instantiation and forwards all attribute access
and method calls to an instance of `HeteroItemSet`.
"""

def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
dgl_warning(
"ItemSetDict is deprecated and will be removed in the future. "
"Please use HeteroItemSet instead.",
category=DeprecationWarning,
)
self._new_instance = HeteroItemSet(itemsets)

def __getattr__(self, name: str):
return getattr(self._new_instance, name)

def __getitem__(self, index):
return self._new_instance[index]

def __len__(self) -> int:
return len(self._new_instance)

def __repr__(self) -> str:
ret = (
"{Classname}(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
itemsets_str = textwrap.indent(
repr(self._itemsets), " " * len(" itemsets=")
).strip()
return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)
52 changes: 52 additions & 0 deletions tests/python/pytorch/graphbolt/test_itemset.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,55 @@ def test_HeteroItemSet_repr():
")"
)
assert str(item_set) == expected_str, item_set


def test_deprecation_alias():
"""Test `ItemSetDict` as the alias for `HeteroItemSet`."""

user_ids = torch.arange(0, 5)
item_ids = torch.arange(5, 10)
ids = {
"user": gb.ItemSet(user_ids, names="seeds"),
"item": gb.ItemSet(item_ids, names="seeds"),
}
with pytest.warns(
DeprecationWarning,
match="ItemSetDict is deprecated and will be removed in the future. Please use HeteroItemSet instead.",
):
item_set_dict = gb.ItemSetDict(ids)
hetero_item_set = gb.HeteroItemSet(ids)
assert len(item_set_dict) == len(hetero_item_set)
assert item_set_dict.names == hetero_item_set.names
assert item_set_dict._keys == hetero_item_set._keys
assert torch.equal(item_set_dict._offsets, hetero_item_set._offsets)
assert (
repr(item_set_dict)[len("ItemSetDict") :]
== repr(hetero_item_set)[len("HeteroItemSet") :]
)
# Indexing all with a slice.
assert torch.equal(item_set_dict[:]["user"], hetero_item_set[:]["user"])
assert torch.equal(item_set_dict[:]["item"], hetero_item_set[:]["item"])
# Indexing partial with a slice.
partial_data = item_set_dict[:3]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["user"], hetero_item_set[:3]["user"])
partial_data = item_set_dict[7:]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["item"], hetero_item_set[7:]["item"])
partial_data = item_set_dict[3:8:2]
assert len(list(partial_data.keys())) == 2
assert torch.equal(partial_data["user"], hetero_item_set[3:8:2]["user"])
assert torch.equal(partial_data["item"], hetero_item_set[3:8:2]["item"])
# Indexing with an iterable of int.
partial_data = item_set_dict[torch.tensor([1, 0, 4])]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["user"], hetero_item_set[1, 0, 4]["user"])
partial_data = item_set_dict[torch.tensor([9, 8, 5])]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["item"], hetero_item_set[9, 8, 5]["item"])
partial_data = item_set_dict[torch.tensor([8, 1, 0, 9, 7, 5])]
assert len(list(partial_data.keys())) == 2
assert torch.equal(partial_data["user"], hetero_item_set[1, 0]["user"])
assert torch.equal(
partial_data["item"], hetero_item_set[8, 9, 7, 5]["item"]
)

0 comments on commit 7c3d418

Please sign in to comment.