Skip to content

Commit

Permalink
adjusting numpy.random.seed usage in cotraining
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan Stomps committed Jan 18, 2023
1 parent ec47a63 commit 95fc695
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
15 changes: 7 additions & 8 deletions models/SSML/CoTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ class CoTraining:
regression implementation with hyperparameter optimization.
Data agnostic (i.e. user supplied data inputs).
TODO: Currently only supports binary classification.
Add multinomial functions and unit tests.
Add functionality for regression(?)
- Add multinomial functions and unit tests.
- Add functionality for regression(?)
Inputs:
kwargs: logistic regression input functions.
keys random_state, max_iter, tol, and C supported.
random_state: int/float for reproducible intiailization.
keys seed, random_state, max_iter, tol, and C supported.
seed/random_state: int/float for reproducible intiailization.
'''

# only binary so far
def __init__(self, **kwargs):
# supported keys = ['max_iter', 'tol', 'C', 'random_state']
# supported keys = ['max_iter', 'tol', 'C', 'random_state', 'seed']
# defaults to a fixed value for reproducibility
self.random_state = kwargs.pop('random_state', 0)
# set the random seed of training splits for reproducibility
self.seed = kwargs.pop('seed', 0)
np.random.seed(self.seed)
# parameters for cotraining logistic regression models:
# defaults to sklearn.linear_model.LogisticRegression default vals
self.max_iter = kwargs.pop('max_iter', 100)
Expand Down Expand Up @@ -236,9 +238,6 @@ def train(self, trainx, trainy, Ux,
# avoid overwriting when deleting in co-training loop
U_lr = Ux.copy()

# set the random seed of training splits for reproducibility
np.random.seed(self.seed)

# TODO: allow a user to specify uneven splits between the two models
split_frac = 0.5
# labeled training data
Expand Down
16 changes: 12 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,16 @@ def test_pca():

def test_LogReg():
# test saving model input parameters
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0}
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0, 'random_state': 0}
model = LogReg(max_iter=params['max_iter'],
tol=params['tol'],
C=params['C'])
C=params['C'],
random_state=params['random_state'])

assert model.model.max_iter == params['max_iter']
assert model.model.tol == params['tol']
assert model.model.C == params['C']
assert model.random_state == params['random_state']

X_train, X_test, y_train, y_test = train_test_split(pytest.spectra,
pytest.labels,
Expand Down Expand Up @@ -187,10 +189,13 @@ def test_LogReg():

def test_CoTraining():
# test saving model input parameters
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0}
params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0,
'random_state': 0, 'seed': 1}
model = CoTraining(max_iter=params['max_iter'],
tol=params['tol'],
C=params['C'])
C=params['C'],
random_state=params['random_state'],
seed=params['seed'])

assert model.model1.max_iter == params['max_iter']
assert model.model1.tol == params['tol']
Expand All @@ -200,6 +205,9 @@ def test_CoTraining():
assert model.model2.tol == params['tol']
assert model.model2.C == params['C']

assert model.random_state == params['random_state']
assert model.seed == params['seed']

X, Ux, y, Uy = train_test_split(pytest.spectra,
pytest.labels,
test_size=0.5,
Expand Down

0 comments on commit 95fc695

Please sign in to comment.