Source code for pyspark.ml.tree

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Sequence, TypeVar, TYPE_CHECKING

from pyspark import since
from pyspark.ml.linalg import Vector
from pyspark.ml.param import Params
from pyspark.ml.param.shared import (
    HasCheckpointInterval,
    HasSeed,
    HasWeightCol,
    Param,
    TypeConverters,
    HasMaxIter,
    HasStepSize,
    HasValidationIndicatorCol,
)
from pyspark.ml.wrapper import JavaPredictionModel
from pyspark.ml.common import inherit_doc

if TYPE_CHECKING:
    from pyspark.ml._typing import P

T = TypeVar("T")


@inherit_doc
class _DecisionTreeModel(JavaPredictionModel[T]):
    """
    Abstraction for Decision Tree models.

    .. versionadded:: 1.5.0
    """

    @property  # type: ignore[misc]
    @since("1.5.0")
    def numNodes(self) -> int:
        """Return number of nodes of the decision tree."""
        return self._call_java("numNodes")

    @property  # type: ignore[misc]
    @since("1.5.0")
    def depth(self) -> int:
        """Return depth of the decision tree."""
        return self._call_java("depth")

    @property  # type: ignore[misc]
    @since("2.0.0")
    def toDebugString(self) -> str:
        """Full description of model."""
        return self._call_java("toDebugString")

    @since("3.0.0")
    def predictLeaf(self, value: Vector) -> float:
        """
        Predict the indices of the leaves corresponding to the feature vector.
        """
        return self._call_java("predictLeaf", value)


class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
    """
    Mixin for Decision Tree parameters.
    """

    leafCol: Param[str] = Param(
        Params._dummy(),
        "leafCol",
        "Leaf indices column name. Predicted leaf "
        + "index of each instance in each tree by preorder.",
        typeConverter=TypeConverters.toString,
    )

    maxDepth: Param[int] = Param(
        Params._dummy(),
        "maxDepth",
        "Maximum depth of the tree. (>= 0) E.g., "
        + "depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. "
        + "Must be in range [0, 30].",
        typeConverter=TypeConverters.toInt,
    )

    maxBins: Param[int] = Param(
        Params._dummy(),
        "maxBins",
        "Max number of bins for discretizing continuous "
        + "features.  Must be >=2 and >= number of categories for any categorical "
        + "feature.",
        typeConverter=TypeConverters.toInt,
    )

    minInstancesPerNode: Param[int] = Param(
        Params._dummy(),
        "minInstancesPerNode",
        "Minimum number of "
        + "instances each child must have after split. If a split causes "
        + "the left or right child to have fewer than "
        + "minInstancesPerNode, the split will be discarded as invalid. "
        + "Should be >= 1.",
        typeConverter=TypeConverters.toInt,
    )

    minWeightFractionPerNode: Param[float] = Param(
        Params._dummy(),
        "minWeightFractionPerNode",
        "Minimum "
        "fraction of the weighted sample count that each child "
        "must have after split. If a split causes the fraction "
        "of the total weight in the left or right child to be "
        "less than minWeightFractionPerNode, the split will be "
        "discarded as invalid. Should be in interval [0.0, 0.5).",
        typeConverter=TypeConverters.toFloat,
    )

    minInfoGain: Param[float] = Param(
        Params._dummy(),
        "minInfoGain",
        "Minimum information gain for a split " + "to be considered at a tree node.",
        typeConverter=TypeConverters.toFloat,
    )

    maxMemoryInMB: Param[int] = Param(
        Params._dummy(),
        "maxMemoryInMB",
        "Maximum memory in MB allocated to "
        + "histogram aggregation. If too small, then 1 node will be split per "
        + "iteration, and its aggregates may exceed this size.",
        typeConverter=TypeConverters.toInt,
    )

    cacheNodeIds: Param[bool] = Param(
        Params._dummy(),
        "cacheNodeIds",
        "If false, the algorithm will pass "
        + "trees to executors to match instances with nodes. If true, the "
        + "algorithm will cache node IDs for each instance. Caching can speed "
        + "up training of deeper trees. Users can set how often should the cache "
        + "be checkpointed or disable it by setting checkpointInterval.",
        typeConverter=TypeConverters.toBoolean,
    )

    def __init__(self) -> None:
        super(_DecisionTreeParams, self).__init__()

    def setLeafCol(self: "P", value: str) -> "P":
        """
        Sets the value of :py:attr:`leafCol`.
        """
        return self._set(leafCol=value)

    def getLeafCol(self) -> str:
        """
        Gets the value of leafCol or its default value.
        """
        return self.getOrDefault(self.leafCol)

    def getMaxDepth(self) -> int:
        """
        Gets the value of maxDepth or its default value.
        """
        return self.getOrDefault(self.maxDepth)

    def getMaxBins(self) -> int:
        """
        Gets the value of maxBins or its default value.
        """
        return self.getOrDefault(self.maxBins)

    def getMinInstancesPerNode(self) -> int:
        """
        Gets the value of minInstancesPerNode or its default value.
        """
        return self.getOrDefault(self.minInstancesPerNode)

    def getMinWeightFractionPerNode(self) -> float:
        """
        Gets the value of minWeightFractionPerNode or its default value.
        """
        return self.getOrDefault(self.minWeightFractionPerNode)

    def getMinInfoGain(self) -> float:
        """
        Gets the value of minInfoGain or its default value.
        """
        return self.getOrDefault(self.minInfoGain)

    def getMaxMemoryInMB(self) -> int:
        """
        Gets the value of maxMemoryInMB or its default value.
        """
        return self.getOrDefault(self.maxMemoryInMB)

    def getCacheNodeIds(self) -> bool:
        """
        Gets the value of cacheNodeIds or its default value.
        """
        return self.getOrDefault(self.cacheNodeIds)


