Skip to content

Commit

Permalink
Tune with random splits
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Sep 23, 2024
1 parent ca38bb0 commit c308468
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 75 deletions.
7 changes: 2 additions & 5 deletions benchmark/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.data.dataset import _get_flattened_data_list
import torch_geometric.transforms as T


Expand Down Expand Up @@ -50,7 +49,7 @@ def resolve_data(args: Namespace, dataset: Dataset) -> Data:


def resolve_split(data_split: str, data: Data) -> Data:
# TODO: support more split schemes
# TODO: support more split schemes (also in hyperval)
scheme, split = data_split.split('_')
if scheme == 'Random':
(r_train, r_val) = map(int, split.split('/')[:2])
Expand All @@ -60,9 +59,7 @@ def resolve_split(data_split: str, data: Data) -> Data:
else:
assert hasattr(data, 'train_mask') and hasattr(data, 'val_mask') and hasattr(data, 'test_mask')
if data.train_mask.dim() > 1:
split = int(split)
if split >= data.train_mask.size(1):
split = split % data.train_mask.size(1)
split = int(split) % data.train_mask.size(1)
data.train_mask = data.train_mask[:, split]
data.val_mask = data.val_mask[:, split]
data.test_mask = data.test_mask[:, split]
Expand Down
11 changes: 10 additions & 1 deletion benchmark/run_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def reverse_parse(parser, key, val):
return type_func(val)


def filter_res(s, metric):
# remove all substring start with s but not contain metric
flt_common = lambda x: not x.startswith('s_') and not '_' in x
flt_metric = lambda x: metric in x and x.endswith('_test')
lst = [x for x in s.split(', ') if flt_common(x.split(':')[0]) or flt_metric(x.split(':')[0])]
return ', '.join(lst)


def main(args):
# ========== Run configuration
logger = setup_logger(args.logpath, level_console=args.loglevel, quiet=args.quiet)
Expand All @@ -51,7 +59,8 @@ def main(args):
trn()

