Skip to content

Commit

Permalink
[DistGB]enable prob/mask sampling on graphbolt partitions (#7458)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 13, 2024
1 parent 94b691b commit 7f1d164
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 25 deletions.
46 changes: 45 additions & 1 deletion examples/distributed/rgcn/lp_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def run(args, g, train_eids):

neg_sampler = dgl.dataloading.negative_sampler.Uniform(3)

sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
prob = args.prob_or_mask
sampler = dgl.dataloading.MultiLayerNeighborSampler(
fanouts,
prob=prob,
)

exclude = None
reverse_etypes = None
Expand Down Expand Up @@ -70,6 +74,12 @@ def run(args, g, train_eids):
epoch_tic = time.time()
for step, sample_data in enumerate(dataloader):
input_nodes, pos_graph, neg_graph, blocks = sample_data

for block in blocks:
for c_etype in block.canonical_etypes:
homo_eids = block.edges[c_etype].data[dgl.EID]
assert th.all(g.edges[c_etype].data[prob][homo_eids] > 0)

if args.debug:
# Verify exclude_edges functionality.
current_eids = blocks[-1].edata[dgl.EID]
Expand Down Expand Up @@ -118,6 +128,18 @@ def run(args, g, train_eids):
g.barrier()


def rand_init_prob(shape, dtype):
prob = th.rand(shape)
prob[th.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0
return prob


def rand_init_mask(shape, dtype):
prob = th.rand(shape)
prob[th.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0
return (prob > 0.2).to(th.float32)


def main(args):
dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)

Expand All @@ -127,6 +149,22 @@ def main(args):
g = dgl.distributed.DistGraph(args.graph_name)
print("rank:", g.rank())

# Assign prob/masks to edges.
for c_etype in g.canonical_etypes:
shape = (g.num_edges(etype=c_etype),)
g.edges[c_etype].data["prob"] = dgl.distributed.DistTensor(
shape,
th.float32,
init_func=rand_init_prob,
part_policy=g.get_edge_partition_policy(c_etype),
)
g.edges[c_etype].data["mask"] = dgl.distributed.DistTensor(
shape,
th.float32,
init_func=rand_init_mask,
part_policy=g.get_edge_partition_policy(c_etype),
)

pb = g.get_partition_book()
c_etype = ("author", "writes", "paper")
train_eids = dgl.distributed.edge_split(
Expand Down Expand Up @@ -200,6 +238,12 @@ def main(args):
action="store_true",
help="whether to remove edges during sampling",
)
parser.add_argument(
"--prob_or_mask",
type=str,
default="prob",
help="whether to use prob or mask during sampling",
)
args = parser.parse_args()

print(args)
Expand Down
27 changes: 27 additions & 0 deletions python/dgl/dataloading/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,27 @@ def collate(self, items):
"""
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name):
"""Add data into the graph as an edge attribute.
For some cases such as prob/mask-based sampling on GraphBolt partitions,
we need to prepare such data beforehand. This is because data are
usually saved in DistGraph.ndata/edata, but such data is not in the
format that GraphBolt partitions require. And in GraphBolt, such data
are saved as edge attributes. So we need to add such data into the graph
before any sampling is kicked off.
Parameters
----------
g : DistGraph
The graph.
data_name : str
The name of data that's stored in DistGraph.ndata/edata.
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)


class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
Expand Down Expand Up @@ -181,6 +202,9 @@ def __init__(self, g, nids, graph_sampler):
self.nids = utils.prepare_tensor_or_dict(g, nids, "nids")
self._dataset = utils.maybe_flatten_dict(self.nids)

# Add prob/mask into graphbolt partition's edge attributes if needed.
Collator.add_edge_attribute_to_graph(self.g, self.graph_sampler.prob)

@property
def dataset(self):
return self._dataset
Expand Down Expand Up @@ -437,6 +461,9 @@ def __init__(
self.eids = utils.prepare_tensor_or_dict(g, eids, "eids")
self._dataset = utils.maybe_flatten_dict(self.eids)

# Add prob/mask into graphbolt partition's edge attributes if needed.
Collator.add_edge_attribute_to_graph(self.g, self.graph_sampler.prob)

@property
def dataset(self):
return self._dataset
Expand Down
201 changes: 199 additions & 2 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@

INIT_GRAPH = 800001
QUERY_IF_USE_GRAPHBOLT = 800002
ADD_EDGE_ATTRIBUTE_FROM_KV = 800003
ADD_EDGE_ATTRIBUTE_FROM_SHARED_MEM = 800004


class InitGraphRequest(rpc.Request):
Expand Down Expand Up @@ -118,6 +120,126 @@ def __setstate__(self, state):
self._use_graphbolt = state


def _copy_data_to_shared_mem(data, name):
"""Copy data to shared memory."""
# [TODO] Copy data to shared memory.
assert data.dtype == torch.float32, "Only float32 is supported."
data_type = F.reverse_data_type_dict[F.dtype(data)]
shared_data = empty_shared_mem(name, True, data.shape, data_type)
dlpack = shared_data.to_dlpack()
ret = F.zerocopy_from_dlpack(dlpack)
rpc.copy_data_to_shared_memory(ret, data)
return ret


def _copy_data_from_shared_mem(name, shape):
"""Copy data from shared memory."""
data_type = F.reverse_data_type_dict[F.float32]
data = empty_shared_mem(name, False, shape, data_type)
dlpack = data.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)


class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names):
self._name = name
self._kv_names = kv_names

def __getstate__(self):
return self._name, self._kv_names

