Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
SPARK-29116 (#285)
Browse files Browse the repository at this point in the history
* Add dynamically typed stub for ml.tree

* Add statically typed stub for ml.tree

* Update annotations for tree models in ml.regression

* Update annotations for tree models in ml.classification
  • Loading branch information
zero323 authored Jan 11, 2020
1 parent 52de6cc commit 93e6e41
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 105 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ API Coverage
+------------------------------------------------+---------------------+--------------------+------------+
| pyspark.ml.tests ||| Tests |
+------------------------------------------------+---------------------+--------------------+------------+
| pyspark.ml.tree | x || |
+------------------------------------------------+---------------------+--------------------+------------+
| pyspark.ml.tuning ||| |
+------------------------------------------------+---------------------+--------------------+------------+
| pyspark.ml.util ||| |
Expand Down
38 changes: 19 additions & 19 deletions third_party/3/pyspark/ml/classification.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ from pyspark.ml._typing import JM, M, P, T, ParamMap
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, _TreeEnsembleModel, _RandomForestParams, _GBTParams, _HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
from pyspark.ml.regression import DecisionTreeRegressionModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaPredictorParams, JavaWrapper, JavaTransformer
from pyspark.sql.dataframe import DataFrame
Expand Down Expand Up @@ -137,30 +138,29 @@ class BinaryLogisticRegressionSummary(LogisticRegressionSummary):

class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary): ...

class TreeClassifierParams:
supportedImpurities: List[str]
impurity: Param[str]
def __init__(self) -> None: ...
def getImpurity(self) -> str: ...
class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams): ...

class DecisionTreeClassifier(JavaProbabilisticClassifier[DecisionTreeClassificationModel], HasWeightCol, DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[DecisionTreeClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ...) -> DecisionTreeClassifier: ...
class DecisionTreeClassifier(JavaProbabilisticClassifier[DecisionTreeClassificationModel], _DecisionTreeClassifierParams, JavaMLWritable, JavaMLReadable[DecisionTreeClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> DecisionTreeClassifier: ...
def setMaxDepth(self, value: int) -> DecisionTreeClassifier: ...
def setMaxBins(self, value: int) -> DecisionTreeClassifier: ...
def setMinInstancesPerNode(self, value: int) -> DecisionTreeClassifier: ...
def setMinWeightFractionPerNode(self, value: float) -> DecisionTreeClassifier: ...
def setMinInfoGain(self, value: float) -> DecisionTreeClassifier: ...
def setMaxMemoryInMB(self, value: int) -> DecisionTreeClassifier: ...
def setCacheNodeIds(self, value: bool) -> DecisionTreeClassifier: ...
def setImpurity(self, value: str) -> DecisionTreeClassifier: ...

class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[DecisionTreeClassificationModel]):
class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel[Vector], _DecisionTreeClassifierParams, JavaMLWritable, JavaMLReadable[DecisionTreeClassificationModel]):
@property
def featureImportances(self) -> Vector: ...

class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificationModel], HasSeed, RandomForestParams, TreeClassifierParams, HasCheckpointInterval, JavaMLWritable, JavaMLReadable[RandomForestClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., seed: Optional[int] = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., subsamplingRate: float = ..., leafCol: str = ...) -> RandomForestClassifier: ...
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): ...

class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificationModel], _RandomForestClassifierParams, JavaMLWritable, JavaMLReadable[RandomForestClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., seed: Optional[int] = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., subsamplingRate: float = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> RandomForestClassifier: ...
def setMaxDepth(self, value: int) -> RandomForestClassifier: ...
def setMaxBins(self, value: int) -> RandomForestClassifier: ...
def setMinInstancesPerNode(self, value: int) -> RandomForestClassifier: ...
Expand All @@ -172,20 +172,20 @@ class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificat
def setSubsamplingRate(self, value: float) -> RandomForestClassifier: ...
def setFeatureSubsetStrategy(self, value: str) -> RandomForestClassifier: ...

class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[RandomForestClassificationModel]):
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], _RandomForestClassifierParams, JavaMLWritable, JavaMLReadable[RandomForestClassificationModel]):
@property
def featureImportances(self) -> Vector: ...
@property
def trees(self) -> List[DecisionTreeClassificationModel]: ...

class GBTClassifierParams(GBTParams, HasVarianceImpurity):
class GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
supportedLossTypes: List[str]
lossType: Param[str]
def getLossType(self) -> str: ...

class GBTClassifier(JavaProbabilisticClassifier[GBTClassificationModel], GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[GBTClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ...) -> GBTClassifier: ...
class GBTClassifier(JavaProbabilisticClassifier[GBTClassificationModel], GBTClassifierParams, JavaMLWritable, JavaMLReadable[GBTClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ..., minWeightFractionPerNode: float = ...) -> GBTClassifier: ...
def setMaxDepth(self, value: int) -> GBTClassifier: ...
def setMaxBins(self, value: int) -> GBTClassifier: ...
def setMinInstancesPerNode(self, value: int) -> GBTClassifier: ...
Expand All @@ -198,7 +198,7 @@ class GBTClassifier(JavaProbabilisticClassifier[GBTClassificationModel], GBTClas
def setFeatureSubsetStrategy(self, value: str) -> GBTClassifier: ...
def setValidationIndicatorCol(self, value: str) -> GBTClassifier: ...

class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[GBTClassificationModel]):
class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], GBTClassifierParams, JavaMLWritable, JavaMLReadable[GBTClassificationModel]):
@property
def featureImportances(self) -> Vector: ...
@property
Expand Down
Loading

0 comments on commit 93e6e41

Please sign in to comment.