AbstractDiscreteDistribution.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.statistics.distribution;

import java.util.function.IntUnaryOperator;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;

/**
 * Base class for integer-valued discrete distributions.  Default
 * implementations are provided for some of the methods that do not vary
 * from distribution to distribution.
 *
 * <p>This base class provides a default factory method for creating
 * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the
 * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
 * inversion method</a> for generating random samples that follow the
 * distribution.
 *
 * <p>The class provides functionality to evaluate the probability in a range
 * using either the cumulative probability or the survival probability.
 * The survival probability is used if both arguments to
 * {@link #probability(int, int)} are above the median.
 * Child classes with a known median can override the default {@link #getMedian()}
 * method.
 */
abstract class AbstractDiscreteDistribution
    implements DiscreteDistribution {
    /** Marker value for no median.
     * This is a long to be outside the value of any possible int valued median. */
    private static final long NO_MEDIAN = Long.MIN_VALUE;

    /** Cached value of the median. */
    private long median = NO_MEDIAN;

    /**
     * Gets the median. This is used to determine if the arguments to the
     * {@link #probability(int, int)} function are in the upper or lower domain.
     *
     * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
     * with a value of 0.5.
     *
     * @return the median
     */
    int getMedian() {
        long m = median;
        if (m == NO_MEDIAN) {
            median = m = inverseCumulativeProbability(0.5);
        }
        return (int) m;
    }

    /** {@inheritDoc} */
    @Override
    public double probability(int x0,
                              int x1) {
        if (x0 > x1) {
            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
        }
        // As per the default interface method handle special cases:
        // x0     = x1 : return 0
        // x0 + 1 = x1 : return probability(x1)
        // Long addition avoids overflow
        if (x0 + 1L >= x1) {
            return x0 == x1 ? 0.0 : probability(x1);
        }

        // Use the survival probability when in the upper domain [3]:
        //
        //  lower          median         upper
        //    |              |              |
        // 1.     |------|
        //        x0     x1
        // 2.         |----------|
        //            x0         x1
        // 3.                  |--------|
        //                     x0       x1

        final double m = getMedian();
        if (x0 >= m) {
            return survivalProbability(x0) - survivalProbability(x1);
        }
        return cumulativeProbability(x1) - cumulativeProbability(x0);
    }

    /**
     * {@inheritDoc}
     *
     * <p>The default implementation returns:
     * <ul>
     * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
     * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
     * <li>the result of a binary search between the lower and upper bound using
     *     {@link #cumulativeProbability(int) cumulativeProbability(x)}.
     *     The bounds may be bracketed for efficiency.</li>
     * </ul>
     *
     * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
     */
    @Override
    public int inverseCumulativeProbability(double p) {
        ArgumentUtils.checkProbability(p);
        return inverseProbability(p, 1 - p, false);
    }

    /**
     * {@inheritDoc}
     *
     * <p>The default implementation returns:
     * <ul>
     * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
     * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
     * <li>the result of a binary search between the lower and upper bound using
     *     {@link #survivalProbability(int) survivalProbability(x)}.
     *     The bounds may be bracketed for efficiency.</li>
     * </ul>
     *
     * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
     */
    @Override
    public int inverseSurvivalProbability(double p) {
        ArgumentUtils.checkProbability(p);
        return inverseProbability(1 - p, p, true);
    }

    /**
     * Implementation for the inverse cumulative or survival probability.
     *
     * @param p Cumulative probability.
     * @param q Survival probability.
     * @param complement Set to true to compute the inverse survival probability
     * @return the value
     */
    private int inverseProbability(double p, double q, boolean complement) {

        int lower = getSupportLowerBound();
        if (p == 0) {
            return lower;
        }
        int upper = getSupportUpperBound();
        if (q == 0) {
            return upper;
        }

        // The binary search sets the upper value to the mid-point
        // based on fun(x) >= 0. The upper value is returned.
        //
        // Create a function to search for x where the upper bound can be
        // lowered if:
        // cdf(x) >= p
        // sf(x)  <= q
        final IntUnaryOperator fun = complement ?
            x -> Double.compare(q, survivalProbability(x)) :
            x -> Double.compare(cumulativeProbability(x), p);

        if (lower == Integer.MIN_VALUE) {
            if (fun.applyAsInt(lower) >= 0) {
                return lower;
            }
        } else {
            // this ensures:
            // cumulativeProbability(lower) < p
            // survivalProbability(lower) > q
            // which is important for the solving step
            lower -= 1;
        }

        // use the one-sided Chebyshev inequality to narrow the bracket
        // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double)
        final double mu = getMean();
        final double sig = Math.sqrt(getVariance());
        final boolean chebyshevApplies = Double.isFinite(mu) &&
                                         ArgumentUtils.isFiniteStrictlyPositive(sig);

        if (chebyshevApplies) {
            double tmp = mu - sig * Math.sqrt(q / p);
            if (tmp > lower) {
                lower = ((int) Math.ceil(tmp)) - 1;
            }
            tmp = mu + sig * Math.sqrt(p / q);
            if (tmp < upper) {
                upper = ((int) Math.ceil(tmp)) - 1;
            }
        }

        return solveInverseProbability(fun, lower, upper);
    }

    /**
     * This is a utility function used by {@link
     * #inverseProbability(double, double, boolean)}. It assumes
     * that the inverse probability lies in the bracket {@code
     * (lower, upper]}. The implementation does simple bisection to find the
     * smallest {@code x} such that {@code fun(x) >= 0}.
     *
     * @param fun Probability function.
     * @param lowerBound Value satisfying {@code fun(lower) < 0}.
     * @param upperBound Value satisfying {@code fun(upper) >= 0}.
     * @return the smallest x
     */
    private static int solveInverseProbability(IntUnaryOperator fun,
                                               int lowerBound,
                                               int upperBound) {
        // Use long to prevent overflow during computation of the middle
        long lower = lowerBound;
        long upper = upperBound;
        while (lower + 1 < upper) {
            // Note: Cannot replace division by 2 with a right shift because
            // (lower + upper) can be negative.
            final long middle = (lower + upper) / 2;
            final int pm = fun.applyAsInt((int) middle);
            if (pm < 0) {
                lower = middle;
            } else {
                upper = middle;
            }
        }
        return (int) upper;
    }

    /** {@inheritDoc} */
    @Override
    public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
        // Inversion method distribution sampler.
        return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
    }
}