diff --git a/docs/source/releases.rst b/docs/source/releases.rst index c26b7c44..66003aec 100644 --- a/docs/source/releases.rst +++ b/docs/source/releases.rst @@ -1,6 +1,11 @@ Release Notes ============= +Version 19.11.2 +--------------- +Bugfixes: + - `predict_proba_arff` now also accepts a `target_column` as expected from the previous update. + Version 19.11.1 --------------- Features: diff --git a/gama/GamaClassifier.py b/gama/GamaClassifier.py index c1332346..f2169834 100644 --- a/gama/GamaClassifier.py +++ b/gama/GamaClassifier.py @@ -1,5 +1,5 @@ import inspect -from typing import Union +from typing import Union, Optional import numpy as np import pandas as pd @@ -88,17 +88,25 @@ def predict_proba(self, x: Union[pd.DataFrame, np.ndarray]): x[col] = x[col].astype(self._X[col].dtype) return self._predict_proba(x) - def predict_proba_arff(self, arff_file_path: str): + def predict_proba_arff(self, arff_file_path: str, target_column: Optional[str] = None): """ Predict the class probabilities for input in the arff_file, must have empty target column. - Predict target for X, using the best found pipeline(s) during the `fit` call. - - :param arff_file_path: str + Parameters + ---------- + arff_file_path: str + An ARFF file with the same columns as the one that used in fit. + Target column must be present in file, but its values are ignored (can be '?'). + target_column: str, optional (default=None) + Specifies which column the model should predict. + If left None, the last column is taken to be the target. - :return: a numpy array with class probabilities. The array is of shape (N, K) where N is the length of the + Returns + ------- + numpy.ndarray + Numpy array with class probabilities. The array is of shape (N, K) where N is the length of the first dimension of X, and K is the number of class labels found in `y` of `fit`. """ - X, _ = X_y_from_arff(arff_file_path) + X, _ = X_y_from_arff(arff_file_path, target_column) return self._predict_proba(X) def fit(self, x, y, *args, **kwargs): diff --git a/tests/system/test_gamaclassifier.py b/tests/system/test_gamaclassifier.py index de80dac9..c61c9629 100644 --- a/tests/system/test_gamaclassifier.py +++ b/tests/system/test_gamaclassifier.py @@ -35,6 +35,7 @@ breast_cancer_missing = dict( name='breast_cancer_missing', load=load_breast_cancer, + target='status', test_size=143, n_classes=2, base_accuracy=0.62937, @@ -98,9 +99,9 @@ def _test_dataset_problem( y_test = [str(val) for val in y_test] with Stopwatch() as sw: - gama.fit_arff(train_path) - class_predictions = gama.predict_arff(test_path) - class_probabilities = gama.predict_proba_arff(test_path) + gama.fit_arff(train_path, target_column=data['target']) + class_predictions = gama.predict_arff(test_path, target_column=data['target']) + class_probabilities = gama.predict_proba_arff(test_path, target_column=data['target']) gama_score = gama.score_arff(test_path) else: X, y = data['load'](return_X_y=True)