Skip to content

Commit

Permalink
Update docs and APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Oct 2, 2024
1 parent a8e1e6b commit 4b75dee
Show file tree
Hide file tree
Showing 21 changed files with 269 additions and 144 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Figures can be plotted by: [`benchmark/notebook/fig_hop.ipynb`](benchmark/notebo

#### Frequency response (*Table 12*):
```bash
bash scripts/exp_filter.sh
bash scripts/exp_regression.sh
```

## Customization
Expand All @@ -109,11 +109,11 @@ options:
# Logging configuration
--seed SEED random seed
--dev DEV GPU id
--suffix SUFFIX Save name suffix.
-quiet Dry run without saving logs.
--suffix SUFFIX Result log file name. None:not saving results
-quiet File log. True:dry run without saving logs
--storage {state_file,state_ram,state_gpu}
Storage scheme for saving the checkpoints.
--loglevel LOGLEVEL 10:progress, 15:train, 20:info, 25:result
Checkpoint log storage scheme.
--loglevel LOGLEVEL Console log. 10:progress, 15:train, 20:info, 25:result
# Data configuration
--data DATA Dataset name
--data_split DATA_SPLIT Index or percentage of dataset split
Expand Down
17 changes: 17 additions & 0 deletions benchmark/dataset/linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,23 @@ def forward(self, data: Any) -> Any:


def get_data(datapath, transform, args: Namespace):
r"""Load data based on parameters and configurations.
Args:
datapath: Path to the root data directory.
transform: Data transformation pipeline.
args: Parameters.
* args.data (str): Dataset name.
* args.data_split (str): Index of dataset split.
Returns:
data (Data): The resolved data sample from the dataset.
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
args.multi (bool): True for multi-label classification.
args.metric (str): Main metric name for evaluation.
"""
args.multi = False
args.metric = 's_auroc' if args.data in ['genius'] else 's_f1i'
# FIXME: check split
Expand Down
17 changes: 17 additions & 0 deletions benchmark/dataset/ogbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ def forward(self, data: Any) -> Any:


def get_data(datapath, transform, args: Namespace):
r"""Load data based on parameters and configurations.
Args:
datapath: Path to the root data directory.
transform: Data transformation pipeline.
args: Parameters.
* args.data (str): Dataset name.
* args.data_split (str): Index of dataset split.
Returns:
data (Data): The resolved data sample from the dataset.
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
args.multi (bool): True for multi-label classification.
args.metric (str): Main metric name for evaluation.
"""
args.multi = True if args.data == 'ogbn-proteins' else False
args.metric = 's_auroc' if args.data == 'ogbn-proteins' else 's_f1i'
args.data_split = f"Original_0"
Expand Down
17 changes: 17 additions & 0 deletions benchmark/dataset/pygn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@


def get_data(datapath, transform, args: Namespace):
r"""Load data based on parameters and configurations.
Args:
datapath: Path to the root data directory.
transform: Data transformation pipeline.
args: Parameters.
* args.data (str): Dataset name.
* args.data_split (str): Index of dataset split.
Returns:
data (Data): The resolved data sample from the dataset.
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
args.multi (bool): True for multi-label classification.
args.metric (str): Main metric name for evaluation.
"""
args.multi = False
args.metric = 's_f1i'
assert args.data_split.split('_')[0] in ['Random', 'Stratify']
Expand Down
37 changes: 32 additions & 5 deletions benchmark/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def T_insert(transform, new_t: T.BaseTransform, index=-1) -> T.Compose:


def resolve_data(args: Namespace, dataset: Dataset) -> Data:
r"""Acquire data and properties from dataset.
Args:
args: Parameters.
* args.multi (bool): ``True`` for multi-label classification.
dataset: PyG dataset object.
Returns:
data (Data): The resolved PyG data object from the dataset.
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
"""
# Avoid triggering transform when getting simple properties.
# data = dataset.get(dataset.indices()[0])
# if hasattr(dataset, '_data_list') and dataset._data_list is not None:
Expand All @@ -44,6 +57,18 @@ def resolve_data(args: Namespace, dataset: Dataset) -> Data:


def resolve_split(data_split: str, data: Data) -> Data:
r"""Apply data split masks.
Args:
data_split: Index of dataset split, formatted as ``scheme_split`` or ``scheme_split_seed``.
* ``scheme='Random'``: Random split, ``split`` is ``train/val/test`` ratio.
* ``scheme='Stratify'``: Stratified split, ``split`` is ``train/val/test`` ratio.
* ``scheme='Original'``: Original split, ``split`` is the index of split.
data: PyG data object containing the dataset and its attributes.
Returns:
data (Data): The updated PyG data object with split masks (train/val/test).
"""
ctx = data_split.split('_')
if len(ctx) == 2:
scheme, split = ctx
Expand Down Expand Up @@ -80,8 +105,8 @@ def split_crossval(label: torch.Tensor,
r_train: float,
r_val: float,
seed: int = None,
ignore_neg=True,
stratify=False) -> Tuple[torch.Tensor]:
ignore_neg: bool =True,
stratify: bool =False) -> Tuple[torch.Tensor]:
r"""Split index by cross-validation"""
node_labeled = torch.where(label >= 0)[0] if ignore_neg else np.arange(label.shape[0])

