Skip to content

Commit

Permalink
modified project
Browse files Browse the repository at this point in the history
  • Loading branch information
linjing-lab committed Oct 29, 2023
1 parent 5db7c84 commit ce33de2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion released_box/perming/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self,
batch_size,
learning_rate_init,
lr_scheduler)
assert num_classes >= 2, 'The predefined options of Multipler are more suitable for Multi-classification.'
assert num_classes >= 2

def _activate(self, activation: str):
'''
Expand Down
6 changes: 3 additions & 3 deletions released_box/perming/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def _criterion(self, criterion: str):
return torch.nn.CrossEntropyLoss()
elif criterion == 'NLLLoss':
return torch.nn.NLLLoss()
elif criterion == 'MultiLabelSoftMarginLoss':
elif criterion == 'MultiLabelSoftMarginLoss': # multi-outputs
return torch.nn.MultiLabelSoftMarginLoss()
elif criterion == 'BCELoss': # classification with num_classes = 2
return torch.nn.BCELoss()
elif criterion == 'BCEWithLogitsLoss':
elif criterion == 'BCEWithLogitsLoss': # multi-outputs
return torch.nn.BCEWithLogitsLoss()
elif criterion == 'MSELoss': # regression
elif criterion == 'MSELoss': # regression (or multi-outputs)
return torch.nn.MSELoss()
elif criterion == 'L1Loss':
return torch.nn.L1Loss()
Expand Down

0 comments on commit ce33de2

Please sign in to comment.