Skip to content

Commit

Permalink
move assign of val_container
Browse files Browse the repository at this point in the history
  • Loading branch information
linjing-lab committed Oct 28, 2023
1 parent a8fbf74 commit 3ac323b
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion released_box/perming/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def train_val(self,
assert tolerance > 1e-9 and tolerance < 1.0, 'Set tolerance to early stop training and validation process within patience.'
assert patience >= 10 and patience <= 100, 'Value coordinate with tolerance should fit about num_epochs and batch_size.'
assert n_jobs == -1 or n_jobs > 0, 'Take full jobs with setting n_jobs=-1 or manually set nums of jobs.'
# if n_jobs==1, parallel processing will be turn off to save cuda memory.
total_step: int = len(self.train_loader)
self.val_container = [set for set in self.val_loader] # replace [*iter] to avoid memory burden in jupyter kernel
self._set_container() # replace same operation with local assignment in _set_container
val_length: int = len(self.val_container)
self.stop_iter: bool = False # init state of train_val
for epoch in range(num_epochs):
Expand Down Expand Up @@ -370,6 +371,12 @@ def _pack_info(self, by: str, state: bool) -> Dict[str, Any]:
else:
regress.update(loss_)
return regress

def _set_container(self) -> None:
'''
Validation Container from `self.val_loader`.
'''
self.val_container = [set for set in self.val_loader] # replace [*iter] to reduce memory burden in jupyter kernel

def train_test_val_split(features: TabularData, target: TabularData, ratio_set: Dict[str, int], random_seed: Optional[int]) -> Tuple[Dict[str, TabularData]]:
'''
Expand Down

0 comments on commit 3ac323b

Please sign in to comment.