@inherit_doc
class _TreeEnsembleModel(JavaPredictionModel[T]):
    """
    (private abstraction)
    Represents a tree ensemble model.
    """

    @property  # type: ignore[misc]
    @since("2.0.0")
    def trees(self) -> Sequence["_DecisionTreeModel"]:
        """Trees in this ensemble. Warning: These have null parent Estimators."""
        return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))]

    @property  # type: ignore[misc]
    @since("2.0.0")
    def getNumTrees(self) -> int:
        """Number of trees in ensemble."""
        return self._call_java("getNumTrees")

    @property  # type: ignore[misc]
    @since("1.5.0")
    def treeWeights(self) -> List[float]:
        """Return the weights for each tree"""
        return list(self._call_java("javaTreeWeights"))

    @property  # type: ignore[misc]
    @since("2.0.0")
    def totalNumNodes(self) -> int:
        """Total number of nodes, summed over all trees in the ensemble."""
        return self._call_java("totalNumNodes")

    @property  # type: ignore[misc]
    @since("2.0.0")
    def toDebugString(self) -> str:
        """Full description of model."""
        return self._call_java("toDebugString")

    @since("3.0.0")
    def predictLeaf(self, value: Vector) -> float:
        """
        Predict the indices of the leaves corresponding to the feature vector.
        """
        return self._call_java("predictLeaf", value)


class _TreeEnsembleParams(_DecisionTreeParams):
    """
    Mixin for Decision Tree-based ensemble algorithms parameters.
    """

    subsamplingRate: Param[float] = Param(
        Params._dummy(),
        "subsamplingRate",
        "Fraction of the training data " + "used for learning each decision tree, in range (0, 1].",
        typeConverter=TypeConverters.toFloat,
    )

    supportedFeatureSubsetStrategies: List[str] = ["auto", "all", "onethird", "sqrt", "log2"]

    featureSubsetStrategy: Param[str] = Param(
        Params._dummy(),
        "featureSubsetStrategy",
        "The number of features to consider for splits at each tree node. Supported "
        + "options: 'auto' (choose automatically for task: If numTrees == 1, set to "
        + "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to "
        + "'onethird' for regression), 'all' (use all features), 'onethird' (use "
        + "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use "
        + "log2(number of features)), 'n' (when n is in the range (0, 1.0], use "
        + "n * number of features. When n is in the range (1, number of features), use"
        + " n features). default = 'auto'",
        typeConverter=TypeConverters.toString,
    )

    def __init__(self) -> None:
        super(_TreeEnsembleParams, self).__init__()

    @since("1.4.0")
    def getSubsamplingRate(self) -> float:
        """
        Gets the value of subsamplingRate or its default value.
        """
        return self.getOrDefault(self.subsamplingRate)

    @since("1.4.0")
    def getFeatureSubsetStrategy(self) -> str:
        """
        Gets the value of featureSubsetStrategy or its default value.
        """
        return self.getOrDefault(self.featureSubsetStrategy)