def __setstate__(self, state):
self._name, self._kv_names = state

def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
name = self._name
g = server_state.graph
if name not in g.edge_attributes:
# Fetch target data from kvstore.
kv_store = server_state.kv_store
data = [
kv_store.data_store[kv_name] if kv_name else None
for kv_name in self._kv_names
]
# Due to data type limitation in GraphBolt's sampling, we only support float32.
data_type = torch.float32
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
etype_ids, typed_eids = gpb.map_to_per_etype(homo_eids)
for etype_id, c_etype in enumerate(gpb.canonical_etypes):
curr_indices = torch.nonzero(etype_ids == etype_id).squeeze()
curr_typed_eids = typed_eids[curr_indices]
curr_local_eids = gpb.eid2localeid(
curr_typed_eids, gpb.partid, etype=c_etype
)
if data[etype_id] is None:
continue
attr_data[curr_indices] = data[etype_id][curr_local_eids].to(
data_type
)
# Copy data to shared memory.
attr_data = _copy_data_to_shared_mem(attr_data, "__edge__" + name)
g.add_edge_attribute(name, attr_data)
return AddEdgeAttributeFromKVResponse(name)


class AddEdgeAttributeFromKVResponse(rpc.Response):
"""Ack the request of adding edge attribute."""

def __init__(self, name):
self._name = name

def __getstate__(self):
return self._name

def __setstate__(self, state):
self._name = state


class AddEdgeAttributeFromSharedMemRequest(rpc.Request):
"""Add edge attribute from shared memory to local GraphBolt partition."""

def __init__(self, name):
self._name = name

def __getstate__(self):
return self._name

def __setstate__(self, state):
self._name = state

def process_request(self, server_state):
name = self._name
g = server_state.graph
if name not in g.edge_attributes:
data = _copy_data_from_shared_mem(
"__edge__" + name, (g.total_num_edges,)
)
g.add_edge_attribute(name, data)
return AddEdgeAttributeFromSharedMemResponse(name)


class AddEdgeAttributeFromSharedMemResponse(rpc.Response):
"""Ack the request of adding edge attribute from shared memory."""

def __init__(self, name):
self._name = name

def __getstate__(self):
return self._name

def __setstate__(self, state):
self._name = state


def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
Expand Down Expand Up @@ -592,6 +714,7 @@ class DistGraph:

def __init__(self, graph_name, gpb=None, part_config=None):
self.graph_name = graph_name
self._added_edge_attributes = [] # For prob/mask sampling on GB.
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
# "GraphBolt is not supported in standalone mode."
self._use_graphbolt = False
Expand Down Expand Up @@ -725,16 +848,35 @@ def _init_metadata(self):
}

def __getstate__(self):
return self.graph_name, self._gpb, self._use_graphbolt
return (
self.graph_name,
self._gpb,
self._use_graphbolt,
self._added_edge_attributes,
)

def __setstate__(self, state):
self.graph_name, gpb, self._use_graphbolt = state
(
self.graph_name,
gpb,
self._use_graphbolt,
self._added_edge_attributes,
) = state
self._init(gpb)

self._init_ndata_store()
self._init_edata_store()
self._init_metadata()

# For prob/mask sampling on GB only.
if self._use_graphbolt and len(self._added_edge_attributes) > 0:
# Add edge attribute from main server's shared memory.
for name in self._added_edge_attributes:
data = _copy_data_from_shared_mem(
"__edge__" + name, (self.local_partition.total_num_edges,)
)
self.local_partition.add_edge_attribute(name, data)

@property
def local_partition(self):
"""Return the local partition on the client
Expand Down Expand Up @@ -1478,6 +1620,51 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name):
"""Add an edge attribute into GraphBolt partition from edge data.
Parameters
----------
name : str
The name of the edge attribute.
"""
# Sanity checks.
if not self._use_graphbolt:
raise DGLError("GraphBolt is not used.")

# Send add request to main server on the same machine.
kv_names = [
(
self.edges[etype].data[name].kvstore_key
if name in self.edges[etype].data
else None
)
for etype in self.canonical_etypes
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
)
# Wait for the response.
assert rpc.recv_response()._name == name
# Send add request to backup servers.
for server_id in range(self._client.num_servers):
rpc.send_request(
server_id, AddEdgeAttributeFromSharedMemRequest(name)
)
for server_id in range(self._client.num_servers):
rpc.recv_response()
# Add edge attribute from main server's shared memory.
data = _copy_data_from_shared_mem(
"__edge__" + name, (self.local_partition.total_num_edges,)
)
self.local_partition.add_edge_attribute(name, data)
# Sync local clients.
self._client.barrier()

# Save the edge attribute into state. This is required by separate samplers.
self._added_edge_attributes.append(name)


def _get_overlap(mask_arr, ids):
"""Select the IDs given a boolean mask array.
Expand Down Expand Up @@ -1873,3 +2060,13 @@ def edge_split(
QueryIfUseGraphBoltRequest,
QueryIfUseGraphBoltResponse,
)
rpc.register_service(
ADD_EDGE_ATTRIBUTE_FROM_KV,
AddEdgeAttributeFromKVRequest,
AddEdgeAttributeFromKVResponse,
)
rpc.register_service(
ADD_EDGE_ATTRIBUTE_FROM_SHARED_MEM,
AddEdgeAttributeFromSharedMemRequest,
AddEdgeAttributeFromSharedMemResponse,
)
Loading

0 comments on commit 7f1d164

Please sign in to comment.