diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index 903b84540914..de1e48e14e08 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -352,11 +352,14 @@ def _process_partition_gb( sorted_idx = ( th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx ) + else: + sorted_idxs=th.arange(len(edge_ids)) return indptr, indices[sorted_idx], edge_ids[sorted_idx] -def update_node_map(node_map_val, end_ids_per_rank, id_ntypes, prev_last_id): +def _update_node_map(node_map_val, end_ids_per_rank, id_ntypes, prev_last_id): + """this function is modified from the function '_update_node_edge_map' in dgl.distributed.partition """ # Update the node_map_val to be contiguous. rank = dist.get_rank() prev_end_id = ( @@ -561,7 +564,7 @@ def create_graph_object( ] dist.all_gather(gather_last_ids, last_id) - prev_last_id = update_node_map( + prev_last_id = _update_node_map( node_map_val, gather_last_ids, id_ntypes, prev_last_id ) last_ids[part_id] = prev_last_id