Skip to content

Commit

Permalink
[Dist] fix node/edge map for few nodes/edges (#7785)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Sep 10, 2024
1 parent ad1551d commit bf125d8
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 1 deletion.
103 changes: 103 additions & 0 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,95 @@ def _set_trainer_ids(g, sim_g, node_parts):
g.edges[c_etype].data["trainer_id"] = trainer_id


def _update_node_edge_map(node_map_val, edge_map_val, g, num_parts):
"""
If the original graph contains few nodes or edges for specific node/edge
types, the partitioned graph may have empty partitions for these types. And
the node_map_val and edge_map_val will have -1 for the start and end ID of
these types. This function updates the node_map_val and edge_map_val to be
contiguous.
Example case:
Suppose we have a heterogeneous graph with 3 node/edge types and the number
of partitions is 3. A possible node_map_val or edge_map_val is as follows:
| part_id\\Node/Edge Type| Type A | Type B | Type C |
|------------------------|--------|---------|--------|
| 0 | 0, 1 | -1, -1 | 2, 3 |
| 1 | -1, -1 | 3, 4 | 4, 5 |
| 2 | 5, 6 | 7, 8 | -1, -1|
As node/edge IDs are contiguous in node/edge type for each partition, we can
update the node_map_val and edge_map_val via updating the start and end ID
in row-wise order.
Updated node_map_val or edge_map_val:
| part_id\\Node/Edge Type| Type A | Type B | Type C |
|------------------------|--------|---------|--------|
| 0 | 0, 1 | 1, 1 | 2, 3 |
| 1 | 3, 3 | 3, 4 | 4, 5 |
| 2 | 5, 6 | 7, 8 | 8, 8 |
"""
# Update the node_map_val to be contiguous.
ntype_ids = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
ntype_ids_reverse = {v: k for k, v in ntype_ids.items()}
for part_id in range(num_parts):
for ntype_id in list(ntype_ids.values()):
ntype = ntype_ids_reverse[ntype_id]
start_id = node_map_val[ntype][part_id][0]
end_id = node_map_val[ntype][part_id][1]
if not (start_id == -1 and end_id == -1):
continue
prev_ntype_id = (
ntype_ids[ntype] - 1
if ntype_ids[ntype] > 0
else max(ntype_ids.values())
)
prev_ntype = ntype_ids_reverse[prev_ntype_id]
if ntype_ids[ntype] == 0:
if part_id == 0:
node_map_val[ntype][part_id][0] = 0
else:
node_map_val[ntype][part_id][0] = node_map_val[prev_ntype][
part_id - 1
][1]
else:
node_map_val[ntype][part_id][0] = node_map_val[prev_ntype][
part_id
][1]
node_map_val[ntype][part_id][1] = node_map_val[ntype][part_id][0]
# Update the edge_map_val to be contiguous.
etype_ids = {etype: g.get_etype_id(etype) for etype in g.canonical_etypes}
etype_ids_reverse = {v: k for k, v in etype_ids.items()}
for part_id in range(num_parts):
for etype_id in list(etype_ids.values()):
etype = etype_ids_reverse[etype_id]
start_id = edge_map_val[etype][part_id][0]
end_id = edge_map_val[etype][part_id][1]
if not (start_id == -1 and end_id == -1):
continue
prev_etype_id = (
etype_ids[etype] - 1
if etype_ids[etype] > 0
else max(etype_ids.values())
)
prev_etype = etype_ids_reverse[prev_etype_id]
if etype_ids[etype] == 0:
if part_id == 0:
edge_map_val[etype][part_id][0] = 0
else:
edge_map_val[etype][part_id][0] = edge_map_val[prev_etype][
part_id - 1
][1]
else:
edge_map_val[etype][part_id][0] = edge_map_val[prev_etype][
part_id
][1]
edge_map_val[etype][part_id][1] = edge_map_val[etype][part_id][0]


def partition_graph(
g,
graph_name,
Expand Down Expand Up @@ -1081,6 +1170,9 @@ def get_homogeneous(g, balance_ntypes):
)
for ntype in g.ntypes:
inner_ntype_mask = inner_ntype == g.get_ntype_id(ntype)
if F.sum(F.astype(inner_ntype_mask, F.int64), 0) == 0:
# Skip if there is no node of this type in this partition.
continue
typed_nids = F.boolean_mask(inner_nids, inner_ntype_mask)
# inner node IDs are in a contiguous ID range.
expected_range = np.arange(
Expand All @@ -1097,6 +1189,9 @@ def get_homogeneous(g, balance_ntypes):
)
for etype in g.canonical_etypes:
inner_etype_mask = inner_etype == g.get_etype_id(etype)
if F.sum(F.astype(inner_etype_mask, F.int64), 0) == 0:
# Skip if there is no edge of this type in this partition.
continue
typed_eids = np.sort(
F.asnumpy(F.boolean_mask(inner_eids, inner_etype_mask))
)
Expand All @@ -1123,6 +1218,9 @@ def get_homogeneous(g, balance_ntypes):
val.append(
F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0))
)
if F.sum(F.astype(inner_node_mask, F.int64), 0) == 0:
node_map_val[ntype].append([-1, -1])
continue
inner_nids = F.boolean_mask(
parts[i].ndata[NID], inner_node_mask
)
Expand All @@ -1143,6 +1241,9 @@ def get_homogeneous(g, balance_ntypes):
val.append(
F.as_scalar(F.sum(F.astype(inner_edge_mask, F.int64), 0))
)
if F.sum(F.astype(inner_edge_mask, F.int64), 0) == 0:
edge_map_val[etype].append([-1, -1])
continue
inner_eids = np.sort(
F.asnumpy(
F.boolean_mask(parts[i].edata[EID], inner_edge_mask)
Expand All @@ -1153,6 +1254,8 @@ def get_homogeneous(g, balance_ntypes):
)
val = np.cumsum(val).tolist()
assert val[-1] == g.num_edges(etype)
# Update the node_map_val and edge_map_val to be contiguous.
_update_node_edge_map(node_map_val, edge_map_val, g, num_parts)
else:
node_map_val = {}
edge_map_val = {}
Expand Down
78 changes: 77 additions & 1 deletion tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def _verify_hetero_graph_node_edge_num(
else part.edata
)
if dgl.ETYPE in edata:
assert len(g.canonical_etypes) == len(F.unique(edata[dgl.ETYPE]))
# edata may not contain all edge types.
assert len(g.canonical_etypes) >= len(F.unique(edata[dgl.ETYPE]))
if debug_mode or isinstance(part, dgl.DGLGraph):
for ntype in g.ntypes:
ntype_id = g.get_ntype_id(ntype)
Expand Down Expand Up @@ -516,6 +517,8 @@ def check_hetero_partition(
gpb.map_to_per_etype(F.tensor([0], F.int32))
# These are original per-type IDs.
for etype_id, etype in enumerate(hg.canonical_etypes):
if F.sum((etype_ids == etype_id), 0) == 0:
continue
part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
src_ntype_ids1 = F.boolean_mask(
src_ntype_ids, etype_ids == etype_id
Expand Down Expand Up @@ -2160,3 +2163,76 @@ def test_partition_graph_graphbolt_hetero_find_edges_multi(
graph_formats="coo",
n_jobs=4,
)


@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [4])
@pytest.mark.parametrize("num_trainers_per_machine", [1])
@pytest.mark.parametrize("graph_formats", [None])
def test_partition_hetero_few_edges(
part_method,
num_parts,
num_trainers_per_machine,
graph_formats,
):
os.environ["DGL_DIST_DEBUG"] = "1"
if part_method == "random" and num_parts > 1:
num_trainers_per_machine = 1

# Create a heterograph with 2 edges for one edge type.
hg = create_random_hetero()
edges_coo = {
c_etype: hg.edges(etype=c_etype) for c_etype in hg.canonical_etypes
}
edges_coo[("n1", "a0", "n2")] = (th.tensor([0, 1]), th.tensor([1, 0]))
edges_coo[("n1", "a1", "n3")] = (th.tensor([0, 1]), th.tensor([1, 0]))
hg = dgl.heterograph(edges_coo)

check_hetero_partition(
hg,
part_method,
num_parts,
num_trainers_per_machine,
load_feats=False,
graph_formats=graph_formats,
)
reset_envs()


@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [4])
@pytest.mark.parametrize("num_trainers_per_machine", [1])
@pytest.mark.parametrize("graph_formats", [None])
def test_partition_hetero_few_nodes(
part_method,
num_parts,
num_trainers_per_machine,
graph_formats,
):
os.environ["DGL_DIST_DEBUG"] = "1"
if part_method == "random" and num_parts > 1:
num_trainers_per_machine = 1

# Create a heterograph with 2 nodes for one node type.
hg = create_random_hetero()
edges_coo = {
c_etype: hg.edges(etype=c_etype) for c_etype in hg.canonical_etypes
}
edges_coo[("n1", "r_few", "n_few")] = (th.tensor([0, 1]), th.tensor([1, 0]))
edges_coo[("a0", "a01", "n_1")] = (th.tensor([0, 1]), th.tensor([1, 0]))
hg = dgl.heterograph(edges_coo)

expected_exception = False
try:
check_hetero_partition(
hg,
part_method,
num_parts,
num_trainers_per_machine,
load_feats=False,
graph_formats=graph_formats,
)
except Exception as e:
expected_exception = True
assert expected_exception == (part_method == "metis")
reset_envs()

0 comments on commit bf125d8

Please sign in to comment.