diff --git a/README.md b/README.md index 33f68a6..b46788c 100755 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ Figures can be plotted by: [`benchmark/notebook/fig_hop.ipynb`](benchmark/notebo #### Frequency response (*Table 12*): ```bash -bash scripts/exp_filter.sh +bash scripts/exp_regression.sh ``` ## Customization @@ -109,11 +109,11 @@ options: # Logging configuration --seed SEED random seed --dev DEV GPU id - --suffix SUFFIX Save name suffix. - -quiet Dry run without saving logs. + --suffix SUFFIX Result log file name. None:not saving results + -quiet File log. True:dry run without saving logs --storage {state_file,state_ram,state_gpu} - Storage scheme for saving the checkpoints. - --loglevel LOGLEVEL 10:progress, 15:train, 20:info, 25:result + Checkpoint log storage scheme. + --loglevel LOGLEVEL Console log. 10:progress, 15:train, 20:info, 25:result # Data configuration --data DATA Dataset name --data_split DATA_SPLIT Index or percentage of dataset split diff --git a/benchmark/dataset/linkx.py b/benchmark/dataset/linkx.py index 020ea69..a20c7d0 100644 --- a/benchmark/dataset/linkx.py +++ b/benchmark/dataset/linkx.py @@ -220,6 +220,23 @@ def forward(self, data: Any) -> Any: def get_data(datapath, transform, args: Namespace): + r"""Load data based on parameters and configurations. + + Args: + datapath: Path to the root data directory. + transform: Data transformation pipeline. + args: Parameters. + + * args.data (str): Dataset name. + * args.data_split (str): Index of dataset split. + Returns: + data (Data): The resolved data sample from the dataset. + Updates: + args.num_features (int): Number of input features. + args.num_classes (int): Number of output classes. + args.multi (bool): True for multi-label classification. + args.metric (str): Main metric name for evaluation. + """ args.multi = False args.metric = 's_auroc' if args.data in ['genius'] else 's_f1i' # FIXME: check split diff --git a/benchmark/dataset/ogbn.py b/benchmark/dataset/ogbn.py index 93f677a..a5c1598 100644 --- a/benchmark/dataset/ogbn.py +++ b/benchmark/dataset/ogbn.py @@ -24,6 +24,23 @@ def forward(self, data: Any) -> Any: def get_data(datapath, transform, args: Namespace): + r"""Load data based on parameters and configurations. + + Args: + datapath: Path to the root data directory. + transform: Data transformation pipeline. + args: Parameters. + + * args.data (str): Dataset name. + * args.data_split (str): Index of dataset split. + Returns: + data (Data): The resolved data sample from the dataset. + Updates: + args.num_features (int): Number of input features. + args.num_classes (int): Number of output classes. + args.multi (bool): True for multi-label classification. + args.metric (str): Main metric name for evaluation. + """ args.multi = True if args.data == 'ogbn-proteins' else False args.metric = 's_auroc' if args.data == 'ogbn-proteins' else 's_f1i' args.data_split = f"Original_0" diff --git a/benchmark/dataset/pygn.py b/benchmark/dataset/pygn.py index d02a141..e8069b1 100644 --- a/benchmark/dataset/pygn.py +++ b/benchmark/dataset/pygn.py @@ -30,6 +30,23 @@ def get_data(datapath, transform, args: Namespace): + r"""Load data based on parameters and configurations. + + Args: + datapath: Path to the root data directory. + transform: Data transformation pipeline. + args: Parameters. + + * args.data (str): Dataset name. + * args.data_split (str): Index of dataset split. + Returns: + data (Data): The resolved data sample from the dataset. + Updates: + args.num_features (int): Number of input features. + args.num_classes (int): Number of output classes. + args.multi (bool): True for multi-label classification. + args.metric (str): Main metric name for evaluation. + """ args.multi = False args.metric = 's_f1i' assert args.data_split.split('_')[0] in ['Random', 'Stratify'] diff --git a/benchmark/dataset/utils.py b/benchmark/dataset/utils.py index b2b320d..2beb8db 100644 --- a/benchmark/dataset/utils.py +++ b/benchmark/dataset/utils.py @@ -23,6 +23,19 @@ def T_insert(transform, new_t: T.BaseTransform, index=-1) -> T.Compose: def resolve_data(args: Namespace, dataset: Dataset) -> Data: + r"""Acquire data and properties from dataset. + + Args: + args: Parameters. + + * args.multi (bool): ``True`` for multi-label classification. + dataset: PyG dataset object. + Returns: + data (Data): The resolved PyG data object from the dataset. + Updates: + args.num_features (int): Number of input features. + args.num_classes (int): Number of output classes. + """ # Avoid triggering transform when getting simple properties. # data = dataset.get(dataset.indices()[0]) # if hasattr(dataset, '_data_list') and dataset._data_list is not None: @@ -44,6 +57,18 @@ def resolve_data(args: Namespace, dataset: Dataset) -> Data: def resolve_split(data_split: str, data: Data) -> Data: + r"""Apply data split masks. + + Args: + data_split: Index of dataset split, formatted as ``scheme_split`` or ``scheme_split_seed``. + + * ``scheme='Random'``: Random split, ``split`` is ``train/val/test`` ratio. + * ``scheme='Stratify'``: Stratified split, ``split`` is ``train/val/test`` ratio. + * ``scheme='Original'``: Original split, ``split`` is the index of split. + data: PyG data object containing the dataset and its attributes. + Returns: + data (Data): The updated PyG data object with split masks (train/val/test). + """ ctx = data_split.split('_') if len(ctx) == 2: scheme, split = ctx @@ -80,8 +105,8 @@ def split_crossval(label: torch.Tensor, r_train: float, r_val: float, seed: int = None, - ignore_neg=True, - stratify=False) -> Tuple[torch.Tensor]: + ignore_neg: bool =True, + stratify: bool =False) -> Tuple[torch.Tensor]: r"""Split index by cross-validation""" node_labeled = torch.where(label >= 0)[0] if ignore_neg else np.arange(label.shape[0]) @@ -98,13 +123,15 @@ def split_crossval(label: torch.Tensor, index_to_mask(torch.as_tensor(test_idx), size=label.shape[0])) -def even_quantile_labels(vals, nclasses, verbose=True): +def even_quantile_labels(vals: np.ndarray, nclasses: int, verbose:bool=True): """ partitions vals into nclasses by a quantile based split, where the first class is less than the 1/nclasses quantile, second class is less than the 2/nclasses quantile, and so on - vals is np array - returns an np array of int class labels + Args: + vals: The input array to be partitioned. + nclasses: The number of classes to partition the array into. + verbose: Prints the intervals for each class. """ label = -1 * np.ones(vals.shape[0], dtype=int) interval_lst = [] diff --git a/benchmark/dataset/yandex.py b/benchmark/dataset/yandex.py index 82ea9bb..56dff6c 100644 --- a/benchmark/dataset/yandex.py +++ b/benchmark/dataset/yandex.py @@ -76,6 +76,23 @@ def process(self) -> None: def get_data(datapath, transform, args: Namespace): + r"""Load data based on parameters and configurations. + + Args: + datapath: Path to the root data directory. + transform: Data transformation pipeline. + args: Parameters. + + * args.data (str): Dataset name. + * args.data_split (str): Index of dataset split. + Returns: + data (Data): The resolved data sample from the dataset. + Updates: + args.num_features (int): Number of input features. + args.num_classes (int): Number of output classes. + args.multi (bool): True for multi-label classification. + args.metric (str): Main metric name for evaluation. + """ args.multi = False args.metric = { 'chameleon_filtered': 's_f1i', diff --git a/benchmark/trainer/base.py b/benchmark/trainer/base.py index d6c6827..5a49e12 100755 --- a/benchmark/trainer/base.py +++ b/benchmark/trainer/base.py @@ -24,24 +24,25 @@ class TrnBase(object): r"""Base trainer class for general pipelines and tasks. Args: - model (nn.Module): Pytorch model to be trained. - data (Data): PyG style data. - logger (Logger): Logger object. - args (Namespace): Configuration arguments. - device (str): torch device. - metric (str): Metric for evaluation. - criterion (set): Loss function in :mod:`torch.nn`. - epoch (int): Number of training epochs. - lr_[lin/conv] (float): Learning rate for linear/conv. - wd_[lin/conv] (float): Weight decay for linear/conv. - patience (int): Patience for early stopping. - period (int): Period for checkpoint saving. - suffix (str): Suffix for checkpoint saving. - storage (str): Storage scheme for checkpoint saving. - logpath (Path): Path for logging. - multi (bool): True for multi-label classification. - num_features (int): Number of data input features. - num_classes (int): Number of data output classes. + model: Pytorch model to be trained. + data: PyG style data. + res_logger: Logger for results. + args: Configuration arguments. + + * device (str): torch device. + * metric (str): Metric for evaluation. + * criterion (set): Loss function in :mod:`torch.nn`. + * epoch (int): Number of training epochs. + * lr_[lin/conv] (float): Learning rate for linear/conv. + * wd_[lin/conv] (float): Weight decay for linear/conv. + * patience (int): Patience for early stopping. + * period (int): Period for checkpoint saving. + * suffix (str): Suffix for checkpoint saving. + * storage (str): Storage scheme for checkpoint saving. + * logpath (Path): Path for logging. + * multi (bool): True for multi-label classification. + * num_features (int): Number of data input features. + * num_classes (int): Number of data output classes. Methods: setup_optimizer: Set up the optimizer and scheduler. diff --git a/benchmark/trainer/fullbatch.py b/benchmark/trainer/fullbatch.py index 768e1f7..b82a9d3 100755 --- a/benchmark/trainer/fullbatch.py +++ b/benchmark/trainer/fullbatch.py @@ -24,24 +24,22 @@ class TrnFullbatch(TrnBase): - Run pipeline: train_val -> test. Args: - --- TrnBase Args --- - model (nn.Module): Pytorch model to be trained. - data (Data): PyG style data. - logger (Logger): Logger object. - args (Namespace): Configuration arguments. - device (str): torch device. - metric (str): Metric for evaluation. - epoch (int): Number of training epochs. - lr_[lin/conv] (float): Learning rate for linear/conv. - wd_[lin/conv] (float): Weight decay for linear/conv. - patience (int): Patience for early stopping. - period (int): Period for checkpoint saving. - suffix (str): Suffix for checkpoint saving. - storage (str): Storage scheme for checkpoint saving. - logpath (Path): Path for logging. - multi (bool): True for multi-label classification. - num_features (int): Number of data input features. - num_classes (int): Number of data output classes. + model, data, res_logger: args for :class:`TrnBase`. + args: args for :class:`TrnBase`. + + * device (str): torch device. + * metric (str): Metric for evaluation. + * epoch (int): Number of training epochs. + * lr_[lin/conv] (float): Learning rate for linear/conv. + * wd_[lin/conv] (float): Weight decay for linear/conv. + * patience (int): Patience for early stopping. + * period (int): Period for checkpoint saving. + * suffix (str): Suffix for checkpoint saving. + * storage (str): Storage scheme for checkpoint saving. + * logpath (Path): Path for logging. + * multi (bool): True for multi-label classification. + * num_features (int): Number of data input features. + * num_classes (int): Number of data output classes. """ name: str = 'fb' @@ -132,7 +130,7 @@ def test_deg(self) -> ResLogger: adj_t = self.data.adj_t if isinstance(adj_t, SparseTensor): deg = adj_t.sum(dim=0).cpu() - elif isinstance(adj_t, torch.Tensor) and adj_t.is_sparse_csr: + elif pyg_utils.is_torch_sparse_tensor(adj_t): deg = torch.sparse.sum(adj_t.to_sparse_coo(), [0]).cpu().to_dense() else: raise NotImplementedError(f"Type {type(adj_t)} not supported!") diff --git a/benchmark/trainer/load_data.py b/benchmark/trainer/load_data.py index 3f41893..066616d 100755 --- a/benchmark/trainer/load_data.py +++ b/benchmark/trainer/load_data.py @@ -22,16 +22,18 @@ class SingleGraphLoader(object): - r"""Loader for PyG.data.Data object for one graph. + r"""Loader for :class:`torch_geometric.data.Data` object for one graph. Args: - args.seed (int): Random seed. - args.data (str): Dataset name. - args.data_split (str): Index of dataset split. + args: Configuration arguments. + + * args.seed (int): Random seed. + * args.data (str): Dataset name. + * args.data_split (str): Index of dataset split. + res_logger: Logger for results. """ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: - r"""Assigning dataset identity. - """ + # Assigning dataset identity. self.seed = args.seed self.data = args.data.lower() @@ -57,7 +59,7 @@ def get(self, args: Namespace) -> Data: Args: args.normg (float): Generalized graph norm. - Returns (update in args): + Updates: args.num_features (int): Number of input features. args.num_classes (int): Number of output classes. args.multi (bool): True for multi-label classification. diff --git a/benchmark/trainer/load_metric.py b/benchmark/trainer/load_metric.py index 1593a81..f99cb2e 100755 --- a/benchmark/trainer/load_metric.py +++ b/benchmark/trainer/load_metric.py @@ -16,17 +16,19 @@ class ResCollection(MetricCollection): def compute(self) -> List[Tuple[str, Any, Callable]]: - r"""Wrap compute output to ResLogger style.""" + r"""Wrap compute output to :class:`ResLogger` style.""" dct = self._compute_and_reduce("compute") return [(k, v.cpu().numpy(), (lambda x: format(x*100, '.3f'))) for k, v in dct.items()] def metric_loader(args: Namespace) -> MetricCollection: - r"""Loader for torchmetrics.Metric object. + r"""Loader for :class:`torchmetrics.Metric` object. Args: - args.multi (bool): True for multi-label classification. - args.num_classes (int): Number of output classes/labels. + args: Configuration arguments. + + args.multi (bool): True for multi-label classification. + args.num_classes (int): Number of output classes/labels. """ # FEATURE: more metrics [glemos1](https://github.com/facebookresearch/glemos/blob/main/src/performances/node_classification.py), [glemos2](https://github.com/facebookresearch/glemos/blob/main/src/utils/eval_utils.py) if args.multi: diff --git a/benchmark/trainer/load_model.py b/benchmark/trainer/load_model.py index 12362d4..f122659 100755 --- a/benchmark/trainer/load_model.py +++ b/benchmark/trainer/load_model.py @@ -16,11 +16,14 @@ class ModelLoader(object): - r"""Loader for nn.Module object. + r"""Loader for :class:`torch.nn.Module` object. Args: - args.model (str): Model architecture name. - args.conv (str): Convolution layer name. + args: Configuration arguments. + + args.model (str): Model architecture name. + args.conv (str): Convolution layer name. + res_logger: Logger for results. """ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: r"""Assigning model identity. @@ -153,7 +156,7 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: args.hidden (int): Number of hidden units. args.dp_[lin/conv] (float): Dropout rate for linear/conv. - Returns (update in args): + Updates: args.criterion (str): Criterion for loss calculation """ self.logger.debug('-'*20 + f" Loading model: {self} " + '-'*20) diff --git a/benchmark/trainer/minibatch.py b/benchmark/trainer/minibatch.py index 267d976..e24d15a 100755 --- a/benchmark/trainer/minibatch.py +++ b/benchmark/trainer/minibatch.py @@ -28,24 +28,22 @@ class TrnMinibatch(TrnBase): Args: args.batch (int): Batch size. args.normf (int): Embedding normalization. - --- TrnBase Args --- - model (nn.Module): Pytorch model to be trained. - data (Data): PyG style data. - logger (Logger): Logger object. - args (Namespace): Configuration arguments. - device (str): torch device. - metric (str): Metric for evaluation. - epoch (int): Number of training epochs. - lr_[lin/conv] (float): Learning rate for linear/conv. - wd_[lin/conv] (float): Weight decay for linear/conv. - patience (int): Patience for early stopping. - period (int): Period for checkpoint saving. - suffix (str): Suffix for checkpoint saving. - storage (str): Storage scheme for checkpoint saving. - logpath (Path): Path for logging. - multi (bool): True for multi-label classification. - num_features (int): Number of data input features. - num_classes (int): Number of data output classes. + model, data, res_logger: args for :class:`TrnBase`. + args: args for :class:`TrnBase`. + + * device (str): torch device. + * metric (str): Metric for evaluation. + * epoch (int): Number of training epochs. + * lr_[lin/conv] (float): Learning rate for linear/conv. + * wd_[lin/conv] (float): Weight decay for linear/conv. + * patience (int): Patience for early stopping. + * period (int): Period for checkpoint saving. + * suffix (str): Suffix for checkpoint saving. + * storage (str): Storage scheme for checkpoint saving. + * logpath (Path): Path for logging. + * multi (bool): True for multi-label classification. + * num_features (int): Number of data input features. + * num_classes (int): Number of data output classes. """ name: str = 'mb' diff --git a/benchmark/utils/checkpoint.py b/benchmark/utils/checkpoint.py index 428be07..e815ff2 100755 --- a/benchmark/utils/checkpoint.py +++ b/benchmark/utils/checkpoint.py @@ -16,14 +16,14 @@ class CkptLogger(object): stopping during training. Args: - logpath (Path or str): Path to checkpoints saving directory. - patience (int, optional): Patience for early stopping. Defaults no early stopping. - period (int, optional): Periodic saving interval. Defaults to no periodic saving. - prefix (str, optional): Prefix for the checkpoint file names. - storage (str, optional): Storage scheme for saving the checkpoints. - - 'model' vs 'state': Save model object or state_dict. - - '_file', '_ram', '_gpu': Save as file, RAM, or GPU memory. - metric_cmp (function or ['max', 'min'], optional): Comparison function for the metric. + logpath: Path to checkpoints saving directory. + patience: Patience for early stopping. Defaults no early stopping. + period: Periodic saving interval. Defaults to no periodic saving. + prefix: Prefix for the checkpoint file names. + storage: Storage scheme for saving the checkpoints. + * 'model' vs 'state': Save model object or state_dict. + * '_file', '_ram', '_gpu': Save as file, RAM, or GPU memory. + metric_cmp: Comparison function for the metric. Can be 'max' or 'min'. """ def __init__(self, logpath: Union[Path, str], @@ -108,8 +108,8 @@ def load(self, *suffix, model: nn.Module, map_location='cpu') -> nn.Module: Args: suffix: Variable length argument for suffix in the model file name. - model (nn.Module): The model structure to load. - map_location (str, optional): `map_location` argument for `torch.load`. + model: The model structure to load. + map_location: `map_location` argument for `torch.load`. Returns: model (nn.Module): The loaded model. @@ -163,8 +163,8 @@ def step(self, """Step one epoch with periodic saving and early stopping. Args: - metric (float): Metric value for the current step. - model (nn.Module, optional): Model for the current step. Defaults to None. + metric: Metric value for the current step. + model: Model for the current step. Defaults to None. Returns: early_stop (bool): True if early stopping criteria is met. diff --git a/benchmark/utils/config.py b/benchmark/utils/config.py index fc7293b..e47a909 100755 --- a/benchmark/utils/config.py +++ b/benchmark/utils/config.py @@ -50,8 +50,8 @@ def setup_argparse(): parser.add_argument('-v', '--dev', type=int, default=0, help='GPU id') parser.add_argument('-z', '--suffix', type=str, default=None, help='Result log file name. None:not saving results') parser.add_argument('-quiet', action='store_true', help='File log. True:dry run without saving logs') + parser.add_argument('--storage', type=str, default='state_gpu', choices=['state_file', 'state_ram', 'state_gpu'], help='Checkpoint log storage scheme') parser.add_argument('--loglevel', type=int, default=10, help='Console log. 10:progress, 15:train, 20:info, 25:result') - parser.add_argument('--storage', type=str, default='state_gpu', choices=['state_file', 'state_ram', 'state_gpu'], help='Storage scheme for saving the checkpoints') # Data configuration parser.add_argument('-d', '--data', type=str, default='cora', help='Dataset name') parser.add_argument('--data_split', type=str, default='Stratify_60/20/20', help='Dataset split') @@ -120,7 +120,7 @@ def setup_args(parser: argparse.ArgumentParser) -> argparse.Namespace: elif args.model in ['PrecomputedVar', 'PrecomputedVarCompose']: args.model_repr = 'PrecomputedVar' args.conv_repr = args.conv - # FIXME: separate arch for AdaGNN + # TODO: separate arch for AdaGNN elif args.model in ['AdaGNN']: args.model_repr = 'DecoupledVar' args.conv_repr = args.conv diff --git a/benchmark/utils/logger.py b/benchmark/utils/logger.py index 9d67310..a8e2b3d 100755 --- a/benchmark/utils/logger.py +++ b/benchmark/utils/logger.py @@ -76,9 +76,9 @@ def setup_logpath(dir: Union[Path, str] = LOGPATH, r"""Resolve log path for saving. Args: - dir (Path or str): Base directory for saving logs. Default is '../log/'. - folder_args (Tuple): Subfolder names. - quiet (bool, optional): Quiet run without creating directories. + dir: Base directory for saving logs. + folder_args: Subfolder names. + quiet: Quiet run without creating directories. Returns: logpath (Path): Path for log directory. @@ -99,8 +99,8 @@ class ResLogger(object): r"""Logger for formatting result to strings by wrapping pd.DataFrame table. Args: - logpath (Path or str): Path to CSV file saving directory. - quiet (bool): Quiet run without saving file. + logpath: Path to CSV file saving directory. + quiet: Quiet run without saving file. """ def __init__(self, logpath: Union[Path, str] = LOGPATH, @@ -151,8 +151,8 @@ def _set(self, data: DataFrame, fmt: Series): r"""Sets the data from input DataFrame. Args: - data (DataFrame): Concat on columns, inner join on index. - fmt (Series): Inner join on columns. + data: Concat on columns, inner join on index. + fmt: Inner join on columns. """ cols_left = self.data.columns.tolist() cols_right = data.columns.tolist() @@ -169,9 +169,9 @@ def concat(self, r"""Concatenate data entries of a single row to data. Args: - vals (List or Dict): list of entries (key, value, formatter). - row (int): New index in self dataframe for vals to be logged. - suffix (str): Suffix string for input keys. Default is None. + vals: list of entries (key, value, formatter). + row: New index in self dataframe for vals to be logged. + suffix: Suffix string for input keys. Default is None. Returns: self (ResLogger) @@ -208,9 +208,9 @@ def merge(self, r"""Merge from another logger. Args: - vals (TabLogger): Logger to merge. - row (List): New index in self dataframe. - suffix (str): Suffix string for input keys. Default is None. + logger: Logger to merge. + row: New index in self dataframe. + suffix: Suffix string for input keys. Default is None. """ if rows: assert len(rows) == logger.nrows @@ -226,19 +226,21 @@ def del_col(self, col: Union[List, str]) -> 'ResLogger': r"""Delete columns from data. Args: - col (str or list): Column(s) to delete. + col: Column(s) to delete. """ self.data = self.data.drop(columns=col) self.fmt = self.fmt.drop(index=col) return self # ===== Output - def _get(self, col=None, row=None) -> Union[DataFrame, Series, str]: + def _get(self, + col: Union[List, str]=None, + row: Union[List, str]=None) -> Union[DataFrame, Series, str]: r"""Retrieve one or sliced data and apply string format. Args: - col (str or list): Column(s) to retrieve. Defaults to all. - row (str or list): Row(s) to retrieve. Defaults to all. + col: Column(s) to retrieve. Defaults to all. + row: Row(s) to retrieve. Defaults to all. Returns: val: Formatted data. @@ -294,9 +296,9 @@ def get_str(self, and rows. Args: - col (str or list): Column(s) to retrieve. Defaults to all. - row (str or list): Row(s) to retrieve. Defaults to all. - maxlen (int): Max line length of the resulting string. + col: Column(s) to retrieve. Defaults to all. + row: Row(s) to retrieve. Defaults to all. + maxlen: Max line length of the resulting string. Returns: s (str): Formatted string representation. diff --git a/docs/source/_tutorial/configure.rst b/docs/source/_tutorial/configure.rst index 62e61b2..b93e98e 100644 --- a/docs/source/_tutorial/configure.rst +++ b/docs/source/_tutorial/configure.rst @@ -16,12 +16,12 @@ Refer to the help text by: --seed SEED random seed --dev DEV GPU id ---suffix SUFFIX Save name suffix --quiet Dry run without saving logs +--suffix SUFFIX Result log file name. ``None``:not saving results +-quiet File log. ``True``:dry run without saving logs --storage STORAGE - Storage scheme for saving the checkpoints. + Checkpoint log storage scheme. Options: ``state_file``, ``state_ram``, ``state_gpu`` ---loglevel LOGLEVEL ``10``:progress, ``15``:train, ``20``:info, ``25``:result +--loglevel LOGLEVEL Console log. ``10``:progress, ``15``:train, ``20``:info, ``25``:result .. rubric:: Data configuration diff --git a/docs/source/_tutorial/reproduce.rst b/docs/source/_tutorial/reproduce.rst index 7001a70..39d740c 100644 --- a/docs/source/_tutorial/reproduce.rst +++ b/docs/source/_tutorial/reproduce.rst @@ -42,4 +42,4 @@ Figures can be plotted by: `benchmark/notebook/fig_hop.ipynb Tuple[Tensor, Tensor]: + r"""Random inplace edge dropout for the adjacency matrix + :obj:`edge_index` with probability :obj:`p` using samples from + a Bernoulli distribution. + Expand :func:`torch_geometric.utils.dropout_edge` with type support. + + Args: + edge_index: The edge indices. + p: Dropout probability. + force_undirected: If set to :obj:`True`, will either + drop or keep both edges of an undirected edge. + training: If set to :obj:`False`, this operation is a no-op. + + Returns: + edge_index, edge_mask (LongTensor, BoolTensor): The edge indices and the edge mask. + """ if p < 0. or p > 1.: - raise ValueError(f'Dropout probability has to be between 0 and 1 ' - f'(got {p}') + raise ValueError(f'Dropout probability has to be between 0 and 1 (got {p})') if isinstance(edge_index, SparseTensor): if not training or p == 0.0: edge_mask = edge_index.new_ones(edge_index.sparse_size()[0], dtype=torch.bool) return edge_index, edge_mask - row, col, value = edge_index.coo() - edge_tensor = torch.stack([row, col], dim=0) - - _, edge_mask = dropout_edge_pyg(edge_tensor, p, force_undirected, training) + edge_tensor, _ = pyg_utils.to_edge_index(edge_index) + _, edge_mask = pyg_utils.dropout_edge(edge_tensor, p, force_undirected, training) return edge_index.masked_select_nnz(edge_mask), edge_mask - # FEATURE: support torch.sparse.Tensor + elif pyg_utils.is_torch_sparse_tensor(edge_index): + if not training or p == 0.0: + edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) + return edge_index, edge_mask + + edge_tensor, _ = pyg_utils.to_edge_index(edge_index) + sparse_mask, edge_mask = pyg_utils.dropout_edge(edge_tensor, p, force_undirected, training) + + sparse_mask = pyg_utils.to_torch_coo_tensor(sparse_mask, edge_mask, edge_index.size(1), is_coalesced=True) + return edge_index.sparse_mask(sparse_mask), edge_mask + else: - return dropout_edge_pyg(edge_index, p, force_undirected, training) + return pyg_utils.dropout_edge(edge_index, p, force_undirected, training) diff --git a/pyg_spectral/utils/laplacian.py b/pyg_spectral/utils/laplacian.py index 4f6484e..3465c50 100755 --- a/pyg_spectral/utils/laplacian.py +++ b/pyg_spectral/utils/laplacian.py @@ -4,7 +4,7 @@ from torch import Tensor from torch_geometric.typing import Adj, OptTensor, SparseTensor -from torch_geometric.utils import add_self_loops, remove_self_loops, scatter +from torch_geometric.utils import add_self_loops, remove_self_loops, scatter, is_torch_sparse_tensor from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -19,14 +19,13 @@ def get_laplacian( r"""Computes the graph Laplacian of the graph given by :obj:`edge_index` and optional :obj:`edge_weight`. Remove the normalization of graph adjacency matrix in - :class:`torch_geometric.transforms.get_laplacian`. + :func:`torch_geometric.utils.get_laplacian`. Args: - edge_index (LongTensor or SparseTensor): The edge indices. - edge_weight (Tensor, optional): One-dimensional edge weights. - (default: :obj:`None`) - normalization (bool, optional): The normalization scheme for the graph - Laplacian (default: :obj:`True`): + edge_index: The edge indices. + edge_weight: One-dimensional edge weights. + normalization: The normalization scheme for the graph + Laplacian: 1. :obj:`False`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` @@ -34,12 +33,11 @@ def get_laplacian( 2. :obj:`"True"`: Normalization already applied :math:`\mathbf{L} = diag * \mathbf{I} - \mathbf{A}` - diag (float, optional): Weight of identity when normalization=True. - (default: :obj:`1.0`) - dtype (torch.dtype, optional): The desired data type of returned tensor - in case :obj:`edge_weight=None`. (default: :obj:`None`) - num_nodes (int, optional): The number of nodes, *i.e.* - :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + diag: Weight of identity when normalization=True. + dtype: The desired data type of returned tensor + in case :obj:`edge_weight=None`. + num_nodes: The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. """ if isinstance(edge_index, SparseTensor): assert edge_weight is None @@ -55,8 +53,9 @@ def get_laplacian( edge_index = edge_index.set_diag(deg) return edge_index - elif isinstance(edge_index, Tensor) and edge_index.is_sparse_csr: + elif is_torch_sparse_tensor(edge_index): import scipy.sparse as sp + edge_index = edge_index.to_sparse_csr().coalesce() data = edge_index.values().cpu().detach().numpy() indices = edge_index.col_indices().cpu().detach().numpy() indptr = edge_index.crow_indices().cpu().detach().numpy() diff --git a/setup.py b/setup.py index 7a771cc..f578632 100755 --- a/setup.py +++ b/setup.py @@ -32,4 +32,3 @@ packages=find_packages(), ext_modules=ext_modules, ) -#FEATURE: [optional benckmark](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies)