diff --git a/test/distributed/test_dist_link_neighbor_loader.py b/test/distributed/test_dist_link_neighbor_loader.py new file mode 100644 index 000000000000..39a6a655b2a1 --- /dev/null +++ b/test/distributed/test_dist_link_neighbor_loader.py @@ -0,0 +1,260 @@ +import socket +from typing import Tuple + +import pytest +import torch +import torch.multiprocessing as mp + +from torch_geometric.data import Data, HeteroData +from torch_geometric.datasets import FakeDataset, FakeHeteroDataset +from torch_geometric.distributed import ( + DistContext, + DistLinkNeighborLoader, + DistNeighborSampler, + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.distributed.partition import load_partition_info +from torch_geometric.testing import onlyLinux, withPackage + + +def create_dist_data(tmp_path: str, rank: int): + graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) + feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) + ( + meta, + num_partitions, + partition_idx, + node_pb, + edge_pb, + ) = load_partition_info(tmp_path, rank) + if meta['is_hetero']: + node_pb = torch.cat(list(node_pb.values())) + edge_pb = torch.cat(list(edge_pb.values())) + + graph_store.partition_idx = partition_idx + graph_store.num_partitions = num_partitions + graph_store.node_pb = node_pb + graph_store.edge_pb = edge_pb + graph_store.meta = meta + + feat_store.partition_idx = partition_idx + feat_store.num_partitions = num_partitions + feat_store.node_feat_pb = node_pb + feat_store.edge_feat_pb = edge_pb + feat_store.meta = meta + + return feat_store, graph_store + + +def dist_link_neighbor_loader_homo( + tmp_path: str, + world_size: int, + rank: int, + master_addr: str, + master_port: int, + num_workers: int, + async_sampling: bool, + neg_ratio: float, +): + part_data = create_dist_data(tmp_path, rank) + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-loader-test', + ) + + edge_label_index = part_data[1].get_edge_index(None, 'coo') + edge_label = torch.randint(high=2, size=(edge_label_index.size(1), )) + + loader = DistLinkNeighborLoader( + data=part_data, + edge_label_index=(None, edge_label_index), + edge_label=edge_label if neg_ratio is not None else None, + num_neighbors=[1], + batch_size=10, + num_workers=num_workers, + master_addr=master_addr, + master_port=master_port, + current_ctx=current_ctx, + rpc_worker_names={}, + concurrency=10, + drop_last=True, + async_sampling=async_sampling, + ) + + assert str(loader).startswith('DistLinkNeighborLoader') + assert str(mp.current_process().pid) in str(loader) + assert isinstance(loader.neighbor_sampler, DistNeighborSampler) + assert not part_data[0].meta['is_hetero'] + + for batch in loader: + assert isinstance(batch, Data) + assert batch.n_id.size() == (batch.num_nodes, ) + assert batch.input_id.numel() == batch.batch_size == 10 + assert batch.edge_index.min() >= 0 + assert batch.edge_index.max() < batch.num_nodes + + +def dist_link_neighbor_loader_hetero( + tmp_path: str, + world_size: int, + rank: int, + master_addr: str, + master_port: int, + num_workers: int, + async_sampling: bool, + neg_ratio: float, + edge_type: Tuple[str, str, str], +): + part_data = create_dist_data(tmp_path, rank) + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name="dist-loader-test", + ) + + edge_label_index = part_data[1].get_edge_index(edge_type, 'coo') + edge_label = torch.randint(high=2, size=(edge_label_index.size(1), )) + + loader = DistLinkNeighborLoader( + data=part_data, + edge_label_index=(edge_type, edge_label_index), + edge_label=edge_label if neg_ratio is not None else None, + num_neighbors=[1], + batch_size=10, + num_workers=num_workers, + master_addr=master_addr, + master_port=master_port, + current_ctx=current_ctx, + rpc_worker_names={}, + concurrency=10, + drop_last=True, + async_sampling=async_sampling, + ) + + assert str(loader).startswith('DistLinkNeighborLoader') + assert str(mp.current_process().pid) in str(loader) + assert isinstance(loader.neighbor_sampler, DistNeighborSampler) + assert part_data[0].meta['is_hetero'] + + for batch in loader: + assert isinstance(batch, HeteroData) + assert (batch[edge_type].input_id.numel() == + batch[edge_type].batch_size == 10) + + assert len(batch.node_types) == 2 + for node_type in batch.node_types: + assert torch.equal(batch[node_type].x, batch.x_dict[node_type]) + assert batch.x_dict[node_type].size(0) >= 0 + assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes + + assert len(batch.edge_types) == 4 + for edge_type in batch.edge_types: + assert (batch[edge_type].edge_attr.size(0) == + batch[edge_type].edge_index.size(1)) + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('num_parts', [2]) +@pytest.mark.parametrize('num_workers', [0]) +@pytest.mark.parametrize('async_sampling', [True]) +@pytest.mark.parametrize('neg_ratio', [None]) +@pytest.mark.skip(reason="'sample_from_edges' not yet implemented") +def test_dist_link_neighbor_loader_homo( + tmp_path, + num_parts, + num_workers, + async_sampling, + neg_ratio, +): + mp_context = torch.multiprocessing.get_context('spawn') + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + addr = 'localhost' + + data = FakeDataset( + num_graphs=1, + avg_num_nodes=100, + avg_degree=3, + edge_dim=2, + )[0] + partitioner = Partitioner(data, num_parts, tmp_path) + partitioner.generate_partition() + + w0 = mp_context.Process( + target=dist_link_neighbor_loader_homo, + args=(tmp_path, num_parts, 0, addr, port, num_workers, async_sampling, + neg_ratio), + ) + + w1 = mp_context.Process( + target=dist_link_neighbor_loader_homo, + args=(tmp_path, num_parts, 1, addr, port, num_workers, async_sampling, + neg_ratio), + ) + + w0.start() + w1.start() + w0.join() + w1.join() + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('num_parts', [2]) +@pytest.mark.parametrize('num_workers', [0]) +@pytest.mark.parametrize('async_sampling', [True]) +@pytest.mark.parametrize('neg_ratio', [None]) +@pytest.mark.parametrize('edge_type', [('v0', 'e0', 'v0')]) +@pytest.mark.skip(reason="'sample_from_edges' not yet implemented") +def test_dist_link_neighbor_loader_hetero( + tmp_path, + num_parts, + num_workers, + async_sampling, + neg_ratio, + edge_type, +): + mp_context = torch.multiprocessing.get_context('spawn') + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + addr = 'localhost' + + data = FakeHeteroDataset( + num_graphs=1, + avg_num_nodes=100, + avg_degree=3, + num_node_types=2, + num_edge_types=4, + edge_dim=2, + )[0] + partitioner = Partitioner(data, num_parts, tmp_path) + partitioner.generate_partition() + + w0 = mp_context.Process( + target=dist_link_neighbor_loader_hetero, + args=(tmp_path, num_parts, 0, addr, port, num_workers, async_sampling, + neg_ratio, edge_type), + ) + + w1 = mp_context.Process( + target=dist_link_neighbor_loader_hetero, + args=(tmp_path, num_parts, 1, addr, port, num_workers, async_sampling, + neg_ratio, edge_type), + ) + + w0.start() + w1.start() + w0.join() + w1.join() diff --git a/test/distributed/test_dist_neighbor_loader.py b/test/distributed/test_dist_neighbor_loader.py index 8341b8d0be4f..d651e6af0453 100644 --- a/test/distributed/test_dist_neighbor_loader.py +++ b/test/distributed/test_dist_neighbor_loader.py @@ -18,7 +18,7 @@ from torch_geometric.testing import onlyLinux, withPackage -def create_dist_data(tmp_path, rank): +def create_dist_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) ( @@ -50,14 +50,13 @@ def create_dist_data(tmp_path, rank): def dist_neighbor_loader_homo( - tmp_path: str, - world_size: int, - rank: int, - master_addr: str, - master_port: int, - num_workers: int, - async_sampling: bool, - device=torch.device('cpu'), + tmp_path: str, + world_size: int, + rank: int, + master_addr: str, + master_port: int, + num_workers: int, + async_sampling: bool, ): part_data = create_dist_data(tmp_path, rank) input_nodes = part_data[0].get_global_id(None) @@ -80,7 +79,6 @@ def dist_neighbor_loader_homo( current_ctx=current_ctx, rpc_worker_names={}, concurrency=10, - device=device, drop_last=True, async_sampling=async_sampling, ) @@ -96,7 +94,6 @@ def dist_neighbor_loader_homo( assert isinstance(batch, Data) assert batch.n_id.size() == (batch.num_nodes, ) assert batch.input_id.numel() == batch.batch_size == 10 - assert batch.edge_index.device == device assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert torch.equal( @@ -106,14 +103,13 @@ def dist_neighbor_loader_homo( def dist_neighbor_loader_hetero( - tmp_path: str, - world_size: int, - rank: int, - master_addr: str, - master_port: int, - num_workers: int, - async_sampling: bool, - device=torch.device('cpu'), + tmp_path: str, + world_size: int, + rank: int, + master_addr: str, + master_port: int, + num_workers: int, + async_sampling: bool, ): part_data = create_dist_data(tmp_path, rank) input_nodes = ('v0', part_data[0].get_global_id('v0')) @@ -127,7 +123,7 @@ def dist_neighbor_loader_hetero( loader = DistNeighborLoader( part_data, - num_neighbors=[10, 10], + num_neighbors=[1], batch_size=10, num_workers=num_workers, input_nodes=input_nodes, @@ -136,7 +132,6 @@ def dist_neighbor_loader_hetero( current_ctx=current_ctx, rpc_worker_names={}, concurrency=10, - device=device, drop_last=True, async_sampling=async_sampling, ) @@ -153,14 +148,11 @@ def dist_neighbor_loader_hetero( assert len(batch.node_types) == 2 for node_type in batch.node_types: assert torch.equal(batch[node_type].x, batch.x_dict[node_type]) - assert batch.x_dict[node_type].device == device assert batch.x_dict[node_type].size(0) >= 0 assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes assert len(batch.edge_types) == 4 for edge_type in batch.edge_types: - assert batch[edge_type].edge_index.device == device - assert batch[edge_type].edge_attr.device == device assert (batch[edge_type].edge_attr.size(0) == batch[edge_type].edge_index.size(1)) diff --git a/torch_geometric/distributed/__init__.py b/torch_geometric/distributed/__init__.py index fff9850625a7..d452b2a7fce3 100644 --- a/torch_geometric/distributed/__init__.py +++ b/torch_geometric/distributed/__init__.py @@ -5,6 +5,7 @@ from .dist_neighbor_sampler import DistNeighborSampler from .dist_loader import DistLoader from .dist_neighbor_loader import DistNeighborLoader +from .dist_link_neighbor_loader import DistLinkNeighborLoader __all__ = classes = [ 'DistContext', @@ -14,4 +15,5 @@ 'DistNeighborSampler', 'DistLoader', 'DistNeighborLoader', + 'DistLinkNeighborLoader', ] diff --git a/torch_geometric/distributed/dist_link_neighbor_loader.py b/torch_geometric/distributed/dist_link_neighbor_loader.py new file mode 100644 index 000000000000..db10e9dba03f --- /dev/null +++ b/torch_geometric/distributed/dist_link_neighbor_loader.py @@ -0,0 +1,126 @@ +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch + +from torch_geometric.distributed import ( + DistLoader, + DistNeighborSampler, + LocalFeatureStore, + LocalGraphStore, +) +from torch_geometric.distributed.dist_context import DistContext, DistRole +from torch_geometric.loader import LinkLoader +from torch_geometric.sampler.base import NegativeSampling, SubgraphType +from torch_geometric.typing import EdgeType, InputEdges, OptTensor + + +class DistLinkNeighborLoader(LinkLoader, DistLoader): + r"""A distributed loader that preform sampling from edges. + + Args: + data: A (:class:`~torch_geometric.data.FeatureStore`, + :class:`~torch_geometric.data.GraphStore`) data object. + num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): + The number of neighbors to sample for each node in each iteration. + If an entry is set to :obj:`-1`, all neighbors will be included. + In heterogeneous graphs, may also take in a dictionary denoting + the amount of neighbors to sample for each individual edge type. + master_addr (str): RPC address for distributed loader communication, + *i.e.* the IP address of the master node. + master_port (Union[int, str]): Open port for RPC communication with + the master node. + current_ctx (DistContext): Distributed context information of the + current process. + rpc_worker_names (Dict[DistRole, List[str]]): RPC workers identifiers. + concurrency (int, optional): RPC concurrency used for defining the + maximum size of the asynchronous processing queue. + (default: :obj:`1`) + + All other arguments follow the interface of + :class:`torch_geometric.loader.LinkNeighborLoader`. + """ + def __init__( + self, + data: Tuple[LocalFeatureStore, LocalGraphStore], + num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], + master_addr: str, + master_port: Union[int, str], + current_ctx: DistContext, + rpc_worker_names: Dict[DistRole, List[str]], + edge_label_index: InputEdges = None, + edge_label: OptTensor = None, + edge_label_time: OptTensor = None, + neighbor_sampler: Optional[DistNeighborSampler] = None, + replace: bool = False, + subgraph_type: Union[SubgraphType, str] = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + neg_sampling: Optional[NegativeSampling] = None, + neg_sampling_ratio: Optional[Union[int, float]] = None, + time_attr: Optional[str] = None, + transform: Optional[Callable] = None, + concurrency: int = 1, + filter_per_worker: Optional[bool] = None, + async_sampling: bool = True, + device: Optional[torch.device] = None, + **kwargs, + ): + assert isinstance(data[0], LocalFeatureStore) + assert isinstance(data[1], LocalGraphStore) + assert concurrency >= 1, "RPC concurrency must be greater than 1" + + if (edge_label_time is not None) != (time_attr is not None): + raise ValueError( + f"Received conflicting 'edge_label_time' and 'time_attr' " + f"arguments: 'edge_label_time' is " + f"{'set' if edge_label_time is not None else 'not set'} " + f"while 'time_attr' is " + f"{'set' if time_attr is not None else 'not set'}. " + f"Both arguments must be provided for temporal sampling.") + + channel = torch.multiprocessing.Queue() if async_sampling else None + + if neighbor_sampler is None: + neighbor_sampler = DistNeighborSampler( + data=data, + current_ctx=current_ctx, + rpc_worker_names=rpc_worker_names, + num_neighbors=num_neighbors, + replace=replace, + subgraph_type=subgraph_type, + disjoint=disjoint, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + device=device, + channel=channel, + concurrency=concurrency, + ) + + self.neighbor_sampler = neighbor_sampler + + DistLoader.__init__( + self, + channel=channel, + master_addr=master_addr, + master_port=master_port, + current_ctx=current_ctx, + rpc_worker_names=rpc_worker_names, + **kwargs, + ) + LinkLoader.__init__( + self, + data=data, + link_sampler=neighbor_sampler, + edge_label_index=edge_label_index, + edge_label=edge_label, + neg_sampling=neg_sampling, + neg_sampling_ratio=neg_sampling_ratio, + transform=transform, + filter_per_worker=filter_per_worker, + worker_init_fn=self.worker_init_fn, + transform_sampler_output=self.channel_get, + **kwargs, + ) + + def __repr__(self) -> str: + return DistLoader.__repr__(self) diff --git a/torch_geometric/distributed/dist_neighbor_loader.py b/torch_geometric/distributed/dist_neighbor_loader.py index f428790ca85f..3c91d464f095 100644 --- a/torch_geometric/distributed/dist_neighbor_loader.py +++ b/torch_geometric/distributed/dist_neighbor_loader.py @@ -40,29 +40,27 @@ class DistNeighborLoader(NodeLoader, DistLoader): :class:`torch_geometric.loader.NeighborLoader`. """ def __init__( - self, - data: Tuple[LocalFeatureStore, LocalGraphStore], - num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], - master_addr: str, - master_port: Union[int, str], - current_ctx: DistContext, - rpc_worker_names: Dict[DistRole, List[str]], - input_nodes: InputNodes = None, - input_time: OptTensor = None, - neighbor_sampler: Optional[DistNeighborSampler] = None, - replace: bool = False, - subgraph_type: Union[SubgraphType, str] = "directional", - disjoint: bool = False, - temporal_strategy: str = "uniform", - time_attr: Optional[str] = None, - transform: Optional[Callable] = None, - is_sorted: bool = False, - with_edge: bool = True, - concurrency: int = 1, - filter_per_worker: Optional[bool] = False, - async_sampling: bool = True, - device: torch.device = torch.device("cpu"), - **kwargs, + self, + data: Tuple[LocalFeatureStore, LocalGraphStore], + num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], + master_addr: str, + master_port: Union[int, str], + current_ctx: DistContext, + rpc_worker_names: Dict[DistRole, List[str]], + input_nodes: InputNodes = None, + input_time: OptTensor = None, + neighbor_sampler: Optional[DistNeighborSampler] = None, + replace: bool = False, + subgraph_type: Union[SubgraphType, str] = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + time_attr: Optional[str] = None, + transform: Optional[Callable] = None, + concurrency: int = 1, + filter_per_worker: Optional[bool] = False, + async_sampling: bool = True, + device: Optional[torch.device] = None, + **kwargs, ): assert isinstance(data[0], LocalFeatureStore) assert isinstance(data[1], LocalGraphStore) @@ -81,14 +79,11 @@ def __init__( current_ctx=current_ctx, rpc_worker_names=rpc_worker_names, num_neighbors=num_neighbors, - with_edge=with_edge, replace=replace, subgraph_type=subgraph_type, disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, - is_sorted=is_sorted, - share_memory=kwargs.get('num_workers', 0) > 0, device=device, channel=channel, concurrency=concurrency, diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index 59eab0fa25b6..9c0748432693 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -60,20 +60,20 @@ class DistNeighborSampler: used by :class:`~torch_geometric.distributed.DistNeighborLoader`. """ def __init__( - self, - current_ctx: DistContext, - rpc_worker_names: Dict[DistRole, List[str]], - data: Tuple[LocalFeatureStore, LocalGraphStore], - num_neighbors: NumNeighborsType, - channel: Optional[mp.Queue] = None, - replace: bool = False, - subgraph_type: Union[SubgraphType, str] = 'directional', - disjoint: bool = False, - temporal_strategy: str = 'uniform', - time_attr: Optional[str] = None, - concurrency: int = 1, - device: Optional[torch.device] = torch.device('cpu'), - **kwargs, + self, + current_ctx: DistContext, + rpc_worker_names: Dict[DistRole, List[str]], + data: Tuple[LocalFeatureStore, LocalGraphStore], + num_neighbors: NumNeighborsType, + channel: Optional[mp.Queue] = None, + replace: bool = False, + subgraph_type: Union[SubgraphType, str] = 'directional', + disjoint: bool = False, + temporal_strategy: str = 'uniform', + time_attr: Optional[str] = None, + concurrency: int = 1, + device: Optional[torch.device] = None, + **kwargs, ): self.current_ctx = current_ctx self.rpc_worker_names = rpc_worker_names diff --git a/torch_geometric/loader/link_loader.py b/torch_geometric/loader/link_loader.py index a9555db4383e..e004ef66f79b 100644 --- a/torch_geometric/loader/link_loader.py +++ b/torch_geometric/loader/link_loader.py @@ -218,8 +218,16 @@ def filter_fn( out = self.transform_sampler_output(out) if isinstance(out, SamplerOutput): - data = filter_data(self.data, out.node, out.row, out.col, out.edge, - self.link_sampler.edge_permutation) + if isinstance(self.data, Data): + data = filter_data( # + self.data, out.node, out.row, out.col, out.edge, + self.link_sampler.edge_permutation) + + else: # Tuple[FeatureStore, GraphStore] + # TODO Respect `custom_cls`. + # TODO Integrate features. + edge_index = torch.stack([out.row, out.col]) + data = Data(edge_index=edge_index) if 'n_id' not in data: data.n_id = out.node @@ -250,12 +258,22 @@ def filter_fn( elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): - data = filter_hetero_data(self.data, out.node, out.row, - out.col, out.edge, - self.link_sampler.edge_permutation) + data = filter_hetero_data( # + self.data, out.node, out.row, out.col, out.edge, + self.link_sampler.edge_permutation) + else: # Tuple[FeatureStore, GraphStore] - data = filter_custom_store(*self.data, out.node, out.row, - out.col, out.edge, self.custom_cls) + # Hack to detect whether we are in a distributed setting. + if (self.link_sampler.__class__.__name__ == + 'DistNeighborSampler'): + import torch_geometric.distributed as dist + data = dist.utils.filter_dist_store( + *self.data, out.node, out.row, out.col, out.edge, + self.custom_cls, out.metadata) + else: + data = filter_custom_store( # + *self.data, out.node, out.row, out.col, out.edge, + self.custom_cls) for key, node in out.node.items(): if 'n_id' not in data[key]: @@ -264,8 +282,10 @@ def filter_fn( for key, edge in (out.edge or {}).items(): if edge is not None and 'e_id' not in data[key]: edge = edge.to(torch.long) - perm = self.link_sampler.edge_permutation[key] - data[key].e_id = perm[edge] if perm is not None else edge + perm = self.link_sampler.edge_permutation + if perm is not None and perm.get(key, None) is not None: + edge = perm[key][edge] + data[key].e_id = edge data.set_value_dict('batch', out.batch) data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes) diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 2605dc9e41cd..5455bd579a0c 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -203,8 +203,10 @@ def filter_fn( for key, edge in (out.edge or {}).items(): if edge is not None and 'e_id' not in data[key]: edge = edge.to(torch.long) - perm = self.node_sampler.edge_permutation[key] - data[key].e_id = perm[edge] if perm is not None else edge + perm = self.node_sampler.edge_permutation + if perm is not None and perm.get(key, None) is not None: + edge = perm[key][edge] + data[key].e_id = edge data.set_value_dict('batch', out.batch) data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)