diff --git a/benchmark/dataset/utils.py b/benchmark/dataset/utils.py index 9778bdc..62657e6 100644 --- a/benchmark/dataset/utils.py +++ b/benchmark/dataset/utils.py @@ -2,7 +2,6 @@ import numpy as np import torch from torch_geometric.data import Data, Dataset -from torch_geometric.data.dataset import _get_flattened_data_list import torch_geometric.transforms as T @@ -50,7 +49,7 @@ def resolve_data(args: Namespace, dataset: Dataset) -> Data: def resolve_split(data_split: str, data: Data) -> Data: - # TODO: support more split schemes + # TODO: support more split schemes (also in hyperval) scheme, split = data_split.split('_') if scheme == 'Random': (r_train, r_val) = map(int, split.split('/')[:2]) @@ -60,9 +59,7 @@ def resolve_split(data_split: str, data: Data) -> Data: else: assert hasattr(data, 'train_mask') and hasattr(data, 'val_mask') and hasattr(data, 'test_mask') if data.train_mask.dim() > 1: - split = int(split) - if split >= data.train_mask.size(1): - split = split % data.train_mask.size(1) + split = int(split) % data.train_mask.size(1) data.train_mask = data.train_mask[:, split] data.val_mask = data.val_mask[:, split] data.test_mask = data.test_mask[:, split] diff --git a/benchmark/run_best.py b/benchmark/run_best.py index 4e63abb..f458c7d 100644 --- a/benchmark/run_best.py +++ b/benchmark/run_best.py @@ -26,6 +26,14 @@ def reverse_parse(parser, key, val): return type_func(val) +def filter_res(s, metric): + # remove all substring start with s but not contain metric + flt_common = lambda x: not x.startswith('s_') and not '_' in x + flt_metric = lambda x: metric in x and x.endswith('_test') + lst = [x for x in s.split(', ') if flt_common(x.split(':')[0]) or flt_metric(x.split(':')[0])] + return ', '.join(lst) + + def main(args): # ========== Run configuration logger = setup_logger(args.logpath, level_console=args.loglevel, quiet=args.quiet) @@ -51,7 +59,8 @@ def main(args): trn() logger.info(f"[args]: {args}") - logger.log(logging.LRES, f"[res]: {res_logger}") + resstr = filter_res(res_logger.get_str(), args.metric) + logger.log(logging.LRES, f"{resstr}") res_logger.save() save_args(args.logpath, vars(args)) clear_logger(logger) diff --git a/benchmark/run_param.py b/benchmark/run_param.py index df5c0a0..dd1d6f8 100644 --- a/benchmark/run_param.py +++ b/benchmark/run_param.py @@ -10,8 +10,8 @@ from trainer import ( SingleGraphLoader_Trial, ModelLoader_Trial, - TrnFullbatch, - TrnMinibatch_Trial, TrnBase_Trial) + TrnFullbatch, TrnFullbatch_Trial, + TrnMinibatch, TrnMinibatch_Trial) from utils import ( force_list_str, setup_seed, @@ -114,17 +114,18 @@ def __call__(self, trial): if self.data is None: self.data = self.data_loader.get(args) - self.model, trn_cls = self.model_loader(args) - if trn_cls == TrnFullbatch: - self.trn_cls = type('Trn_Trial', (trn_cls, TrnBase_Trial), {}) - else: - self.trn_cls = TrnMinibatch_Trial + self.model, trn_cls = self.model_loader.get(args) + self.trn_cls = { + TrnFullbatch: TrnFullbatch_Trial, + TrnMinibatch: TrnMinibatch_Trial, + }[trn_cls] for key in ['num_features', 'num_classes', 'metric', 'multi', 'criterion']: self.args.__dict__[key] = args.__dict__[key] self.metric = args.metric - self.data = self.data_loader.update(args, self.data) - self.model = self.model_loader.update(args, self.model) + else: + self.data = self.data_loader.update(args, self.data) + self.model = self.model_loader.update(args, self.model) res_logger = deepcopy(self.res_logger) for key in self.args.param: val = args.__dict__[key] @@ -136,32 +137,24 @@ def __call__(self, trial): res_logger.concat([(key, vali, fmt)]) self.fmt_logger[key] = fmt_logger[key] - if self.trn_cls.__name__ == 'Trn_Trial': - trn = self.trn_cls( + if self.trn is None: + self.trn = self.trn_cls( model=self.model, data=self.data, args=args, res_logger=res_logger,) else: - if self.trn is None: - self.trn = self.trn_cls( - model=self.model, - data=self.data, - args=args, - res_logger=res_logger,) - else: - self.trn.update( - model=self.model, - data=self.data, - args=args, - res_logger=res_logger,) - trn = self.trn - trn.trial = trial - trn() + self.trn.update( + model=self.model, + data=self.data, + args=args, + res_logger=res_logger,) + self.trn.trial = trial + self.trn.run() res_logger.save() trial.set_user_attr("s_test", res_logger._get(col=self.metric+'_test', row=0)) - return res_logger.data.loc[0, self.metric+'_val'] + return res_logger.data.loc[0, self.metric+'_hyperval'] def main(args): @@ -175,9 +168,15 @@ def main(args): study_path, _ = setup_logpath(folder_args=(study_path,)) study = optuna.create_study( study_name=args.logid, - storage=f'sqlite:///{str(study_path)}', + storage=optuna.storages.RDBStorage( + url=f'sqlite:///{str(study_path)}', + heartbeat_interval=3600), direction='maximize', - sampler=optuna.samplers.TPESampler(), + sampler=optuna.samplers.TPESampler( + n_startup_trials=8, + multivariate=True, + group=True, + warn_independent_sampling=False), pruner=optuna.pruners.HyperbandPruner( min_resource=2, max_resource=args.epoch, @@ -214,6 +213,9 @@ def main(args): else: best_params = {k: trn.fmt_logger[k](v) for k, v in study.best_params.items()} save_args(args.logpath, best_params) + axes = optuna.visualization.matplotlib.plot_parallel_coordinate( + study, params=best_params.keys()) + axes.get_figure().savefig(args.logpath.joinpath('parallel_coordinate.png')) clear_logger(logger) @@ -225,6 +227,7 @@ def main(args): args = setup_args(parser) seed_lst = args.seed.copy() + args.n_trials /= len(seed_lst) for seed in seed_lst: args.seed = setup_seed(seed, args.cuda) args.flag = f'param-{args.seed}' diff --git a/benchmark/trainer/__init__.py b/benchmark/trainer/__init__.py index 066ee5e..fa74a59 100755 --- a/benchmark/trainer/__init__.py +++ b/benchmark/trainer/__init__.py @@ -1,13 +1,13 @@ from .load_data import SingleGraphLoader, SingleGraphLoader_Trial from .load_model import ModelLoader, ModelLoader_Trial from .base import TrnBase_Trial -from .fullbatch import TrnFullbatch +from .fullbatch import TrnFullbatch, TrnFullbatch_Trial from .minibatch import TrnMinibatch, TrnMinibatch_Trial __all__ = [ 'SingleGraphLoader', 'SingleGraphLoader_Trial', 'ModelLoader', 'ModelLoader_Trial', 'TrnBase_Trial', - 'TrnFullbatch', + 'TrnFullbatch', 'TrnFullbatch_Trial', 'TrnMinibatch', 'TrnMinibatch_Trial', ] diff --git a/benchmark/trainer/base.py b/benchmark/trainer/base.py index 4178e82..689d99c 100755 --- a/benchmark/trainer/base.py +++ b/benchmark/trainer/base.py @@ -10,10 +10,14 @@ import torch import torch.nn as nn from torch_geometric.data import Data +import torch_geometric.utils as pyg_utils + from pyg_spectral import profile from pyg_spectral.utils import load_import +from dataset import split_random from utils import CkptLogger, ResLogger +from .load_metric import metric_loader class TrnBase(object): @@ -70,10 +74,12 @@ def __init__(self, self.data = data # Evaluation metrics - self.splits = ['train', 'val', 'test'] self.multi = args.multi self.num_features = args.num_features self.num_classes = args.num_classes + self.splits = ['train', 'val', 'test'] + metric = metric_loader(args).to(self.device) + self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} # Loggers self.logger = logging.getLogger('log') @@ -211,6 +217,29 @@ def __call__(self, *args, **kwargs): class TrnBase_Trial(TrnBase): r"""Trainer supporting optuna.pruners in training. """ + def __init__(self, + model: nn.Module, + data: Data, + args: Namespace, + res_logger: ResLogger = None, + **kwargs): + super().__init__(model, data, args, res_logger, **kwargs) + self.splits = ['train', 'val', 'hyperval', 'test'] + metric = metric_loader(args).to(self.device) + self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} + + def split_hyperval(self, data: Data) -> Data: + attr_to_index = lambda k: pyg_utils.mask_to_index(data[f'{k}_mask']) if hasattr(data, f'{k}_mask') else torch.tensor([]) + idx = {k: attr_to_index(k) for k in ['train', 'val', 'hyperval']} + r_train = 1.0 * len(idx['train']) / (len(idx['train']) + len(idx['val']) + len(idx['hyperval'])) + r_val = r_hyperval = (1.0 - r_train) / 2 + + label = data.y.detach().clone() + label[data.test_mask] = -1 + data.train_mask, data.val_mask, data.hyperval_mask = split_random(label, r_train, r_val, ignore_neg=True) + + return data + def clear(self): if self.evaluator: for k in self.splits: diff --git a/benchmark/trainer/fullbatch.py b/benchmark/trainer/fullbatch.py index 7c3ce0b..768e1f7 100755 --- a/benchmark/trainer/fullbatch.py +++ b/benchmark/trainer/fullbatch.py @@ -14,8 +14,7 @@ from pyg_spectral.profile import Stopwatch -from .base import TrnBase -from .load_metric import metric_loader +from .base import TrnBase, TrnBase_Trial from utils import ResLogger @@ -52,8 +51,6 @@ def __init__(self, args: Namespace, **kwargs): super(TrnFullbatch, self).__init__(model, data, args, **kwargs) - metric = metric_loader(args).to(self.device) - self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} self.mask: dict = None self.flag_test_deg = args.test_deg if hasattr(args, 'test_deg') else False @@ -64,7 +61,7 @@ def clear(self): def _fetch_data(self) -> Tuple[Data, dict]: r"""Process the single graph data.""" - t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index', 'train_mask', 'val_mask', 'test_mask']) + t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index'] + [f'{k}_mask' for k in self.splits]) self.data = t_to_device(self.data) # FIXME: Update to `EdgeIndex` [Release note 2.5.0](https://github.com/pyg-team/pytorch_geometric/releases/tag/2.5.0) # if not pyg_utils.is_sparse(self.data.adj_t): @@ -176,3 +173,26 @@ def run(self) -> ResLogger: res_run.merge(self.test_deg()) return self.res_logger.merge(res_run) + + +class TrnFullbatch_Trial(TrnFullbatch, TrnBase_Trial): + r"""Trainer supporting optuna.pruners in training. + """ + def run(self) -> ResLogger: + res_run = ResLogger() + self.data = self.split_hyperval(self.data) + self._fetch_data() + self.model = self.model.to(self.device) + self.setup_optimizer() + + res_train = self.train_val() + res_run.merge(res_train) + + self.model = self.ckpt_logger.load('best', model=self.model) + res_test = self.test(['train', 'val', 'hyperval', 'test']) + res_run.merge(res_test) + + return self.res_logger.merge(res_run) + + def update(self, *args, **kwargs): + self.__init__(*args, **kwargs) diff --git a/benchmark/trainer/load_data.py b/benchmark/trainer/load_data.py index 1ac4f46..032d92f 100755 --- a/benchmark/trainer/load_data.py +++ b/benchmark/trainer/load_data.py @@ -92,7 +92,8 @@ def get(self, args: Namespace) -> Data: T_insert(self.transform, Tspec.GenNorm(left=args.normg), index=-2) assert self.data in class_list, f"Invalid dataset: {self.data}" - data = func_list[class_list[self.data]](DATAPATH, self.transform, args) + get_data = func_list[class_list[self.data]] + data = get_data(DATAPATH, self.transform, args) self.res_logger.concat([('data', self.data, str), ('metric', args.metric, str)]) return data diff --git a/benchmark/trainer/minibatch.py b/benchmark/trainer/minibatch.py index 3274375..267d976 100755 --- a/benchmark/trainer/minibatch.py +++ b/benchmark/trainer/minibatch.py @@ -11,7 +11,6 @@ import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader from torch_geometric.data import Data, Dataset -import torch_geometric.utils as pyg_utils from pyg_spectral.nn.norm import TensorStandardScaler from pyg_spectral.profile import Stopwatch, Accumulator @@ -61,9 +60,6 @@ def __init__(self, assert isinstance(args.normf, int) self.norm_prop = TensorStandardScaler(dim=args.normf) - metric = metric_loader(args).to(self.device) - self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} - self.shuffle = {'train': True, 'val': False, 'test': False} self.embed = None @@ -71,7 +67,7 @@ def clear(self): del self.data, self.embed return super().clear() - def _fetch_data(self) -> Tuple[Data, dict]: + def _fetch_data(self) -> tuple: r"""Process the single graph data.""" # FIXME: Update to `EdgeIndex` [Release note 2.5.0](https://github.com/pyg-team/pytorch_geometric/releases/tag/2.5.0) # if not pyg_utils.is_sparse(self.data.adj_t): @@ -80,17 +76,28 @@ def _fetch_data(self) -> Tuple[Data, dict]: # self.logger.warning(f"Graph {self.data} contains isolated nodes.") mask = {k: getattr(self.data, f'{k}_mask') for k in self.splits} - return self.data, mask - def _fetch_preprocess(self, data: Data) -> tuple: - r"""Call model preprocess for precomputation.""" - if hasattr(data, 'adj_t'): - input, label = (data.x, data.adj_t), data.y + if hasattr(self.data, 'adj_t'): + input, label = (self.data.x, self.data.adj_t), self.data.y else: - input, label = (data.x, data.edge_index), data.y + input, label = (self.data.x, self.data.edge_index), self.data.y if hasattr(self.model, 'preprocess'): self.model.preprocess(*input) - return input, label + del self.data + return input, label, mask + + def _fetch_preprocess(self, embed: torch.Tensor, label: torch.Tensor, mask: dict) -> dict: + r"""Call model preprocess for precomputation.""" + self.embed = {} + for k in self.splits: + dataset = TensorDataset(embed[mask[k]], label[mask[k]]) + # self.embed[k] = DataLoader(dataset, + # batch_size=self.batch, + # shuffle=self.shuffle[k], + # num_workers=0) + self.embed[k] = dataset + self.logger.log(logging.LTRN, f"[{k}]: n_sample={len(dataset)}, n_batch={len(self.embed[k]) // self.batch}") + return self.embed def _fetch_input(self, split: str) -> Generator: r"""Process each sample of model input and label for training.""" @@ -154,10 +161,7 @@ def preprocess(self) -> ResLogger: r"""Pipeline for precomputation on CPU.""" self.logger.debug('-'*20 + f" Start propagation: pre " + '-'*20) - data, mask = self._fetch_data() - input, label = self._fetch_preprocess(data) - del self.data - + input, label, mask = self._fetch_data() stopwatch = Stopwatch() if hasattr(self.model, 'convolute'): with stopwatch: @@ -169,15 +173,7 @@ def preprocess(self) -> ResLogger: self.norm_prop.fit(embed[mask['train']]) embed = self.norm_prop(embed) - self.embed = {} - for k in self.splits: - dataset = TensorDataset(embed[mask[k]], label[mask[k]]) - # self.embed[k] = DataLoader(dataset, - # batch_size=self.batch, - # shuffle=self.shuffle[k], - # num_workers=0) - self.embed[k] = dataset - self.logger.log(logging.LTRN, f"[{k}]: n_sample={len(dataset)}, n_batch={len(self.embed[k]) // self.batch}") + self._fetch_preprocess(embed, label, mask) return ResLogger()( [('time_pre', stopwatch.data)]) @@ -208,6 +204,28 @@ class TrnMinibatch_Trial(TrnMinibatch, TrnBase_Trial): r"""Trainer supporting optuna.pruners in training. Lazy calling precomputation. """ + @TrnBase._log_memory(split='pre') + def preprocess(self) -> ResLogger: + r"""Pipeline for precomputation on CPU.""" + self.logger.debug('-'*20 + f" Start propagation: pre " + '-'*20) + + self.data = self.split_hyperval(self.data) + input, label, mask = self._fetch_data() + stopwatch = Stopwatch() + if hasattr(self.model, 'convolute'): + with stopwatch: + self.raw_embed = self.model.convolute(*input) + else: + self.raw_embed = input[0] + + if hasattr(self, 'norm_prop'): + self.norm_prop.fit(self.raw_embed[mask['train']]) + self.raw_embed = self.norm_prop(self.raw_embed) + + self._fetch_preprocess(self.raw_embed, label, mask) + return ResLogger()( + [('time_pre', stopwatch.data)]) + def run(self) -> ResLogger: res_run = ResLogger() @@ -228,17 +246,17 @@ def run(self) -> ResLogger: res_run.merge(res_train) self.model = self.ckpt_logger.load('best', model=self.model) - res_test = self.test() + res_test = self.test(['train', 'val', 'hyperval', 'test']) res_run.merge(res_test) return self.res_logger.merge(res_run) def update(self, - model: nn.Module, - data: Data, - args: Namespace, - res_logger: ResLogger = None, - **kwargs): + model: nn.Module, + data: Data, + args: Namespace, + res_logger: ResLogger = None, + **kwargs): self.model = model self.data = data self.res_logger = res_logger or ResLogger() diff --git a/benchmark/utils/logger.py b/benchmark/utils/logger.py index 7e807c5..0606104 100755 --- a/benchmark/utils/logger.py +++ b/benchmark/utils/logger.py @@ -24,6 +24,7 @@ warnings.filterwarnings('ignore', '.*No negative samples in targets.*') warnings.filterwarnings('ignore', '.*No positive samples found in target.*') warnings.filterwarnings('ignore', '.*No negative samples found in target.*') +warnings.filterwarnings('ignore', '.*is( an)? experimental.*') def setup_logger(logpath: Union[Path, str] = LOGPATH,