AbstractContinuousDistribution.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.statistics.distribution;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import org.apache.commons.numbers.rootfinder.BrentSolver;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler;
/**
* Base class for probability distributions on the reals.
* Default implementations are provided for some of the methods
* that do not vary from distribution to distribution.
*
* <p>This base class provides a default factory method for creating
* a {@linkplain ContinuousDistribution.Sampler sampler instance} that uses the
* <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
* inversion method</a> for generating random samples that follow the
* distribution.
*
* <p>The class provides functionality to evaluate the probability in a range
* using either the cumulative probability or the survival probability.
* The survival probability is used if both arguments to
* {@link #probability(double, double)} are above the median.
* Child classes with a known median can override the default {@link #getMedian()}
* method.
*/
abstract class AbstractContinuousDistribution
implements ContinuousDistribution {
// Notes on the inverse probability implementation:
//
// The Brent solver does not allow a stopping criteria for the proximity
// to the root; it uses equality to zero within 1 ULP. The search is
// iterated until there is a small difference between the upper
// and lower bracket of the root, expressed as a combination of relative
// and absolute thresholds.
/** BrentSolver relative accuracy.
* This is used with {@code tol = 2 * relEps * abs(b) + absEps} so the minimum
* non-zero value with an effect is half of machine epsilon (2^-53). */
private static final double SOLVER_RELATIVE_ACCURACY = 0x1.0p-53;
/** BrentSolver absolute accuracy.
* This is used with {@code tol = 2 * relEps * abs(b) + absEps} so set to MIN_VALUE
* so that when the relative epsilon has no effect (as b is too small) the tolerance
* is at least 1 ULP for sub-normal numbers. */
private static final double SOLVER_ABSOLUTE_ACCURACY = Double.MIN_VALUE;
/** BrentSolver function value accuracy.
* Determines if the Brent solver performs a search. It is not used during the search.
* Set to a very low value to search using Brent's method unless
* the starting point is correct, or within 1 ULP for sub-normal probabilities. */
private static final double SOLVER_FUNCTION_VALUE_ACCURACY = Double.MIN_VALUE;
/** Cached value of the median. */
private double median = Double.NaN;
/**
* Gets the median. This is used to determine if the arguments to the
* {@link #probability(double, double)} function are in the upper or lower domain.
*
* <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
* with a value of 0.5.
*
* @return the median
*/
double getMedian() {
double m = median;
if (Double.isNaN(m)) {
median = m = inverseCumulativeProbability(0.5);
}
return m;
}
/** {@inheritDoc} */
@Override
public double probability(double x0,
double x1) {
if (x0 > x1) {
throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
}
// Use the survival probability when in the upper domain [3]:
//
// lower median upper
// | | |
// 1. |------|
// x0 x1
// 2. |----------|
// x0 x1
// 3. |--------|
// x0 x1
final double m = getMedian();
if (x0 >= m) {
return survivalProbability(x0) - survivalProbability(x1);
}
return cumulativeProbability(x1) - cumulativeProbability(x0);
}
/**
* {@inheritDoc}
*
* <p>The default implementation returns:
* <ul>
* <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
* <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
* <li>the result of a search for a root between the lower and upper bound using
* {@link #cumulativeProbability(double) cumulativeProbability(x) - p}.
* The bounds may be bracketed for efficiency.</li>
* </ul>
*
* @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
*/
@Override
public double inverseCumulativeProbability(double p) {
ArgumentUtils.checkProbability(p);
return inverseProbability(p, 1 - p, false);
}
/**
* {@inheritDoc}
*
* <p>The default implementation returns:
* <ul>
* <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
* <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
* <li>the result of a search for a root between the lower and upper bound using
* {@link #survivalProbability(double) survivalProbability(x) - p}.
* The bounds may be bracketed for efficiency.</li>
* </ul>
*
* @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
*/
@Override
public double inverseSurvivalProbability(double p) {
ArgumentUtils.checkProbability(p);
return inverseProbability(1 - p, p, true);
}
/**
* Implementation for the inverse cumulative or survival probability.
*
* @param p Cumulative probability.
* @param q Survival probability.
* @param complement Set to true to compute the inverse survival probability
* @return the value
*/
private double inverseProbability(final double p, final double q, boolean complement) {
/* IMPLEMENTATION NOTES
* --------------------
* Where applicable, use is made of the one-sided Chebyshev inequality
* to bracket the root. This inequality states that
* P(X - mu >= k * sig) <= 1 / (1 + k^2),
* mu: mean, sig: standard deviation. Equivalently
* 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
* F(mu + k * sig) >= k^2 / (1 + k^2).
*
* For k = sqrt(p / (1 - p)), we find
* F(mu + k * sig) >= p,
* and (mu + k * sig) is an upper-bound for the root.
*
* Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
* P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
* P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
* P(X <= mu - k * sig) <= 1 / (1 + k^2),
* F(mu - k * sig) <= 1 / (1 + k^2).
*
* For k = sqrt((1 - p) / p), we find
* F(mu - k * sig) <= p,
* and (mu - k * sig) is a lower-bound for the root.
*
* In cases where the Chebyshev inequality does not apply, geometric
* progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
* the root.
*
* In the case of the survival probability the bracket can be set using the same
* bound given that the argument p = 1 - q, with q the survival probability.
*/
double lowerBound = getSupportLowerBound();
if (p == 0) {
return lowerBound;
}
double upperBound = getSupportUpperBound();
if (q == 0) {
return upperBound;
}
final double mu = getMean();
final double sig = Math.sqrt(getVariance());
final boolean chebyshevApplies = Double.isFinite(mu) &&
ArgumentUtils.isFiniteStrictlyPositive(sig);
if (lowerBound == Double.NEGATIVE_INFINITY) {
lowerBound = createFiniteLowerBound(p, q, complement, upperBound, mu, sig, chebyshevApplies);
}
if (upperBound == Double.POSITIVE_INFINITY) {
upperBound = createFiniteUpperBound(p, q, complement, lowerBound, mu, sig, chebyshevApplies);
}
// Here the bracket [lower, upper] uses finite values. If the support
// is infinite the bracket can truncate the distribution and the target
// probability can be outside the range of [lower, upper].
if (upperBound == Double.MAX_VALUE) {
if (complement) {
if (survivalProbability(upperBound) > q) {
return getSupportUpperBound();
}
} else if (cumulativeProbability(upperBound) < p) {
return getSupportUpperBound();
}
}
if (lowerBound == -Double.MAX_VALUE) {
if (complement) {
if (survivalProbability(lowerBound) < q) {
return getSupportLowerBound();
}
} else if (cumulativeProbability(lowerBound) > p) {
return getSupportLowerBound();
}
}
final DoubleUnaryOperator fun = complement ?
arg -> survivalProbability(arg) - q :
arg -> cumulativeProbability(arg) - p;
// Note the initial value is robust to overflow.
// Do not use 0.5 * (lowerBound + upperBound).
final double x = new BrentSolver(SOLVER_RELATIVE_ACCURACY,
SOLVER_ABSOLUTE_ACCURACY,
SOLVER_FUNCTION_VALUE_ACCURACY)
.findRoot(fun,
lowerBound,
lowerBound + 0.5 * (upperBound - lowerBound),
upperBound);
if (!isSupportConnected()) {
return searchPlateau(complement, lowerBound, x);
}
return x;
}
/**
* Create a finite lower bound. Assumes the current lower bound is negative infinity.
*
* @param p Cumulative probability.
* @param q Survival probability.
* @param complement Set to true to compute the inverse survival probability
* @param upperBound Current upper bound
* @param mu Mean
* @param sig Standard deviation
* @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
* @return the finite lower bound
*/
private double createFiniteLowerBound(final double p, final double q, boolean complement,
double upperBound, final double mu, final double sig, final boolean chebyshevApplies) {
double lowerBound;
if (chebyshevApplies) {
lowerBound = mu - sig * Math.sqrt(q / p);
} else {
lowerBound = Double.NEGATIVE_INFINITY;
}
// Bound may have been set as infinite
if (lowerBound == Double.NEGATIVE_INFINITY) {
lowerBound = Math.min(-1, upperBound);
if (complement) {
while (survivalProbability(lowerBound) < q) {
lowerBound *= 2;
}
} else {
while (cumulativeProbability(lowerBound) >= p) {
lowerBound *= 2;
}
}
// Ensure finite
lowerBound = Math.max(lowerBound, -Double.MAX_VALUE);
}
return lowerBound;
}
/**
* Create a finite upper bound. Assumes the current upper bound is positive infinity.
*
* @param p Cumulative probability.
* @param q Survival probability.
* @param complement Set to true to compute the inverse survival probability
* @param lowerBound Current lower bound
* @param mu Mean
* @param sig Standard deviation
* @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
* @return the finite lower bound
*/
private double createFiniteUpperBound(final double p, final double q, boolean complement,
double lowerBound, final double mu, final double sig, final boolean chebyshevApplies) {
double upperBound;
if (chebyshevApplies) {
upperBound = mu + sig * Math.sqrt(p / q);
} else {
upperBound = Double.POSITIVE_INFINITY;
}
// Bound may have been set as infinite
if (upperBound == Double.POSITIVE_INFINITY) {
upperBound = Math.max(1, lowerBound);
if (complement) {
while (survivalProbability(upperBound) >= q) {
upperBound *= 2;
}
} else {
while (cumulativeProbability(upperBound) < p) {
upperBound *= 2;
}
}
// Ensure finite
upperBound = Math.min(upperBound, Double.MAX_VALUE);
}
return upperBound;
}
/**
* Indicates whether the support is connected, i.e. whether all values between the
* lower and upper bound of the support are included in the support.
*
* <p>This method is used in the default implementation of the inverse cumulative and
* survival probability functions.
*
* <p>The default value is true which assumes the cdf and sf have no plateau regions
* where the same probability value is returned for a large range of x.
* Override this method if there are gaps in the support of the cdf and sf.
*
* <p>If false then the inverse will perform an additional step to ensure that the
* lower-bound of the interval on which the cdf is constant should be returned. This
* will search from the initial point x downwards if a smaller value also has the same
* cumulative (survival) probability.
*
* <p>Any plateau with a width in x smaller than the inverse absolute accuracy will
* not be searched.
*
* <p>Note: This method was public in commons math. It has been reduced to package private
* in commons statistics as it is an implementation detail.
*
* @return whether the support is connected.
* @see <a href="https://issues.apache.org/jira/browse/MATH-699">MATH-699</a>
*/
boolean isSupportConnected() {
return true;
}
/**
* Test the probability function for a plateau at the point x. If detected
* search the plateau for the lowest point y such that
* {@code inf{y in R | P(y) == P(x)}}.
*
* <p>This function is used when the distribution support is not connected
* to satisfy the inverse probability requirements of {@link ContinuousDistribution}
* on the returned value.
*
* @param complement Set to true to search the survival probability.
* @param lower Lower bound used to limit the search downwards.
* @param x Current value.
* @return the infimum y
*/
private double searchPlateau(boolean complement, double lower, final double x) {
// Test for plateau. Lower the value x if the probability is the same.
// Ensure the step is robust to the solver accuracy being less
// than 1 ulp of x (e.g. dx=0 will infinite loop)
final double dx = Math.max(SOLVER_ABSOLUTE_ACCURACY, Math.ulp(x));
if (x - dx >= lower) {
final DoubleUnaryOperator fun = complement ?
this::survivalProbability :
this::cumulativeProbability;
final double px = fun.applyAsDouble(x);
if (fun.applyAsDouble(x - dx) == px) {
double upperBound = x;
double lowerBound = lower;
// Bisection search
// Require cdf(x) < px and sf(x) > px to move the lower bound
// to the midpoint.
final DoubleBinaryOperator cmp = complement ?
(a, b) -> a > b ? -1 : 1 :
(a, b) -> a < b ? -1 : 1;
while (upperBound - lowerBound > dx) {
final double midPoint = 0.5 * (lowerBound + upperBound);
if (cmp.applyAsDouble(fun.applyAsDouble(midPoint), px) < 0) {
lowerBound = midPoint;
} else {
upperBound = midPoint;
}
}
return upperBound;
}
}
return x;
}
/** {@inheritDoc} */
@Override
public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
// Inversion method distribution sampler.
return InverseTransformContinuousSampler.of(rng, this::inverseCumulativeProbability)::sample;
}
}