# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import List, Sequence, TypeVar, TYPE_CHECKING from pyspark import since from pyspark.ml.linalg import Vector from pyspark.ml.param import Params from pyspark.ml.param.shared import ( HasCheckpointInterval, HasSeed, HasWeightCol, Param, TypeConverters, HasMaxIter, HasStepSize, HasValidationIndicatorCol, ) from pyspark.ml.wrapper import JavaPredictionModel from pyspark.ml.common import inherit_doc if TYPE_CHECKING: from pyspark.ml._typing import P T = TypeVar("T") @inherit_doc class _DecisionTreeModel(JavaPredictionModel[T]): """ Abstraction for Decision Tree models. .. versionadded:: 1.5.0 """ @property # type: ignore[misc] @since("1.5.0") def numNodes(self) -> int: """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property # type: ignore[misc] @since("1.5.0") def depth(self) -> int: """Return depth of the decision tree.""" return self._call_java("depth") @property # type: ignore[misc] @since("2.0.0") def toDebugString(self) -> str: """Full description of model.""" return self._call_java("toDebugString") @since("3.0.0") def predictLeaf(self, value: Vector) -> float: """ Predict the indices of the leaves corresponding to the feature vector. """ return self._call_java("predictLeaf", value) class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): """ Mixin for Decision Tree parameters. """ leafCol: Param[str] = Param( Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " + "index of each instance in each tree by preorder.", typeConverter=TypeConverters.toString, ) maxDepth: Param[int] = Param( Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " + "depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. " + "Must be in range [0, 30].", typeConverter=TypeConverters.toInt, ) maxBins: Param[int] = Param( Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " + "features. Must be >=2 and >= number of categories for any categorical " + "feature.", typeConverter=TypeConverters.toInt, ) minInstancesPerNode: Param[int] = Param( Params._dummy(), "minInstancesPerNode", "Minimum number of " + "instances each child must have after split. If a split causes " + "the left or right child to have fewer than " + "minInstancesPerNode, the split will be discarded as invalid. " + "Should be >= 1.", typeConverter=TypeConverters.toInt, ) minWeightFractionPerNode: Param[float] = Param( Params._dummy(), "minWeightFractionPerNode", "Minimum " "fraction of the weighted sample count that each child " "must have after split. If a split causes the fraction " "of the total weight in the left or right child to be " "less than minWeightFractionPerNode, the split will be " "discarded as invalid. Should be in interval [0.0, 0.5).", typeConverter=TypeConverters.toFloat, ) minInfoGain: Param[float] = Param( Params._dummy(), "minInfoGain", "Minimum information gain for a split " + "to be considered at a tree node.", typeConverter=TypeConverters.toFloat, ) maxMemoryInMB: Param[int] = Param( Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " + "histogram aggregation. If too small, then 1 node will be split per " + "iteration, and its aggregates may exceed this size.", typeConverter=TypeConverters.toInt, ) cacheNodeIds: Param[bool] = Param( Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " + "trees to executors to match instances with nodes. If true, the " + "algorithm will cache node IDs for each instance. Caching can speed " + "up training of deeper trees. Users can set how often should the cache " + "be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean, ) def __init__(self) -> None: super(_DecisionTreeParams, self).__init__() def setLeafCol(self: "P", value: str) -> "P": """ Sets the value of :py:attr:`leafCol`. """ return self._set(leafCol=value) def getLeafCol(self) -> str: """ Gets the value of leafCol or its default value. """ return self.getOrDefault(self.leafCol) def getMaxDepth(self) -> int: """ Gets the value of maxDepth or its default value. """ return self.getOrDefault(self.maxDepth) def getMaxBins(self) -> int: """ Gets the value of maxBins or its default value. """ return self.getOrDefault(self.maxBins) def getMinInstancesPerNode(self) -> int: """ Gets the value of minInstancesPerNode or its default value. """ return self.getOrDefault(self.minInstancesPerNode) def getMinWeightFractionPerNode(self) -> float: """ Gets the value of minWeightFractionPerNode or its default value. """ return self.getOrDefault(self.minWeightFractionPerNode) def getMinInfoGain(self) -> float: """ Gets the value of minInfoGain or its default value. """ return self.getOrDefault(self.minInfoGain) def getMaxMemoryInMB(self) -> int: """ Gets the value of maxMemoryInMB or its default value. """ return self.getOrDefault(self.maxMemoryInMB) def getCacheNodeIds(self) -> bool: """ Gets the value of cacheNodeIds or its default value. """ return self.getOrDefault(self.cacheNodeIds) @inherit_doc class _TreeEnsembleModel(JavaPredictionModel[T]): """ (private abstraction) Represents a tree ensemble model. """ @property # type: ignore[misc] @since("2.0.0") def trees(self) -> Sequence["_DecisionTreeModel"]: """Trees in this ensemble. Warning: These have null parent Estimators.""" return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))] @property # type: ignore[misc] @since("2.0.0") def getNumTrees(self) -> int: """Number of trees in ensemble.""" return self._call_java("getNumTrees") @property # type: ignore[misc] @since("1.5.0") def treeWeights(self) -> List[float]: """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @property # type: ignore[misc] @since("2.0.0") def totalNumNodes(self) -> int: """Total number of nodes, summed over all trees in the ensemble.""" return self._call_java("totalNumNodes") @property # type: ignore[misc] @since("2.0.0") def toDebugString(self) -> str: """Full description of model.""" return self._call_java("toDebugString") @since("3.0.0") def predictLeaf(self, value: Vector) -> float: """ Predict the indices of the leaves corresponding to the feature vector. """ return self._call_java("predictLeaf", value) class _TreeEnsembleParams(_DecisionTreeParams): """ Mixin for Decision Tree-based ensemble algorithms parameters. """ subsamplingRate: Param[float] = Param( Params._dummy(), "subsamplingRate", "Fraction of the training data " + "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat, ) supportedFeatureSubsetStrategies: List[str] = ["auto", "all", "onethird", "sqrt", "log2"] featureSubsetStrategy: Param[str] = Param( Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + "options: 'auto' (choose automatically for task: If numTrees == 1, set to " + "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " + "'onethird' for regression), 'all' (use all features), 'onethird' (use " + "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " + "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " + "n * number of features. When n is in the range (1, number of features), use" + " n features). default = 'auto'", typeConverter=TypeConverters.toString, ) def __init__(self) -> None: super(_TreeEnsembleParams, self).__init__() @since("1.4.0") def getSubsamplingRate(self) -> float: """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) @since("1.4.0") def getFeatureSubsetStrategy(self) -> str: """ Gets the value of featureSubsetStrategy or its default value. """ return self.getOrDefault(self.featureSubsetStrategy) class _RandomForestParams(_TreeEnsembleParams): """ Private class to track supported random forest parameters. """ numTrees: Param[int] = Param( Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt, ) bootstrap: Param[bool] = Param( Params._dummy(), "bootstrap", "Whether bootstrap samples are used " "when building trees.", typeConverter=TypeConverters.toBoolean, ) def __init__(self) -> None: super(_RandomForestParams, self).__init__() @since("1.4.0") def getNumTrees(self) -> int: """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) @since("3.0.0") def getBootstrap(self) -> bool: """ Gets the value of bootstrap or its default value. """ return self.getOrDefault(self.bootstrap) class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol): """ Private class to track supported GBT params. """ stepSize: Param[float] = Param( Params._dummy(), "stepSize", "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + "the contribution of each estimator.", typeConverter=TypeConverters.toFloat, ) validationTol: Param[float] = Param( Params._dummy(), "validationTol", "Threshold for stopping early when fit with validation is used. " + "If the error rate on the validation input changes by less than the " + "validationTol, then learning will stop early (before `maxIter`). " + "This parameter is ignored when fit without validation is used.", typeConverter=TypeConverters.toFloat, ) @since("3.0.0") def getValidationTol(self) -> float: """ Gets the value of validationTol or its default value. """ return self.getOrDefault(self.validationTol) class _HasVarianceImpurity(Params): """ Private class to track supported impurity measures. """ supportedImpurities: List[str] = ["variance"] impurity: Param[str] = Param( Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString, ) def __init__(self) -> None: super(_HasVarianceImpurity, self).__init__() @since("1.4.0") def getImpurity(self) -> str: """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) class _TreeClassifierParams(Params): """ Private class to track supported impurity measures. .. versionadded:: 1.4.0 """ supportedImpurities: List[str] = ["entropy", "gini"] impurity: Param[str] = Param( Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString, ) def __init__(self) -> None: super(_TreeClassifierParams, self).__init__() @since("1.6.0") def getImpurity(self) -> str: """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) class _TreeRegressorParams(_HasVarianceImpurity): """ Private class to track supported impurity measures. """ pass