diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 7396dbe197db..832be742e7c8 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -369,7 +369,9 @@ class NeighborSampler(NeighborSamplerImpl): link prediction, the process needs another pre-peocess operation. That is, gathering unique nodes from the given node pairs, encompassing both positive and negative node pairs, and employs these nodes as the seed nodes - for subsequent steps. + for subsequent steps. When the graph is hetero, sampled subgraphs in + minibatch will contain every edge type even though it is empty after + sampling. Parameters ---------- @@ -479,7 +481,9 @@ class LayerNeighborSampler(NeighborSamplerImpl): link prediction, the process needs another pre-process operation. That is, gathering unique nodes from the given node pairs, encompassing both positive and negative node pairs, and employs these nodes as the seed nodes - for subsequent steps. + for subsequent steps. When the graph is hetero, sampled subgraphs in + minibatch will contain every edge type even though it is empty after + sampling. Implements the approach described in Appendix A.3 of the paper. Similar to dgl.dataloading.LaborSampler but this uses sequential poisson sampling diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index 726010453333..0692599bcb0c 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -5,8 +5,8 @@ import torch -import dgl -from dgl.utils import recursive_apply +from ..convert import create_block, DGLBlock, EID, NID +from ..utils import recursive_apply from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr from .internal import get_attributes, get_nonproperty_attributes @@ -114,6 +114,8 @@ class MiniBatch: all node ids inside are compacted. """ + _blocks: List[DGLBlock] = None + def __repr__(self) -> str: return _minibatch_str(self) @@ -157,13 +159,21 @@ def set_edge_features( self.edge_features = edge_features @property - def blocks(self): - """Extracts DGL blocks from `MiniBatch` to construct a graphical - structure and ID mappings. + def blocks(self) -> List[DGLBlock]: + """DGL blocks extracted from `MiniBatch` containing graphical structures + and ID mappings. """ if not self.sampled_subgraphs: return None + if self._blocks is None: + self._blocks = self.compute_blocks() + return self._blocks + + def compute_blocks(self) -> List[DGLBlock]: + """Extracts DGL blocks from `MiniBatch` to construct graphical + structures and ID mappings. + """ is_heterogeneous = isinstance( self.sampled_subgraphs[0].sampled_csc, Dict ) @@ -192,10 +202,15 @@ def cast_to_minimum_dtype(v: CSCFormatBase): original_column_node_ids is not None ), "Missing `original_column_node_ids` in sampled subgraph." if is_heterogeneous: + node_types = set() + sampled_csc = {} for v in subgraph.sampled_csc.values(): cast_to_minimum_dtype(v) - sampled_csc = { - etype_str_to_tuple(etype): ( + for etype, v in subgraph.sampled_csc.items(): + etype_tuple = etype_str_to_tuple(etype) + node_types.add(etype_tuple[0]) + node_types.add(etype_tuple[2]) + sampled_csc[etype_tuple] = ( "csc", ( v.indptr, @@ -208,15 +223,21 @@ def cast_to_minimum_dtype(v: CSCFormatBase): ), ), ) - for etype, v in subgraph.sampled_csc.items() - } num_src_nodes = { - ntype: nodes.size(0) - for ntype, nodes in original_row_node_ids.items() + ntype: ( + original_row_node_ids[ntype].size(0) + if original_row_node_ids.get(ntype) is not None + else 0 + ) + for ntype in node_types } num_dst_nodes = { - ntype: nodes.size(0) - for ntype, nodes in original_column_node_ids.items() + ntype: ( + original_column_node_ids[ntype].size(0) + if original_column_node_ids.get(ntype) is not None + else 0 + ) + for ntype in node_types } else: sampled_csc = cast_to_minimum_dtype(subgraph.sampled_csc) @@ -236,7 +257,7 @@ def cast_to_minimum_dtype(v: CSCFormatBase): num_src_nodes = original_row_node_ids.size(0) num_dst_nodes = original_column_node_ids.size(0) blocks.append( - dgl.create_block( + create_block( sampled_csc, num_src_nodes=num_src_nodes, num_dst_nodes=num_dst_nodes, @@ -249,7 +270,7 @@ def cast_to_minimum_dtype(v: CSCFormatBase): for node_type, reverse_ids in self.sampled_subgraphs[ 0 ].original_row_node_ids.items(): - blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids + blocks[0].srcnodes[node_type].data[NID] = reverse_ids # Assign reverse edges ids. for block, subgraph in zip(blocks, self.sampled_subgraphs): if subgraph.original_edge_ids: @@ -258,16 +279,16 @@ def cast_to_minimum_dtype(v: CSCFormatBase): reverse_ids, ) in subgraph.original_edge_ids.items(): block.edges[etype_str_to_tuple(edge_type)].data[ - dgl.EID + EID ] = reverse_ids else: - blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[ + blocks[0].srcdata[NID] = self.sampled_subgraphs[ 0 ].original_row_node_ids # Assign reverse edges ids. for block, subgraph in zip(blocks, self.sampled_subgraphs): if subgraph.original_edge_ids is not None: - block.edata[dgl.EID] = subgraph.original_edge_ids + block.edata[EID] = subgraph.original_edge_ids return blocks def to_pyg_data(self): @@ -345,6 +366,8 @@ def _minibatch_str(minibatch: MiniBatch) -> str: attributes.reverse() # Insert key with its value into the string. for name in attributes: + if name[0] == "_": + continue val = getattr(minibatch, name) def _add_indent(_str, indent): diff --git a/tests/python/pytorch/graphbolt/test_minibatch.py b/tests/python/pytorch/graphbolt/test_minibatch.py index 91522722430a..e762cee3fdc6 100644 --- a/tests/python/pytorch/graphbolt/test_minibatch.py +++ b/tests/python/pytorch/graphbolt/test_minibatch.py @@ -128,12 +128,16 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): relation: gb.CSCFormatBase( indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype), indices=torch.tensor([1, 0], dtype=indices_dtype), - ) + ), + reverse_relation: gb.CSCFormatBase( + indptr=torch.tensor([0, 2], dtype=indptr_dtype), + indices=torch.tensor([1, 0], dtype=indices_dtype), + ), }, ] original_column_node_ids = [ {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])}, - {"B": torch.tensor([10, 11])}, + {"B": torch.tensor([10, 11]), "A": torch.tensor([5])}, ] original_row_node_ids = [ { @@ -196,10 +200,12 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): ), SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32), indices=tensor([1, 0], dtype=torch.int32), + ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 2], dtype=torch.int32), + indices=tensor([1, 0], dtype=torch.int32), )}, original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])}, original_edge_ids={'A:r:B': tensor([10, 12])}, - original_column_node_ids={'B': tensor([10, 11])}, + original_column_node_ids={'B': tensor([10, 11]), 'A': tensor([5])}, )], node_features={('A', 'x'): tensor([6, 4, 0, 1])}, labels={'B': tensor([2, 5])}, @@ -213,9 +219,9 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2}, metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2}, - num_dst_nodes={'B': 2}, - num_edges={('A', 'r', 'B'): 2}, - metagraph=[('A', 'B', 'r')])], + num_dst_nodes={'A': 1, 'B': 2}, + num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 2}, + metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])], )""" ) result = str(minibatch) @@ -284,12 +290,16 @@ def test_get_dgl_blocks_hetero(): { relation: gb.CSCFormatBase( indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0]) - ) + ), + reverse_relation: gb.CSCFormatBase( + indptr=torch.tensor([0, 1]), + indices=torch.tensor([1]), + ), }, ] original_column_node_ids = [ {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])}, - {"B": torch.tensor([10, 11])}, + {"B": torch.tensor([10, 11]), "A": torch.tensor([5])}, ] original_row_node_ids = [ { @@ -328,14 +338,96 @@ def test_get_dgl_blocks_hetero(): num_dst_nodes={'A': 4, 'B': 3}, num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2}, metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2}, - num_dst_nodes={'B': 2}, - num_edges={('A', 'r', 'B'): 2}, - metagraph=[('A', 'B', 'r')])]""" + num_dst_nodes={'A': 1, 'B': 2}, + num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 1}, + metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])]""" ) result = str(dgl_blocks) assert result == expect_result +def test_get_dgl_blocks_hetero_partial_empty_edges(): + hg = dgl.heterograph( + { + ("n1", "e1", "n1"): ([0, 1, 1], [1, 2, 0]), + ("n1", "e2", "n2"): ([0, 1, 2], [1, 0, 2]), + } + ) + + gb_g = gb.from_dglgraph(hg, is_homogeneous=False) + + train_set = gb.HeteroItemSet( + {"n1:e2:n2": gb.ItemSet(torch.LongTensor([[0, 1]]), names="seeds")} + ) + datapipe = gb.ItemSampler(train_set, batch_size=1) + datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1]) + dataloader = gb.DataLoader(datapipe) + blocks_str = str(next(iter(dataloader)).blocks) + expected_str = """[Block(num_src_nodes={'n1': 2, 'n2': 0}, + num_dst_nodes={'n1': 2, 'n2': 0}, + num_edges={('n1', 'e1', 'n1'): 2, ('n1', 'e2', 'n2'): 0}, + metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')]), Block(num_src_nodes={'n1': 2, 'n2': 0}, + num_dst_nodes={'n1': 1, 'n2': 1}, + num_edges={('n1', 'e1', 'n1'): 1, ('n1', 'e2', 'n2'): 1}, + metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')])]""" + assert expected_str == blocks_str + + +def test_get_dgl_blocks_hetero_empty_edges(): + hg = dgl.heterograph( + { + ("n3", "e1", "n1"): ([0, 1, 1], [1, 2, 0]), + ("n3", "e2", "n2"): ([0, 1, 2], [1, 0, 2]), + } + ) + + gb_g = gb.from_dglgraph(hg, is_homogeneous=False) + + train_set = gb.HeteroItemSet( + {"n3:e1:n1": gb.ItemSet(torch.LongTensor([[2, 1]]), names="seeds")} + ) + datapipe = gb.ItemSampler(train_set, batch_size=1) + datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1]) + dataloader = gb.DataLoader(datapipe) + blocks_str = str(next(iter(dataloader)).blocks) + expected_str = """[Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2}, + num_dst_nodes={'n1': 0, 'n2': 0, 'n3': 2}, + num_edges={('n3', 'e1', 'n1'): 0, ('n3', 'e2', 'n2'): 0}, + metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')]), Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2}, + num_dst_nodes={'n1': 1, 'n2': 0, 'n3': 1}, + num_edges={('n3', 'e1', 'n1'): 1, ('n3', 'e2', 'n2'): 0}, + metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')])]""" + assert expected_str == blocks_str + + +def test_get_dgl_blocks_homo_empty_edges(): + g = dgl.graph(([2, 3, 4], [3, 4, 5])) + + gb_g = gb.from_dglgraph(g, is_homogeneous=True) + train_set = gb.ItemSet(torch.LongTensor([[0, 1]]), names="seeds") + datapipe = gb.ItemSampler(train_set, batch_size=1) + datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1]) + dataloader = gb.DataLoader(datapipe) + blocks_str = str(next(iter(dataloader)).blocks) + expected_str = "[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0), Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0)]" + assert expected_str == blocks_str + + +def test_seeds_ntype_being_passed(): + hg = dgl.heterograph({("n1", "e1", "n2"): ([0, 1, 2], [2, 0, 1])}) + + gb_g = gb.from_dglgraph(hg, is_homogeneous=False) + train_set = gb.HeteroItemSet( + {"n2": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds")} + ) + datapipe = gb.ItemSampler(train_set, batch_size=2) + datapipe = datapipe.sample_neighbor(gb_g, [-1, -1, -1]) + dataloader = gb.DataLoader(datapipe) + blocks = next(iter(dataloader)).blocks + for block in blocks: + assert "n2" in block.srctypes + + def create_homo_minibatch(): csc_formats = [ gb.CSCFormatBase(