Skip to content

Commit

Permalink
[GraphBolt] Fix blocks in minibatch when facing with empty edges in…
Browse files Browse the repository at this point in the history
… subgraph. (#7413)

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Skeleton003 <799284168.qq.com>
Co-authored-by: Skeleton003 <[email protected]>
Co-authored-by: Mingbang Wang <[email protected]>
Co-authored-by: Rhett Ying <[email protected]>
  • Loading branch information
6 people authored Jun 13, 2024
1 parent 7f1d164 commit e489e38
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 31 deletions.
8 changes: 6 additions & 2 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ class NeighborSampler(NeighborSamplerImpl):
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
for subsequent steps. When the graph is hetero, sampled subgraphs in
minibatch will contain every edge type even though it is empty after
sampling.
Parameters
----------
Expand Down Expand Up @@ -479,7 +481,9 @@ class LayerNeighborSampler(NeighborSamplerImpl):
link prediction, the process needs another pre-process operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
for subsequent steps. When the graph is hetero, sampled subgraphs in
minibatch will contain every edge type even though it is empty after
sampling.
Implements the approach described in Appendix A.3 of the paper. Similar to
dgl.dataloading.LaborSampler but this uses sequential poisson sampling
Expand Down
59 changes: 41 additions & 18 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch

import dgl
from dgl.utils import recursive_apply
from ..convert import create_block, DGLBlock, EID, NID
from ..utils import recursive_apply

from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
from .internal import get_attributes, get_nonproperty_attributes
Expand Down Expand Up @@ -114,6 +114,8 @@ class MiniBatch:
all node ids inside are compacted.
"""

_blocks: List[DGLBlock] = None

def __repr__(self) -> str:
return _minibatch_str(self)

Expand Down Expand Up @@ -157,13 +159,21 @@ def set_edge_features(
self.edge_features = edge_features

@property
def blocks(self):
"""Extracts DGL blocks from `MiniBatch` to construct a graphical
structure and ID mappings.
def blocks(self) -> List[DGLBlock]:
"""DGL blocks extracted from `MiniBatch` containing graphical structures
and ID mappings.
"""
if not self.sampled_subgraphs:
return None

if self._blocks is None:
self._blocks = self.compute_blocks()
return self._blocks

def compute_blocks(self) -> List[DGLBlock]:
"""Extracts DGL blocks from `MiniBatch` to construct graphical
structures and ID mappings.
"""
is_heterogeneous = isinstance(
self.sampled_subgraphs[0].sampled_csc, Dict
)
Expand Down Expand Up @@ -192,10 +202,15 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
node_types = set()
sampled_csc = {}
for v in subgraph.sampled_csc.values():
cast_to_minimum_dtype(v)
sampled_csc = {
etype_str_to_tuple(etype): (
for etype, v in subgraph.sampled_csc.items():
etype_tuple = etype_str_to_tuple(etype)
node_types.add(etype_tuple[0])
node_types.add(etype_tuple[2])
sampled_csc[etype_tuple] = (
"csc",
(
v.indptr,
Expand All @@ -208,15 +223,21 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
),
),
)
for etype, v in subgraph.sampled_csc.items()
}
num_src_nodes = {
ntype: nodes.size(0)
for ntype, nodes in original_row_node_ids.items()
ntype: (
original_row_node_ids[ntype].size(0)
if original_row_node_ids.get(ntype) is not None
else 0
)
for ntype in node_types
}
num_dst_nodes = {
ntype: nodes.size(0)
for ntype, nodes in original_column_node_ids.items()
ntype: (
original_column_node_ids[ntype].size(0)
if original_column_node_ids.get(ntype) is not None
else 0
)
for ntype in node_types
}
else:
sampled_csc = cast_to_minimum_dtype(subgraph.sampled_csc)
Expand All @@ -236,7 +257,7 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
num_src_nodes = original_row_node_ids.size(0)
num_dst_nodes = original_column_node_ids.size(0)
blocks.append(
dgl.create_block(
create_block(
sampled_csc,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
Expand All @@ -249,7 +270,7 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
for node_type, reverse_ids in self.sampled_subgraphs[
0
].original_row_node_ids.items():
blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids
blocks[0].srcnodes[node_type].data[NID] = reverse_ids
# Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.original_edge_ids:
Expand All @@ -258,16 +279,16 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
reverse_ids,
) in subgraph.original_edge_ids.items():
block.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID
EID
] = reverse_ids
else:
blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
blocks[0].srcdata[NID] = self.sampled_subgraphs[
0
].original_row_node_ids
# Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.original_edge_ids is not None:
block.edata[dgl.EID] = subgraph.original_edge_ids
block.edata[EID] = subgraph.original_edge_ids
return blocks

def to_pyg_data(self):
Expand Down Expand Up @@ -345,6 +366,8 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
attributes.reverse()
# Insert key with its value into the string.
for name in attributes:
if name[0] == "_":
continue
val = getattr(minibatch, name)

def _add_indent(_str, indent):
Expand Down
114 changes: 103 additions & 11 deletions tests/python/pytorch/graphbolt/test_minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0], dtype=indices_dtype),
)
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 2], dtype=indptr_dtype),
indices=torch.tensor([1, 0], dtype=indices_dtype),
),
},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
{"B": torch.tensor([10, 11]), "A": torch.tensor([5])},
]
original_row_node_ids = [
{
Expand Down Expand Up @@ -196,10 +200,12 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
),
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])},
original_column_node_ids={'B': tensor([10, 11]), 'A': tensor([5])},
)],
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
labels={'B': tensor([2, 5])},
Expand All @@ -213,9 +219,9 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]),
Block(num_src_nodes={'A': 2, 'B': 2},
num_dst_nodes={'B': 2},
num_edges={('A', 'r', 'B'): 2},
metagraph=[('A', 'B', 'r')])],
num_dst_nodes={'A': 1, 'B': 2},
num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 2},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])],
)"""
)
result = str(minibatch)
Expand Down Expand Up @@ -284,12 +290,16 @@ def test_get_dgl_blocks_hetero():
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
)
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1]),
indices=torch.tensor([1]),
),
},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
{"B": torch.tensor([10, 11]), "A": torch.tensor([5])},
]
original_row_node_ids = [
{
Expand Down Expand Up @@ -328,14 +338,96 @@ def test_get_dgl_blocks_hetero():
num_dst_nodes={'A': 4, 'B': 3},
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2},
num_dst_nodes={'B': 2},
num_edges={('A', 'r', 'B'): 2},
metagraph=[('A', 'B', 'r')])]"""
num_dst_nodes={'A': 1, 'B': 2},
num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 1},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])]"""
)
result = str(dgl_blocks)
assert result == expect_result


def test_get_dgl_blocks_hetero_partial_empty_edges():
hg = dgl.heterograph(
{
("n1", "e1", "n1"): ([0, 1, 1], [1, 2, 0]),
("n1", "e2", "n2"): ([0, 1, 2], [1, 0, 2]),
}
)

gb_g = gb.from_dglgraph(hg, is_homogeneous=False)

train_set = gb.HeteroItemSet(
{"n1:e2:n2": gb.ItemSet(torch.LongTensor([[0, 1]]), names="seeds")}
)
datapipe = gb.ItemSampler(train_set, batch_size=1)
datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])
dataloader = gb.DataLoader(datapipe)
blocks_str = str(next(iter(dataloader)).blocks)
expected_str = """[Block(num_src_nodes={'n1': 2, 'n2': 0},
num_dst_nodes={'n1': 2, 'n2': 0},
num_edges={('n1', 'e1', 'n1'): 2, ('n1', 'e2', 'n2'): 0},
metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')]), Block(num_src_nodes={'n1': 2, 'n2': 0},
num_dst_nodes={'n1': 1, 'n2': 1},
num_edges={('n1', 'e1', 'n1'): 1, ('n1', 'e2', 'n2'): 1},
metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')])]"""
assert expected_str == blocks_str


def test_get_dgl_blocks_hetero_empty_edges():
hg = dgl.heterograph(
{
("n3", "e1", "n1"): ([0, 1, 1], [1, 2, 0]),
("n3", "e2", "n2"): ([0, 1, 2], [1, 0, 2]),
}
)

gb_g = gb.from_dglgraph(hg, is_homogeneous=False)

train_set = gb.HeteroItemSet(
{"n3:e1:n1": gb.ItemSet(torch.LongTensor([[2, 1]]), names="seeds")}
)
datapipe = gb.ItemSampler(train_set, batch_size=1)
datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])
dataloader = gb.DataLoader(datapipe)
blocks_str = str(next(iter(dataloader)).blocks)
expected_str = """[Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2},
num_dst_nodes={'n1': 0, 'n2': 0, 'n3': 2},
num_edges={('n3', 'e1', 'n1'): 0, ('n3', 'e2', 'n2'): 0},
metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')]), Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2},
num_dst_nodes={'n1': 1, 'n2': 0, 'n3': 1},
num_edges={('n3', 'e1', 'n1'): 1, ('n3', 'e2', 'n2'): 0},
metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')])]"""
assert expected_str == blocks_str