logger.info(f"[args]: {args}")
logger.log(logging.LRES, f"[res]: {res_logger}")
resstr = filter_res(res_logger.get_str(), args.metric)
logger.log(logging.LRES, f"{resstr}")
res_logger.save()
save_args(args.logpath, vars(args))
clear_logger(logger)
Expand Down
61 changes: 32 additions & 29 deletions benchmark/run_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from trainer import (
SingleGraphLoader_Trial,
ModelLoader_Trial,
TrnFullbatch,
TrnMinibatch_Trial, TrnBase_Trial)
TrnFullbatch, TrnFullbatch_Trial,
TrnMinibatch, TrnMinibatch_Trial)
from utils import (
force_list_str,
setup_seed,
Expand Down Expand Up @@ -114,17 +114,18 @@ def __call__(self, trial):

if self.data is None:
self.data = self.data_loader.get(args)
self.model, trn_cls = self.model_loader(args)
if trn_cls == TrnFullbatch:
self.trn_cls = type('Trn_Trial', (trn_cls, TrnBase_Trial), {})
else:
self.trn_cls = TrnMinibatch_Trial
self.model, trn_cls = self.model_loader.get(args)
self.trn_cls = {
TrnFullbatch: TrnFullbatch_Trial,
TrnMinibatch: TrnMinibatch_Trial,
}[trn_cls]

for key in ['num_features', 'num_classes', 'metric', 'multi', 'criterion']:
self.args.__dict__[key] = args.__dict__[key]
self.metric = args.metric
self.data = self.data_loader.update(args, self.data)
self.model = self.model_loader.update(args, self.model)
else:
self.data = self.data_loader.update(args, self.data)
self.model = self.model_loader.update(args, self.model)
res_logger = deepcopy(self.res_logger)
for key in self.args.param:
val = args.__dict__[key]
Expand All @@ -136,32 +137,24 @@ def __call__(self, trial):
res_logger.concat([(key, vali, fmt)])
self.fmt_logger[key] = fmt_logger[key]

if self.trn_cls.__name__ == 'Trn_Trial':
trn = self.trn_cls(
if self.trn is None:
self.trn = self.trn_cls(
model=self.model,
data=self.data,
args=args,
res_logger=res_logger,)
else:
if self.trn is None:
self.trn = self.trn_cls(
model=self.model,
data=self.data,
args=args,
res_logger=res_logger,)
else:
self.trn.update(
model=self.model,
data=self.data,
args=args,
res_logger=res_logger,)
trn = self.trn
trn.trial = trial
trn()
self.trn.update(
model=self.model,
data=self.data,
args=args,
res_logger=res_logger,)
self.trn.trial = trial
self.trn.run()

res_logger.save()
trial.set_user_attr("s_test", res_logger._get(col=self.metric+'_test', row=0))
return res_logger.data.loc[0, self.metric+'_val']
return res_logger.data.loc[0, self.metric+'_hyperval']


def main(args):
Expand All @@ -175,9 +168,15 @@ def main(args):
study_path, _ = setup_logpath(folder_args=(study_path,))
study = optuna.create_study(
study_name=args.logid,
storage=f'sqlite:///{str(study_path)}',
storage=optuna.storages.RDBStorage(
url=f'sqlite:///{str(study_path)}',
heartbeat_interval=3600),
direction='maximize',
sampler=optuna.samplers.TPESampler(),
sampler=optuna.samplers.TPESampler(
n_startup_trials=8,
multivariate=True,
group=True,
warn_independent_sampling=False),
pruner=optuna.pruners.HyperbandPruner(
min_resource=2,
max_resource=args.epoch,
Expand Down Expand Up @@ -214,6 +213,9 @@ def main(args):
else:
best_params = {k: trn.fmt_logger[k](v) for k, v in study.best_params.items()}
save_args(args.logpath, best_params)
axes = optuna.visualization.matplotlib.plot_parallel_coordinate(
study, params=best_params.keys())
axes.get_figure().savefig(args.logpath.joinpath('parallel_coordinate.png'))
clear_logger(logger)


Expand All @@ -225,6 +227,7 @@ def main(args):
args = setup_args(parser)

seed_lst = args.seed.copy()
args.n_trials /= len(seed_lst)
for seed in seed_lst:
args.seed = setup_seed(seed, args.cuda)
args.flag = f'param-{args.seed}'
Expand Down
4 changes: 2 additions & 2 deletions benchmark/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .load_data import SingleGraphLoader, SingleGraphLoader_Trial
from .load_model import ModelLoader, ModelLoader_Trial
from .base import TrnBase_Trial
from .fullbatch import TrnFullbatch
from .fullbatch import TrnFullbatch, TrnFullbatch_Trial
from .minibatch import TrnMinibatch, TrnMinibatch_Trial

__all__ = [
'SingleGraphLoader', 'SingleGraphLoader_Trial',
'ModelLoader', 'ModelLoader_Trial',
'TrnBase_Trial',
'TrnFullbatch',
'TrnFullbatch', 'TrnFullbatch_Trial',
'TrnMinibatch', 'TrnMinibatch_Trial',
]
31 changes: 30 additions & 1 deletion benchmark/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
import torch
import torch.nn as nn
from torch_geometric.data import Data
import torch_geometric.utils as pyg_utils

from pyg_spectral import profile
from pyg_spectral.utils import load_import

from dataset import split_random
from utils import CkptLogger, ResLogger
from .load_metric import metric_loader


class TrnBase(object):
Expand Down Expand Up @@ -70,10 +74,12 @@ def __init__(self,
self.data = data

# Evaluation metrics
self.splits = ['train', 'val', 'test']
self.multi = args.multi
self.num_features = args.num_features
self.num_classes = args.num_classes
self.splits = ['train', 'val', 'test']
metric = metric_loader(args).to(self.device)
self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits}

# Loggers
self.logger = logging.getLogger('log')
Expand Down Expand Up @@ -211,6 +217,29 @@ def __call__(self, *args, **kwargs):
class TrnBase_Trial(TrnBase):
r"""Trainer supporting optuna.pruners in training.
"""
def __init__(self,
model: nn.Module,
data: Data,
args: Namespace,
res_logger: ResLogger = None,
**kwargs):
super().__init__(model, data, args, res_logger, **kwargs)
self.splits = ['train', 'val', 'hyperval', 'test']
metric = metric_loader(args).to(self.device)
self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits}

def split_hyperval(self, data: Data) -> Data:
attr_to_index = lambda k: pyg_utils.mask_to_index(data[f'{k}_mask']) if hasattr(data, f'{k}_mask') else torch.tensor([])
idx = {k: attr_to_index(k) for k in ['train', 'val', 'hyperval']}
r_train = 1.0 * len(idx['train']) / (len(idx['train']) + len(idx['val']) + len(idx['hyperval']))
r_val = r_hyperval = (1.0 - r_train) / 2

label = data.y.detach().clone()
label[data.test_mask] = -1
data.train_mask, data.val_mask, data.hyperval_mask = split_random(label, r_train, r_val, ignore_neg=True)

return data

def clear(self):
if self.evaluator:
for k in self.splits:
Expand Down
30 changes: 25 additions & 5 deletions benchmark/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from pyg_spectral.profile import Stopwatch

from .base import TrnBase
from .load_metric import metric_loader
from .base import TrnBase, TrnBase_Trial
from utils import ResLogger


Expand Down Expand Up @@ -52,8 +51,6 @@ def __init__(self,
args: Namespace,
**kwargs):
super(TrnFullbatch, self).__init__(model, data, args, **kwargs)
metric = metric_loader(args).to(self.device)
self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits}

self.mask: dict = None
self.flag_test_deg = args.test_deg if hasattr(args, 'test_deg') else False
Expand All @@ -64,7 +61,7 @@ def clear(self):

def _fetch_data(self) -> Tuple[Data, dict]:
r"""Process the single graph data."""
t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index', 'train_mask', 'val_mask', 'test_mask'])
t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index'] + [f'{k}_mask' for k in self.splits])
self.data = t_to_device(self.data)
# FIXME: Update to `EdgeIndex` [Release note 2.5.0](https://github.com/pyg-team/pytorch_geometric/releases/tag/2.5.0)
# if not pyg_utils.is_sparse(self.data.adj_t):
Expand Down Expand Up @@ -176,3 +173,26 @@ def run(self) -> ResLogger:
res_run.merge(self.test_deg())

return self.res_logger.merge(res_run)


class TrnFullbatch_Trial(TrnFullbatch, TrnBase_Trial):
r"""Trainer supporting optuna.pruners in training.
"""
def run(self) -> ResLogger:
res_run = ResLogger()
self.data = self.split_hyperval(self.data)
self._fetch_data()
self.model = self.model.to(self.device)
self.setup_optimizer()

res_train = self.train_val()
res_run.merge(res_train)

self.model = self.ckpt_logger.load('best', model=self.model)
res_test = self.test(['train', 'val', 'hyperval', 'test'])
res_run.merge(res_test)

return self.res_logger.merge(res_run)

def update(self, *args, **kwargs):
self.__init__(*args, **kwargs)
3 changes: 2 additions & 1 deletion benchmark/trainer/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def get(self, args: Namespace) -> Data:

T_insert(self.transform, Tspec.GenNorm(left=args.normg), index=-2)
assert self.data in class_list, f"Invalid dataset: {self.data}"
data = func_list[class_list[self.data]](DATAPATH, self.transform, args)
get_data = func_list[class_list[self.data]]
data = get_data(DATAPATH, self.transform, args)

self.res_logger.concat([('data', self.data, str), ('metric', args.metric, str)])
return data
Expand Down
Loading

0 comments on commit c308468

Please sign in to comment.