diff --git a/released_box/perming/_utils.py b/released_box/perming/_utils.py index ab88d1c..669d1a4 100644 --- a/released_box/perming/_utils.py +++ b/released_box/perming/_utils.py @@ -236,8 +236,9 @@ def train_val(self, assert tolerance > 1e-9 and tolerance < 1.0, 'Set tolerance to early stop training and validation process within patience.' assert patience >= 10 and patience <= 100, 'Value coordinate with tolerance should fit about num_epochs and batch_size.' assert n_jobs == -1 or n_jobs > 0, 'Take full jobs with setting n_jobs=-1 or manually set nums of jobs.' + # if n_jobs==1, parallel processing will be turn off to save cuda memory. total_step: int = len(self.train_loader) - self.val_container = [set for set in self.val_loader] # replace [*iter] to avoid memory burden in jupyter kernel + self._set_container() # replace same operation with local assignment in _set_container val_length: int = len(self.val_container) self.stop_iter: bool = False # init state of train_val for epoch in range(num_epochs): @@ -370,6 +371,12 @@ def _pack_info(self, by: str, state: bool) -> Dict[str, Any]: else: regress.update(loss_) return regress + + def _set_container(self) -> None: + ''' + Validation Container from `self.val_loader`. + ''' + self.val_container = [set for set in self.val_loader] # replace [*iter] to reduce memory burden in jupyter kernel def train_test_val_split(features: TabularData, target: TabularData, ratio_set: Dict[str, int], random_seed: Optional[int]) -> Tuple[Dict[str, TabularData]]: '''