def test_get_dgl_blocks_homo_empty_edges():
g = dgl.graph(([2, 3, 4], [3, 4, 5]))

gb_g = gb.from_dglgraph(g, is_homogeneous=True)
train_set = gb.ItemSet(torch.LongTensor([[0, 1]]), names="seeds")
datapipe = gb.ItemSampler(train_set, batch_size=1)
datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])
dataloader = gb.DataLoader(datapipe)
blocks_str = str(next(iter(dataloader)).blocks)
expected_str = "[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0), Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0)]"
assert expected_str == blocks_str


def test_seeds_ntype_being_passed():
hg = dgl.heterograph({("n1", "e1", "n2"): ([0, 1, 2], [2, 0, 1])})

gb_g = gb.from_dglgraph(hg, is_homogeneous=False)
train_set = gb.HeteroItemSet(
{"n2": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds")}
)
datapipe = gb.ItemSampler(train_set, batch_size=2)
datapipe = datapipe.sample_neighbor(gb_g, [-1, -1, -1])
dataloader = gb.DataLoader(datapipe)
blocks = next(iter(dataloader)).blocks
for block in blocks:
assert "n2" in block.srctypes


def create_homo_minibatch():
csc_formats = [
gb.CSCFormatBase(
Expand Down

0 comments on commit e489e38

Please sign in to comment.