Skip to content

Commit

Permalink
[DistDGL] add testcase for prob or mask sampling (#7448)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 6, 2024
1 parent 3edc195 commit 4913a7b
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def start_node_dataloader(
groundtruth_g,
use_graphbolt=False,
return_eids=False,
prob_or_mask=None,
):
dgl.distributed.initialize(ip_config)
gpb = None
Expand All @@ -459,6 +460,16 @@ def start_node_dataloader(
part, _, _, _, _, _, _ = load_partition(part_config, i)

# Create sampler
_prob = None
_mask = None
if prob_or_mask is None:
pass
elif prob_or_mask == "prob":
_prob = "prob"
elif prob_or_mask == "mask":
_mask = "mask"
else:
raise ValueError(f"Unsupported prob type: {prob_or_mask}")
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[
(
Expand All @@ -468,7 +479,9 @@ def start_node_dataloader(
else 5
),
10,
]
],
prob=_prob,
mask=_mask,
) # test int for hetero

# Enable santity check in distributed sampling.
Expand Down Expand Up @@ -514,6 +527,12 @@ def start_node_dataloader(
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
# Verify the prob/mask functionality.
if prob_or_mask is not None:
prob_data = groundtruth_g.edges[c_etype].data[
prob_or_mask
][eids]
assert th.all(prob_data > 0)
del dataloader
# this is needed since there's two test here in one process
dgl.distributed.exit_client()
Expand All @@ -532,6 +551,7 @@ def start_edge_dataloader(
reverse_eids,
reverse_etypes,
negative,
prob_or_mask,
):
dgl.distributed.initialize(ip_config)
gpb = None
Expand All @@ -554,7 +574,19 @@ def start_edge_dataloader(
part, _, _, _, _, _, _ = load_partition(part_config, i)

# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, -1])
_prob = None
_mask = None
if prob_or_mask is None:
pass
elif prob_or_mask == "prob":
_prob = "prob"
elif prob_or_mask == "mask":
_mask = "mask"
else:
raise ValueError(f"Unsupported prob type: {prob_or_mask}")
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[5, -1], prob=_prob, mask=_mask
)

# Negative sampler.
negative_sampler = None
Expand Down Expand Up @@ -639,6 +671,12 @@ def start_edge_dataloader(
assert th.equal(
raw_dst, orig_nid[dst_type][sampled_orig_dst]
)
# Verify the prob/mask functionality.
if prob_or_mask is not None:
prob_data = groundtruth_g.edges[etype].data[
prob_or_mask
][sampled_orig_eids]
assert th.all(prob_data > 0)
# Verify the exclude functionality.
if dgl.EID not in blocks[-1].edata.keys():
continue
Expand Down Expand Up @@ -701,6 +739,7 @@ def check_dataloader(
reverse_eids=None,
reverse_etypes=None,
negative=False,
prob_or_mask=None,
):
with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt"
Expand Down Expand Up @@ -760,6 +799,7 @@ def check_dataloader(
g,
use_graphbolt,
return_eids,
prob_or_mask,
),
)
p.start()
Expand All @@ -780,6 +820,7 @@ def check_dataloader(
reverse_eids,
reverse_etypes,
negative,
prob_or_mask,
),
)
p.start()
Expand Down Expand Up @@ -879,6 +920,31 @@ def test_edge_dataloader_homograph(
)


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
@pytest.mark.parametrize("use_graphbolt", [False])
@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"])
def test_dataloader_homograph_prob_or_mask(
num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask
):
reset_envs()
g = CitationGraphDataset("cora")[0]
prob = th.rand(g.num_edges())
mask = prob > 0.2
g.edata["prob"] = F.tensor(prob)
g.edata["mask"] = F.tensor(mask)
check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=True,
prob_or_mask=prob_or_mask,
)


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
Expand Down Expand Up @@ -927,6 +993,32 @@ def test_edge_dataloader_heterograph(
)


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
@pytest.mark.parametrize("use_graphbolt", [False])
@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"])
def test_dataloader_heterograph_prob_or_mask(
num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask
):
reset_envs()
g = create_random_hetero()
for etype in g.canonical_etypes:
prob = th.rand(g.num_edges(etype))
mask = prob > prob.median()
g.edges[etype].data["prob"] = prob
g.edges[etype].data["mask"] = mask
check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=True,
prob_or_mask=prob_or_mask,
)


@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
Expand Down

0 comments on commit 4913a7b

Please sign in to comment.