Skip to content

Commit

Permalink
Merge pull request #1 from neuroneural/sklearn_update
Browse files Browse the repository at this point in the history
Sklearn update
  • Loading branch information
sergeyplis authored Feb 22, 2024
2 parents 872cfa7 + 9efb1af commit 72f8008
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 9 deletions.
4 changes: 2 additions & 2 deletions polyssifier/poly_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def build_classifiers(exclude, scale, feature_selection, nCols):
if 'Decision Tree' not in exclude:
classifiers['Decision Tree'] = {
'clf': DecisionTreeClassifier(max_depth=None,
max_features='auto'),
max_features='sqrt'),
'parameters': {}}

if 'Random Forest' not in exclude:
classifiers['Random Forest'] = {
'clf': RandomForestClassifier(max_depth=None,
n_estimators=10,
max_features='auto'),
max_features='sqrt'),
'parameters': {'n_estimators': list(range(5, 20))}}

if 'Logistic Regression' not in exclude:
Expand Down
3 changes: 2 additions & 1 deletion polyssifier/polyssifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def poly(data, label, n_folds=10, scale=True, exclude=[],
_le.fit(label)
label = _le.transform(label)
n_class = len(np.unique(label))
logger.info(f'Detected {n_class} classes in label')
logger.info('Detected ' + str(n_class) + ' classes in label')

if save and not os.path.exists('poly_{}/models'.format(project_name)):
os.makedirs('poly_{}/models'.format(project_name))
Expand Down Expand Up @@ -84,6 +84,7 @@ def poly(data, label, n_folds=10, scale=True, exclude=[],
kf = list(skf.split(np.zeros(data.shape[0]), label))

# Parallel processing of tasks

manager = Manager()
args = manager.list()
args.append({}) # Store inputs
Expand Down
2 changes: 1 addition & 1 deletion polyssifier/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def plot_scores(scores, scoring='auc', file_name='temp', min_val=None):
break
ax1.text(rect.get_x() - rect.get_width() / 2., ymin + (1 - ymin) * .01,
data.index[n], ha='center', va='bottom',
rotation='90', color='black', fontsize=15)
rotation=90, color='black', fontsize=15)
plt.tight_layout()
plt.savefig(file_name + '.pdf')
plt.savefig(file_name + '.svg', transparent=False)
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
],

# What does your project relate to?
Expand All @@ -78,6 +84,6 @@
# your project is installed. For an analysis of "install_requires" vs pip's
# requirements files see:
# https://packaging.python.org/en/latest/requirements.html
install_requires=['pandas', 'sklearn', 'numpy', 'matplotlib'],
install_requires=['pandas','scikit-learn', 'numpy', 'matplotlib','joblib'],

zip_safe=False) # Override annoying default behavior of easy_install.
4 changes: 2 additions & 2 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_run():
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test2')
for key, score in report.scores.mean().iteritems():
for key, score in report.scores.mean().items():
assert score < 5, '{} score is too low'.format(key)


Expand All @@ -45,7 +45,7 @@ def test_multiclass():
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test3')
for key, score in report.scores.mean().iteritems():
for key, score in report.scores.mean().items():
assert score < 5, '{} score is too low'.format(key)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_run():
report = poly(data, label, n_folds=2, verbose=1,
feature_selection=False,
save=False, project_name='test2')
for key, score in report.scores.mean().iteritems():
for key, score in report.scores.mean().items():
assert score < 5, '{} score is too low'.format(key)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_feature_selection_regression():
assert (report_with_features.scores.mean()[:, 'train'] > 0.2).all(),\
'train score below chance'

for key, ypred in report_with_features.predictions.iteritems():
for key, ypred in report_with_features.predictions.items():
mse = np.linalg.norm(ypred - diabetes_target) / len(diabetes_target)
assert mse < 5, '{} Prediction error is too high'.format(key)

Expand Down

0 comments on commit 72f8008

Please sign in to comment.