From 4913a7b691f183a67eddaba129118c3bc34f3ae6 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 6 Jun 2024 08:10:49 +0800 Subject: [PATCH] [DistDGL] add testcase for prob or mask sampling (#7448) --- tests/distributed/test_mp_dataloader.py | 96 ++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_mp_dataloader.py b/tests/distributed/test_mp_dataloader.py index 93bb5667105a..4cf867ecd217 100644 --- a/tests/distributed/test_mp_dataloader.py +++ b/tests/distributed/test_mp_dataloader.py @@ -433,6 +433,7 @@ def start_node_dataloader( groundtruth_g, use_graphbolt=False, return_eids=False, + prob_or_mask=None, ): dgl.distributed.initialize(ip_config) gpb = None @@ -459,6 +460,16 @@ def start_node_dataloader( part, _, _, _, _, _, _ = load_partition(part_config, i) # Create sampler + _prob = None + _mask = None + if prob_or_mask is None: + pass + elif prob_or_mask == "prob": + _prob = "prob" + elif prob_or_mask == "mask": + _mask = "mask" + else: + raise ValueError(f"Unsupported prob type: {prob_or_mask}") sampler = dgl.dataloading.MultiLayerNeighborSampler( [ ( @@ -468,7 +479,9 @@ def start_node_dataloader( else 5 ), 10, - ] + ], + prob=_prob, + mask=_mask, ) # test int for hetero # Enable santity check in distributed sampling. @@ -514,6 +527,12 @@ def start_node_dataloader( assert th.equal( eids, expected_eids ), f"{eids} != {expected_eids}" + # Verify the prob/mask functionality. + if prob_or_mask is not None: + prob_data = groundtruth_g.edges[c_etype].data[ + prob_or_mask + ][eids] + assert th.all(prob_data > 0) del dataloader # this is needed since there's two test here in one process dgl.distributed.exit_client() @@ -532,6 +551,7 @@ def start_edge_dataloader( reverse_eids, reverse_etypes, negative, + prob_or_mask, ): dgl.distributed.initialize(ip_config) gpb = None @@ -554,7 +574,19 @@ def start_edge_dataloader( part, _, _, _, _, _, _ = load_partition(part_config, i) # Create sampler - sampler = dgl.dataloading.MultiLayerNeighborSampler([5, -1]) + _prob = None + _mask = None + if prob_or_mask is None: + pass + elif prob_or_mask == "prob": + _prob = "prob" + elif prob_or_mask == "mask": + _mask = "mask" + else: + raise ValueError(f"Unsupported prob type: {prob_or_mask}") + sampler = dgl.dataloading.MultiLayerNeighborSampler( + [5, -1], prob=_prob, mask=_mask + ) # Negative sampler. negative_sampler = None @@ -639,6 +671,12 @@ def start_edge_dataloader( assert th.equal( raw_dst, orig_nid[dst_type][sampled_orig_dst] ) + # Verify the prob/mask functionality. + if prob_or_mask is not None: + prob_data = groundtruth_g.edges[etype].data[ + prob_or_mask + ][sampled_orig_eids] + assert th.all(prob_data > 0) # Verify the exclude functionality. if dgl.EID not in blocks[-1].edata.keys(): continue @@ -701,6 +739,7 @@ def check_dataloader( reverse_eids=None, reverse_etypes=None, negative=False, + prob_or_mask=None, ): with tempfile.TemporaryDirectory() as test_dir: ip_config = "ip_config.txt" @@ -760,6 +799,7 @@ def check_dataloader( g, use_graphbolt, return_eids, + prob_or_mask, ), ) p.start() @@ -780,6 +820,7 @@ def check_dataloader( reverse_eids, reverse_etypes, negative, + prob_or_mask, ), ) p.start() @@ -879,6 +920,31 @@ def test_edge_dataloader_homograph( ) +@pytest.mark.parametrize("num_server", [1]) +@pytest.mark.parametrize("num_workers", [1]) +@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) +@pytest.mark.parametrize("use_graphbolt", [False]) +@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"]) +def test_dataloader_homograph_prob_or_mask( + num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask +): + reset_envs() + g = CitationGraphDataset("cora")[0] + prob = th.rand(g.num_edges()) + mask = prob > 0.2 + g.edata["prob"] = F.tensor(prob) + g.edata["mask"] = F.tensor(mask) + check_dataloader( + g, + num_server, + num_workers, + dataloader_type, + use_graphbolt=use_graphbolt, + return_eids=True, + prob_or_mask=prob_or_mask, + ) + + @pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_workers", [0, 1]) @pytest.mark.parametrize("dataloader_type", ["node", "edge"]) @@ -927,6 +993,32 @@ def test_edge_dataloader_heterograph( ) +@pytest.mark.parametrize("num_server", [1]) +@pytest.mark.parametrize("num_workers", [1]) +@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) +@pytest.mark.parametrize("use_graphbolt", [False]) +@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"]) +def test_dataloader_heterograph_prob_or_mask( + num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask +): + reset_envs() + g = create_random_hetero() + for etype in g.canonical_etypes: + prob = th.rand(g.num_edges(etype)) + mask = prob > prob.median() + g.edges[etype].data["prob"] = prob + g.edges[etype].data["mask"] = mask + check_dataloader( + g, + num_server, + num_workers, + dataloader_type, + use_graphbolt=use_graphbolt, + return_eids=True, + prob_or_mask=prob_or_mask, + ) + + @unittest.skip(reason="Skip due to glitch in CI") @pytest.mark.parametrize("num_server", [3]) @pytest.mark.parametrize("num_workers", [0, 4])