Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distGB]fix the problem when graph has few nodes or edges in distributed partition #7824

Merged
merged 20 commits into from
Oct 18, 2024
193 changes: 192 additions & 1 deletion tests/tools/test_dist_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile

import dgl
import dgl.backend as F

import numpy as np
import pyarrow.parquet as pq
Expand All @@ -19,7 +20,8 @@

from distpartitioning import array_readwriter
from distpartitioning.utils import generate_read_list
from pytest_utils import create_chunked_dataset
from pytest_utils import chunk_graph, create_chunked_dataset
from scipy import sparse as spsp

from tools.verification_utils import (
verify_graph_feats,
Expand Down Expand Up @@ -202,6 +204,103 @@ def test_chunk_graph_arbitrary_chunks(
)


def create_mini_chunked_dataset(
root_dir,
num_chunks,
data_fmt,
edges_fmt,
vector_rows,
few_entity="node",
**kwargs,
):
num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
etypes = [
("n1", "r1", "n2"),
("n2", "r1", "n1"),
("n1", "r2", "n3"),
("n2", "r3", "n3"),
]
node_items = ["n1", "n2", "n3"]
edges_coo = {}
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(
num_nodes[src_ntype],
num_nodes[dst_ntype],
density=0.001,
format="coo",
random_state=100,
)
edges_coo[etype] = (arr.row, arr.col)
edge_items = []
if few_entity == "edge":
edges_coo[("n1", "a0", "n2")] = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
edges_coo[("n1", "a1", "n3")] = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
edge_items.append(("n1", "a0", "n2"))
edge_items.append(("n1", "a1", "n3"))
elif few_entity == "node":
edges_coo[("n1", "r_few", "n_few")] = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
edges_coo[("a0", "a01", "n_1")] = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
edge_items.append(("n1", "r_few", "n_few"))
edge_items.append(("a0", "a01", "n_1"))
node_items.append("n_few")
node_items.append("n_1")
num_nodes["n_few"] = 2
num_nodes["n_1"] = 2
g = dgl.heterograph(edges_coo)

node_data = {}
edge_data = {}
# save feature
input_dir = os.path.join(root_dir, "data_test")

for ntype in node_items:
os.makedirs(os.path.join(input_dir, ntype))
feat = np.random.randn(num_nodes[ntype], 3)
feat_path = os.path.join(input_dir, f"{ntype}/feat.npy")
with open(feat_path, "wb") as f:
np.save(f, feat)
g.nodes[ntype].data["feat"] = torch.from_numpy(feat)
node_data[ntype] = {"feat": feat_path}

for etype in set(edge_items):
os.makedirs(os.path.join(input_dir, etype[1]))
num_edge = len(edges_coo[etype][0])
feat = np.random.randn(num_edge, 4)
feat_path = os.path.join(input_dir, f"{etype[1]}/feat.npy")
with open(feat_path, "wb") as f:
np.save(f, feat)
g.edges[etype].data["feat"] = torch.from_numpy(feat)
edge_data[etype] = {"feat": feat_path}

output_dir = os.path.join(root_dir, "chunked-data")
chunk_graph(
g,
"mag240m",
node_data,
edge_data,
num_chunks=num_chunks,
output_path=output_dir,
data_fmt=data_fmt,
edges_fmt=edges_fmt,
vector_rows=vector_rows,
**kwargs,
)
return g


def _test_pipeline(
num_chunks,
num_parts,
Expand Down Expand Up @@ -373,6 +472,98 @@ def test_pipeline_feature_format(data_fmt):
_test_pipeline(4, 4, 4, data_fmt=data_fmt)


@pytest.mark.parametrize(
"num_chunks, num_parts, world_size",
[[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],
)
@pytest.mark.parametrize("few_entity", ["node", "edge"])
def test_partition_hetero_few_entity(
num_chunks,
num_parts,
world_size,
few_entity,
graph_formats=None,
data_fmt="numpy",
edges_fmt="csv",
vector_rows=False,
num_chunks_nodes=None,
num_chunks_edges=None,
num_chunks_node_data=None,
num_chunks_edge_data=None,
):
with tempfile.TemporaryDirectory() as root_dir:
g = create_mini_chunked_dataset(
root_dir,
num_chunks,
few_entity=few_entity,
data_fmt=data_fmt,
edges_fmt=edges_fmt,
vector_rows=vector_rows,
num_chunks_nodes=num_chunks_nodes,
num_chunks_edges=num_chunks_edges,
num_chunks_node_data=num_chunks_node_data,
num_chunks_edge_data=num_chunks_edge_data,
)

# Step1: graph partition
in_dir = os.path.join(root_dir, "chunked-data")
output_dir = os.path.join(root_dir, "parted_data")
os.system(
"python3 tools/partition_algo/random_partition.py "
"--in_dir {} --out_dir {} --num_partitions {}".format(
in_dir, output_dir, num_parts
)
)

# Step2: data dispatch
partition_dir = os.path.join(root_dir, "parted_data")
out_dir = os.path.join(root_dir, "partitioned")
ip_config = os.path.join(root_dir, "ip_config.txt")
with open(ip_config, "w") as f:
for i in range(world_size):
f.write(f"127.0.0.{i + 1}\n")

cmd = "python3 tools/dispatch_data.py"
cmd += f" --in-dir {in_dir}"
cmd += f" --partitions-dir {partition_dir}"
cmd += f" --out-dir {out_dir}"
cmd += f" --ip-config {ip_config}"
cmd += " --ssh-port 22"
cmd += " --process-group-timeout 60"
cmd += " --save-orig-nids"
cmd += " --save-orig-eids"
cmd += f" --graph-formats {graph_formats}" if graph_formats else ""
os.system(cmd)

# read original node/edge IDs
def read_orig_ids(fname):
orig_ids = {}
for i in range(num_parts):
ids_path = os.path.join(out_dir, f"part{i}", fname)
part_ids = load_tensors(ids_path)
for type, data in part_ids.items():
if type not in orig_ids:
orig_ids[type] = data
else:
orig_ids[type] = torch.cat((orig_ids[type], data))
return orig_ids

orig_nids = read_orig_ids("orig_nids.dgl")
orig_eids = read_orig_ids("orig_eids.dgl")

# load partitions and verify
part_config = os.path.join(out_dir, "metadata.json")
for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
part_config, i
)
verify_partition_data_types(part_g)
verify_partition_formats(part_g, graph_formats)
verify_graph_feats(
g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
)


def test_utils_generate_read_list():
read_list = generate_read_list(10, 4)
assert np.array_equal(read_list[0], np.array([0, 1, 2]))
Expand Down
54 changes: 51 additions & 3 deletions tools/distpartitioning/convert_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dgl.graphbolt as gb
import numpy as np
import torch as th
import torch.distributed as dist
from dgl import EID, ETYPE, NID, NTYPE

from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID
Expand Down Expand Up @@ -355,6 +356,34 @@ def _process_partition_gb(
return indptr, indices[sorted_idx], edge_ids[sorted_idx]


def _update_node_map(node_map_val, end_ids_per_rank, id_ntypes, prev_last_id):
"""this function is modified from the function '_update_node_edge_map' in dgl.distributed.partition"""
# Update the node_map_val to be contiguous.
rank = dist.get_rank()
prev_end_id = (
end_ids_per_rank[rank - 1].item() if rank > 0 else prev_last_id
)
ntype_ids = {ntype: ntype_id for ntype_id, ntype in enumerate(id_ntypes)}
for ntype_id in list(ntype_ids.values()):
ntype = id_ntypes[ntype_id]
start_id = node_map_val[ntype][0][0]
end_id = node_map_val[ntype][0][1]
if not (start_id == -1 and end_id == -1):
continue
prev_ntype_id = (
ntype_ids[ntype] - 1
if ntype_ids[ntype] > 0
else max(ntype_ids.values())
)
prev_ntype = id_ntypes[prev_ntype_id]
if ntype_ids[ntype] == 0:
node_map_val[ntype][0][0] = prev_end_id
else:
node_map_val[ntype][0][0] = node_map_val[prev_ntype][0][1]
node_map_val[ntype][0][1] = node_map_val[ntype][0][0]
return node_map_val[ntype][0][-1]


def create_graph_object(
tot_node_count,
tot_edge_count,
Expand All @@ -368,6 +397,7 @@ def create_graph_object(
edgeid_offset,
node_typecounts,
edge_typecounts,
last_ids={},
return_orig_nids=False,
return_orig_eids=False,
use_graphbolt=False,
Expand Down Expand Up @@ -512,12 +542,30 @@ def create_graph_object(
shuffle_global_nid_range = (shuffle_global_nids[0], shuffle_global_nids[-1])

# Determine the node ID ranges of different node types.
prev_last_id = last_ids.get(part_id - 1, 0)
for ntype_name in global_nid_ranges:
ntype_id = ntypes_map[ntype_name]
type_nids = shuffle_global_nids[ntype_ids == ntype_id]
node_map_val[ntype_name].append(
[int(type_nids[0]), int(type_nids[-1]) + 1]
)
if len(type_nids) == 0:
node_map_val[ntype_name].append([-1, -1])
else:
node_map_val[ntype_name].append(
[int(type_nids[0]), int(type_nids[-1]) + 1]
)
last_id = th.tensor(
[max(prev_last_id, int(type_nids[-1]) + 1)], dtype=th.int64
)
id_ntypes = list(global_nid_ranges.keys())

gather_last_ids = [
th.zeros(1, dtype=th.int64) for _ in range(dist.get_world_size())
]

dist.all_gather(gather_last_ids, last_id)
prev_last_id = _update_node_map(
node_map_val, gather_last_ids, id_ntypes, prev_last_id
)
last_ids[part_id] = prev_last_id

# process edges
memory_snapshot("CreateDGLObj_AssignEdgeData: ", part_id)
Expand Down
25 changes: 24 additions & 1 deletion tools/distpartitioning/data_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def exchange_feature(
feat_dims_dtype.append(DATA_TYPE_ID[torch.float32])
feature_dimension = 0

feature_dimension_tensor = torch.tensor([feature_dimension])
dist.all_reduce(feature_dimension_tensor, op=dist.ReduceOp.MAX)
feature_dimension = feature_dimension_tensor.item()

logging.debug(f"Sending the feature shape information - {feat_dims_dtype}")
all_dims_dtype = allgather_sizes(
feat_dims_dtype, world_size, num_parts, return_sizes=True
Expand Down Expand Up @@ -553,7 +557,11 @@ def exchange_feature(
else:
cur_features[local_feat_key] = output_feat_list
cur_global_ids[local_feat_key] = output_id_list

else:
cur_features[local_feat_key] = torch.empty(
(0, feature_dimension), dtype=torch.float32
)
cur_global_ids[local_feat_key] = torch.empty((0,), dtype=torch.int64)
return cur_features, cur_global_ids


Expand Down Expand Up @@ -1301,6 +1309,7 @@ def prepare_local_data(src_data, local_part_id):
if params.graph_formats:
graph_formats = params.graph_formats.split(",")

prev_last_ids = {}
for local_part_id in range(params.num_parts // world_size):
# Synchronize for each local partition of the graph object.
dist.barrier()
Expand Down Expand Up @@ -1340,6 +1349,7 @@ def prepare_local_data(src_data, local_part_id):
schema_map[constants.STR_NUM_NODES_PER_TYPE],
),
edge_typecounts,
prev_last_ids,
return_orig_nids=params.save_orig_nids,
return_orig_eids=params.save_orig_eids,
use_graphbolt=params.use_graphbolt,
Expand Down Expand Up @@ -1390,6 +1400,19 @@ def prepare_local_data(src_data, local_part_id):
] = json_metadata
memory_snapshot("MetadataCreateComplete: ", rank)

last_id_tensor = torch.tensor(
[prev_last_ids[rank + (local_part_id * world_size)]],
dtype=torch.int64,
)
gather_list = [
torch.zeros(1, dtype=torch.int64) for _ in range(world_size)
]
dist.all_gather(gather_list, last_id_tensor)
for rank_id, last_id in enumerate(gather_list):
prev_last_ids[
rank_id + (local_part_id * world_size)
] = last_id.item()

if rank == 0:
# get meta-data from all partitions and merge them on rank-0
metadata_list = gather_metadata_json(output_meta_json, rank, world_size)
Expand Down
5 changes: 5 additions & 0 deletions tools/distpartitioning/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,11 @@ def get_dataset(
autogenerate_column_names=True,
)
parse_options = pyarrow.csv.ParseOptions(delimiter=" ")

if os.path.getsize(edge_file) == 0:
# if getsize() == 0, the file is empty, indicating that the partition doesn't have this attribute.
# The src_ids and dst_ids should remain empty.
continue
with pyarrow.csv.open_csv(
edge_file,
read_options=read_options,
Expand Down
Loading