class _RandomForestParams(_TreeEnsembleParams):
    """
    Private class to track supported random forest parameters.
    """

    numTrees: Param[int] = Param(
        Params._dummy(),
        "numTrees",
        "Number of trees to train (>= 1).",
        typeConverter=TypeConverters.toInt,
    )

    bootstrap: Param[bool] = Param(
        Params._dummy(),
        "bootstrap",
        "Whether bootstrap samples are used " "when building trees.",
        typeConverter=TypeConverters.toBoolean,
    )

    def __init__(self) -> None:
        super(_RandomForestParams, self).__init__()

    @since("1.4.0")
    def getNumTrees(self) -> int:
        """
        Gets the value of numTrees or its default value.
        """
        return self.getOrDefault(self.numTrees)

    @since("3.0.0")
    def getBootstrap(self) -> bool:
        """
        Gets the value of bootstrap or its default value.
        """
        return self.getOrDefault(self.bootstrap)


class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
    """
    Private class to track supported GBT params.
    """

    stepSize: Param[float] = Param(
        Params._dummy(),
        "stepSize",
        "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking "
        + "the contribution of each estimator.",
        typeConverter=TypeConverters.toFloat,
    )

    validationTol: Param[float] = Param(
        Params._dummy(),
        "validationTol",
        "Threshold for stopping early when fit with validation is used. "
        + "If the error rate on the validation input changes by less than the "
        + "validationTol, then learning will stop early (before `maxIter`). "
        + "This parameter is ignored when fit without validation is used.",
        typeConverter=TypeConverters.toFloat,
    )

    @since("3.0.0")
    def getValidationTol(self) -> float:
        """
        Gets the value of validationTol or its default value.
        """
        return self.getOrDefault(self.validationTol)


class _HasVarianceImpurity(Params):
    """
    Private class to track supported impurity measures.
    """

    supportedImpurities: List[str] = ["variance"]

    impurity: Param[str] = Param(
        Params._dummy(),
        "impurity",
        "Criterion used for information gain calculation (case-insensitive). "
        + "Supported options: "
        + ", ".join(supportedImpurities),
        typeConverter=TypeConverters.toString,
    )

    def __init__(self) -> None:
        super(_HasVarianceImpurity, self).__init__()

    @since("1.4.0")
    def getImpurity(self) -> str:
        """
        Gets the value of impurity or its default value.
        """
        return self.getOrDefault(self.impurity)


class _TreeClassifierParams(Params):
    """
    Private class to track supported impurity measures.

    .. versionadded:: 1.4.0
    """

    supportedImpurities: List[str] = ["entropy", "gini"]

    impurity: Param[str] = Param(
        Params._dummy(),
        "impurity",
        "Criterion used for information gain calculation (case-insensitive). "
        + "Supported options: "
        + ", ".join(supportedImpurities),
        typeConverter=TypeConverters.toString,
    )

    def __init__(self) -> None:
        super(_TreeClassifierParams, self).__init__()

    @since("1.6.0")
    def getImpurity(self) -> str:
        """
        Gets the value of impurity or its default value.
        """
        return self.getOrDefault(self.impurity)


class _TreeRegressorParams(_HasVarianceImpurity):
    """
    Private class to track supported impurity measures.
    """

    pass