diff --git a/crabnet/crabnet_.py b/crabnet/crabnet_.py index f6fd828..5529707 100644 --- a/crabnet/crabnet_.py +++ b/crabnet/crabnet_.py @@ -245,14 +245,13 @@ def __init__( self.max_lr = max_lr # Apply BCEWithLogitsLoss to model output if binary classification is True - if classification: - self.classification = True + self.classification = classification + self.model_name = model_name self.mat_prop = mat_prop self.data_loader = None self.train_loader = None - self.classification = False self.n_elements = n_elements self.fudge = fudge # expected fractional tolerance (std. dev) ~= 2%