Skip to content

Commit

Permalink
[GraphBolt][Standalone] Reduce and mark dependencies on DGL. (#7499)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 9, 2024
1 parent 6afb29b commit 5434fc3
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 37 deletions.
4 changes: 3 additions & 1 deletion python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

### FROM DGL @todo
from .._ffi import libinfo


Expand Down Expand Up @@ -46,12 +47,13 @@ def load_graphbolt():
from .itemset import *
from .item_sampler import *
from .minibatch_transformer import *
from .internal_utils import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .external_utils import add_reverse_edges, exclude_seed_edges
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, exclude_seed_edges
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

from ..utils import recursive_apply
from .internal_utils import recursive_apply

__all__ = [
"CANONICAL_ETYPE_DELIMITER",
Expand Down
File renamed without changes.
12 changes: 6 additions & 6 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@

import torch

from dgl.utils import recursive_apply

from ...base import EID, ETYPE, NID, NTYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..internal_utils import recursive_apply
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl

Expand Down Expand Up @@ -1538,11 +1534,15 @@ def load_from_shared_memory(


def from_dglgraph(
g: DGLGraph,
DGLGraphInstance,
is_homogeneous: bool = False,
include_original_edge_id: bool = False,
) -> FusedCSCSamplingGraph:
"""Convert a DGLGraph to FusedCSCSamplingGraph."""
from dgl.base import EID, ETYPE, NID, NTYPE
from dgl.convert import to_homogeneous

g = DGLGraphInstance

homo_g, ntype_count, _ = to_homogeneous(
g, ndata=g.ndata, edata=g.edata, return_count=True
Expand Down
9 changes: 5 additions & 4 deletions python/dgl/graphbolt/impl/legacy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import List, Union

from dgl.data import AsNodePredDataset, DGLDataset
from ..base import etype_tuple_to_str
from ..dataset import Dataset, Task
from ..itemset import HeteroItemSet, ItemSet
Expand All @@ -16,7 +15,7 @@
class LegacyDataset(Dataset):
"""A Graphbolt dataset for legacy DGLDataset."""

def __init__(self, legacy: DGLDataset):
def __init__(self, legacy):
# Only supports single graph cases.
assert len(legacy) == 1
graph = legacy[0]
Expand All @@ -28,7 +27,7 @@ def __init__(self, legacy: DGLDataset):
else:
self._init_as_heterogeneous_node_pred(legacy)

def _init_as_heterogeneous_node_pred(self, legacy: DGLDataset):
def _init_as_heterogeneous_node_pred(self, legacy):
def _init_item_set_dict(idx, labels):
item_set_dict = {}
for key in idx.keys():
Expand Down Expand Up @@ -87,7 +86,9 @@ def _init_item_set_dict(idx, labels):
"Only support heterogeneous ogn node pred dataset"
)

def _init_as_homogeneous_node_pred(self, legacy: DGLDataset):
def _init_as_homogeneous_node_pred(self, legacy):
from dgl.data import AsNodePredDataset

legacy = AsNodePredDataset(legacy)

# Initialize tasks.
Expand Down
9 changes: 4 additions & 5 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import torch
import yaml

from ...base import dgl_warning
from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple, ORIGINAL_EDGE_ID
from ..dataset import Dataset, Task
from ..internal import (
Expand All @@ -25,6 +23,7 @@
read_data,
read_edges,
)
from ..internal_utils import download, extract_archive, gb_warning
from ..itemset import HeteroItemSet, ItemSet
from ..sampling_graph import SamplingGraph
from .fused_csc_sampling_graph import (
Expand Down Expand Up @@ -457,7 +456,7 @@ def preprocess_ondisk_dataset(
is_feature=True,
)
if has_edge_feature_data and not include_original_edge_id:
dgl_warning("Edge feature is stored, but edge IDs are not saved.")
gb_warning("Edge feature is stored, but edge IDs are not saved.")

# 6. Save tasks and train/val/test split according to the output_config.
if input_config.get("tasks", None):
Expand Down Expand Up @@ -831,7 +830,7 @@ def _init_tasks(
if selected_tasks:
not_found_tasks = set(selected_tasks) - task_names
if len(not_found_tasks):
dgl_warning(
gb_warning(
f"Below tasks are not found in YAML: {not_found_tasks}. Skipped."
)
return ret
Expand Down Expand Up @@ -887,7 +886,7 @@ def _init_tvt_set(

def _init_all_nodes_set(self, graph) -> Union[ItemSet, HeteroItemSet]:
if graph is None:
dgl_warning(
gb_warning(
"`all_nodes_set` is returned as None, since graph is None."
)
return None
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/graphbolt/impl/ondisk_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import pydantic

from ...utils import version
from ..internal_utils import version


__all__ = [
"OnDiskFeatureDataFormat",
Expand Down
Loading

0 comments on commit 5434fc3

Please sign in to comment.