Expand All @@ -98,13 +123,15 @@ def split_crossval(label: torch.Tensor,
index_to_mask(torch.as_tensor(test_idx), size=label.shape[0]))


def even_quantile_labels(vals, nclasses, verbose=True):
def even_quantile_labels(vals: np.ndarray, nclasses: int, verbose:bool=True):
""" partitions vals into nclasses by a quantile based split,
where the first class is less than the 1/nclasses quantile,
second class is less than the 2/nclasses quantile, and so on
vals is np array
returns an np array of int class labels
Args:
vals: The input array to be partitioned.
nclasses: The number of classes to partition the array into.
verbose: Prints the intervals for each class.
"""
label = -1 * np.ones(vals.shape[0], dtype=int)
interval_lst = []
Expand Down
17 changes: 17 additions & 0 deletions benchmark/dataset/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def process(self) -> None:


def get_data(datapath, transform, args: Namespace):
r"""Load data based on parameters and configurations.
Args:
datapath: Path to the root data directory.
transform: Data transformation pipeline.
args: Parameters.
* args.data (str): Dataset name.
* args.data_split (str): Index of dataset split.
Returns:
data (Data): The resolved data sample from the dataset.
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
args.multi (bool): True for multi-label classification.
args.metric (str): Main metric name for evaluation.
"""
args.multi = False
args.metric = {
'chameleon_filtered': 's_f1i',
Expand Down
37 changes: 19 additions & 18 deletions benchmark/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,25 @@ class TrnBase(object):
r"""Base trainer class for general pipelines and tasks.
Args:
model (nn.Module): Pytorch model to be trained.
data (Data): PyG style data.
logger (Logger): Logger object.
args (Namespace): Configuration arguments.
device (str): torch device.
metric (str): Metric for evaluation.
criterion (set): Loss function in :mod:`torch.nn`.
epoch (int): Number of training epochs.
lr_[lin/conv] (float): Learning rate for linear/conv.
wd_[lin/conv] (float): Weight decay for linear/conv.
patience (int): Patience for early stopping.
period (int): Period for checkpoint saving.
suffix (str): Suffix for checkpoint saving.
storage (str): Storage scheme for checkpoint saving.
logpath (Path): Path for logging.
multi (bool): True for multi-label classification.
num_features (int): Number of data input features.
num_classes (int): Number of data output classes.
model: Pytorch model to be trained.
data: PyG style data.
res_logger: Logger for results.
args: Configuration arguments.
* device (str): torch device.
* metric (str): Metric for evaluation.
* criterion (set): Loss function in :mod:`torch.nn`.
* epoch (int): Number of training epochs.
* lr_[lin/conv] (float): Learning rate for linear/conv.
* wd_[lin/conv] (float): Weight decay for linear/conv.
* patience (int): Patience for early stopping.
* period (int): Period for checkpoint saving.
* suffix (str): Suffix for checkpoint saving.
* storage (str): Storage scheme for checkpoint saving.
* logpath (Path): Path for logging.
* multi (bool): True for multi-label classification.
* num_features (int): Number of data input features.
* num_classes (int): Number of data output classes.
Methods:
setup_optimizer: Set up the optimizer and scheduler.
Expand Down
36 changes: 17 additions & 19 deletions benchmark/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,22 @@ class TrnFullbatch(TrnBase):
- Run pipeline: train_val -> test.
Args:
--- TrnBase Args ---
model (nn.Module): Pytorch model to be trained.
data (Data): PyG style data.
logger (Logger): Logger object.
args (Namespace): Configuration arguments.
device (str): torch device.
metric (str): Metric for evaluation.
epoch (int): Number of training epochs.
lr_[lin/conv] (float): Learning rate for linear/conv.
wd_[lin/conv] (float): Weight decay for linear/conv.
patience (int): Patience for early stopping.
period (int): Period for checkpoint saving.
suffix (str): Suffix for checkpoint saving.
storage (str): Storage scheme for checkpoint saving.
logpath (Path): Path for logging.
multi (bool): True for multi-label classification.
num_features (int): Number of data input features.
num_classes (int): Number of data output classes.
model, data, res_logger: args for :class:`TrnBase`.
args: args for :class:`TrnBase`.
* device (str): torch device.
* metric (str): Metric for evaluation.
* epoch (int): Number of training epochs.
* lr_[lin/conv] (float): Learning rate for linear/conv.
* wd_[lin/conv] (float): Weight decay for linear/conv.
* patience (int): Patience for early stopping.
* period (int): Period for checkpoint saving.
* suffix (str): Suffix for checkpoint saving.
* storage (str): Storage scheme for checkpoint saving.
* logpath (Path): Path for logging.
* multi (bool): True for multi-label classification.
* num_features (int): Number of data input features.
* num_classes (int): Number of data output classes.
"""
name: str = 'fb'

Expand Down Expand Up @@ -132,7 +130,7 @@ def test_deg(self) -> ResLogger:
adj_t = self.data.adj_t
if isinstance(adj_t, SparseTensor):
deg = adj_t.sum(dim=0).cpu()
elif isinstance(adj_t, torch.Tensor) and adj_t.is_sparse_csr:
elif pyg_utils.is_torch_sparse_tensor(adj_t):
deg = torch.sparse.sum(adj_t.to_sparse_coo(), [0]).cpu().to_dense()
else:
raise NotImplementedError(f"Type {type(adj_t)} not supported!")
Expand Down
16 changes: 9 additions & 7 deletions benchmark/trainer/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@


class SingleGraphLoader(object):
r"""Loader for PyG.data.Data object for one graph.
r"""Loader for :class:`torch_geometric.data.Data` object for one graph.
Args:
args.seed (int): Random seed.
args.data (str): Dataset name.
args.data_split (str): Index of dataset split.
args: Configuration arguments.
* args.seed (int): Random seed.
* args.data (str): Dataset name.
* args.data_split (str): Index of dataset split.
res_logger: Logger for results.
"""
def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
r"""Assigning dataset identity.
"""
# Assigning dataset identity.
self.seed = args.seed
self.data = args.data.lower()

Expand All @@ -57,7 +59,7 @@ def get(self, args: Namespace) -> Data:
Args:
args.normg (float): Generalized graph norm.
Returns (update in args):
Updates:
args.num_features (int): Number of input features.
args.num_classes (int): Number of output classes.
args.multi (bool): True for multi-label classification.
Expand Down
10 changes: 6 additions & 4 deletions benchmark/trainer/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

class ResCollection(MetricCollection):
def compute(self) -> List[Tuple[str, Any, Callable]]:
r"""Wrap compute output to ResLogger style."""
r"""Wrap compute output to :class:`ResLogger` style."""
dct = self._compute_and_reduce("compute")
return [(k, v.cpu().numpy(), (lambda x: format(x*100, '.3f'))) for k, v in dct.items()]


def metric_loader(args: Namespace) -> MetricCollection:
r"""Loader for torchmetrics.Metric object.
r"""Loader for :class:`torchmetrics.Metric` object.
Args:
args.multi (bool): True for multi-label classification.
args.num_classes (int): Number of output classes/labels.
args: Configuration arguments.
args.multi (bool): True for multi-label classification.
args.num_classes (int): Number of output classes/labels.
"""
# FEATURE: more metrics [glemos1](https://github.com/facebookresearch/glemos/blob/main/src/performances/node_classification.py), [glemos2](https://github.com/facebookresearch/glemos/blob/main/src/utils/eval_utils.py)
if args.multi:
Expand Down
11 changes: 7 additions & 4 deletions benchmark/trainer/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@


class ModelLoader(object):
r"""Loader for nn.Module object.
r"""Loader for :class:`torch.nn.Module` object.
Args:
args.model (str): Model architecture name.
args.conv (str): Convolution layer name.
args: Configuration arguments.
args.model (str): Model architecture name.
args.conv (str): Convolution layer name.
res_logger: Logger for results.
"""
def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
r"""Assigning model identity.
Expand Down Expand Up @@ -153,7 +156,7 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
args.hidden (int): Number of hidden units.
args.dp_[lin/conv] (float): Dropout rate for linear/conv.
Returns (update in args):
Updates:
args.criterion (str): Criterion for loss calculation
"""
self.logger.debug('-'*20 + f" Loading model: {self} " + '-'*20)
Expand Down
Loading

0 comments on commit 4b75dee

Please sign in to comment.