NaturalRanking.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.statistics.ranking;
import java.util.Arrays;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.function.DoubleUnaryOperator;
import java.util.function.IntUnaryOperator;
/**
* Ranking based on the natural ordering on floating-point values.
*
* <p>{@link Double#NaN NaNs} are treated according to the configured
* {@link NaNStrategy} and ties are handled using the selected
* {@link TiesStrategy}. Configuration settings are supplied in optional
* constructor arguments. Defaults are {@link NaNStrategy#FAILED} and
* {@link TiesStrategy#AVERAGE}, respectively.
*
* <p>When using {@link TiesStrategy#RANDOM}, a generator of random values in {@code [0, x)}
* can be supplied as a {@link IntUnaryOperator} argument; otherwise a default is created
* on-demand. The source of randomness can be supplied using a method reference.
* The following example creates a ranking with NaN values with the highest
* ranking and ties resolved randomly:
*
* <pre>
* NaturalRanking ranking = new NaturalRanking(NaNStrategy.MAXIMAL,
* new SplittableRandom()::nextInt);
* </pre>
*
* <p>Note: Using {@link TiesStrategy#RANDOM} is not thread-safe due to the mutable
* generator of randomness. Instances not using random resolution of ties are
* thread-safe.
*
* <p>Examples:
*
* <table border="">
* <caption>Examples</caption>
* <tr><th colspan="3">
* Input data: [20, 17, 30, 42.3, 17, 50, Double.NaN, Double.NEGATIVE_INFINITY, 17]
* </th></tr>
* <tr><th>NaNStrategy</th><th>TiesStrategy</th>
* <th>{@code rank(data)}</th>
* <tr>
* <td>MAXIMAL</td>
* <td>default (ties averaged)</td>
* <td>[5, 3, 6, 7, 3, 8, 9, 1, 3]</td></tr>
* <tr>
* <td>MAXIMAL</td>
* <td>MINIMUM</td>
* <td>[5, 2, 6, 7, 2, 8, 9, 1, 2]</td></tr>
* <tr>
* <td>MINIMAL</td>
* <td>default (ties averaged]</td>
* <td>[6, 4, 7, 8, 4, 9, 1.5, 1.5, 4]</td></tr>
* <tr>
* <td>REMOVED</td>
* <td>SEQUENTIAL</td>
* <td>[5, 2, 6, 7, 3, 8, 1, 4]</td></tr>
* <tr>
* <td>MINIMAL</td>
* <td>MAXIMUM</td>
* <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
* <tr>
* <td>MINIMAL</td>
* <td>MAXIMUM</td>
* <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
* </table>
*
* @since 1.1
*/
public class NaturalRanking implements RankingAlgorithm {
/** Message for a null user-supplied {@link NaNStrategy}. */
private static final String NULL_NAN_STRATEGY = "nanStrategy";
/** Message for a null user-supplied {@link TiesStrategy}. */
private static final String NULL_TIES_STRATEGY = "tiesStrategy";
/** Message for a null user-supplied source of randomness. */
private static final String NULL_RANDOM_SOURCE = "randomIntFunction";
/** Default NaN strategy. */
private static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
/** Default ties strategy. */
private static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
/** Map values to positive infinity. */
private static final DoubleUnaryOperator ACTION_POS_INF = x -> Double.POSITIVE_INFINITY;
/** Map values to negative infinity. */
private static final DoubleUnaryOperator ACTION_NEG_INF = x -> Double.NEGATIVE_INFINITY;
/** Raise an exception for values. */
private static final DoubleUnaryOperator ACTION_ERROR = operand -> {
throw new IllegalArgumentException("Invalid data: " + operand);
};
/** NaN strategy. */
private final NaNStrategy nanStrategy;
/** Ties strategy. */
private final TiesStrategy tiesStrategy;
/** Source of randomness when ties strategy is RANDOM.
* Function maps positive x to {@code [0, x)}.
* Can be null to default to a JDK implementation. */
private IntUnaryOperator randomIntFunction;
/**
* Creates an instance with {@link NaNStrategy#FAILED} and
* {@link TiesStrategy#AVERAGE}.
*/
public NaturalRanking() {
this(DEFAULT_NAN_STRATEGY, DEFAULT_TIES_STRATEGY, null);
}
/**
* Creates an instance with {@link NaNStrategy#FAILED} and the
* specified @{@code tiesStrategy}.
*
* <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
* source of randomness is used to resolve ties.
*
* @param tiesStrategy TiesStrategy to use.
* @throws NullPointerException if the strategy is {@code null}
*/
public NaturalRanking(TiesStrategy tiesStrategy) {
this(DEFAULT_NAN_STRATEGY,
Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
}
/**
* Creates an instance with the specified @{@code nanStrategy} and
* {@link TiesStrategy#AVERAGE}.
*
* @param nanStrategy NaNStrategy to use.
* @throws NullPointerException if the strategy is {@code null}
*/
public NaturalRanking(NaNStrategy nanStrategy) {
this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
DEFAULT_TIES_STRATEGY, null);
}
/**
* Creates an instance with the specified @{@code nanStrategy} and the
* specified @{@code tiesStrategy}.
*
* <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
* source of randomness is used to resolve ties.
*
* @param nanStrategy NaNStrategy to use.
* @param tiesStrategy TiesStrategy to use.
* @throws NullPointerException if any strategy is {@code null}
*/
public NaturalRanking(NaNStrategy nanStrategy,
TiesStrategy tiesStrategy) {
this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
}
/**
* Creates an instance with {@link NaNStrategy#FAILED},
* {@link TiesStrategy#RANDOM} and the given the source of random index data.
*
* @param randomIntFunction Source of random index data.
* Function maps positive {@code x} randomly to {@code [0, x)}
* @throws NullPointerException if the source of randomness is {@code null}
*/
public NaturalRanking(IntUnaryOperator randomIntFunction) {
this(DEFAULT_NAN_STRATEGY, TiesStrategy.RANDOM,
Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
}
/**
* Creates an instance with the specified @{@code nanStrategy},
* {@link TiesStrategy#RANDOM} and the given the source of random index data.
*
* @param nanStrategy NaNStrategy to use.
* @param randomIntFunction Source of random index data.
* Function maps positive {@code x} randomly to {@code [0, x)}
* @throws NullPointerException if the strategy or source of randomness are {@code null}
*/
public NaturalRanking(NaNStrategy nanStrategy,
IntUnaryOperator randomIntFunction) {
this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY), TiesStrategy.RANDOM,
Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
}
/**
* @param nanStrategy NaNStrategy to use.
* @param tiesStrategy TiesStrategy to use.
* @param randomIntFunction Source of random index data.
*/
private NaturalRanking(NaNStrategy nanStrategy,
TiesStrategy tiesStrategy,
IntUnaryOperator randomIntFunction) {
// User-supplied arguments are checked for non-null in the respective constructor
this.nanStrategy = nanStrategy;
this.tiesStrategy = tiesStrategy;
this.randomIntFunction = randomIntFunction;
}
/**
* Return the {@link NaNStrategy}.
*
* @return the strategy for handling NaN
*/
public NaNStrategy getNanStrategy() {
return nanStrategy;
}
/**
* Return the {@link TiesStrategy}.
*
* @return the strategy for handling ties
*/
public TiesStrategy getTiesStrategy() {
return tiesStrategy;
}
/**
* Rank {@code data} using the natural ordering on floating-point values, with
* NaN values handled according to {@code nanStrategy} and ties resolved using
* {@code tiesStrategy}.
*
* @throws IllegalArgumentException if the selected {@link NaNStrategy} is
* {@code FAILED} and a {@link Double#NaN} is encountered in the input data.
*/
@Override
public double[] apply(double[] data) {
// Convert data for sorting.
// NaNs are counted for the FIXED strategy.
final int[] nanCount = {0};
final DataPosition[] ranks = createRankData(data, nanCount);
// Sorting will move NaNs to the end and we do not have to resolve ties in them.
final int nonNanSize = ranks.length - nanCount[0];
// Edge case for empty data
if (nonNanSize == 0) {
// Either NaN are left in-place or removed
return nanStrategy == NaNStrategy.FIXED ? data : new double[0];
}
Arrays.sort(ranks);
// Walk the sorted array, filling output array using sorted positions,
// resolving ties as we go.
int pos = 1;
final double[] out = new double[ranks.length];
DataPosition current = ranks[0];
out[current.getPosition()] = pos;
// Store all previous elements of a tie.
// Note this lags behind the length of the tie sequence by 1.
// In the event there are no ties this is not used.
final IntList tiesTrace = new IntList(ranks.length);
for (int i = 1; i < nonNanSize; i++) {
final DataPosition previous = current;
current = ranks[i];
if (current.compareTo(previous) > 0) {
// Check for a previous tie sequence
if (tiesTrace.size() != 0) {
resolveTie(out, tiesTrace, previous.getPosition());
}
pos = i + 1;
} else {
// Tie sequence. Add the matching previous element.
tiesTrace.add(previous.getPosition());
}
out[current.getPosition()] = pos;
}
// Handle tie sequence at end
if (tiesTrace.size() != 0) {
resolveTie(out, tiesTrace, current.getPosition());
}
// For the FIXED strategy consume the remaining NaN elements
if (nanStrategy == NaNStrategy.FIXED) {
for (int i = nonNanSize; i < ranks.length; i++) {
out[ranks[i].getPosition()] = Double.NaN;
}
}
return out;
}
/**
* Creates the rank data. If using {@link NaNStrategy#REMOVED} then NaNs are
* filtered. Otherwise NaNs may be mapped to an infinite value, counted to allow
* subsequent processing, or cause an exception to be thrown.
*
* @param data Source data.
* @param nanCount Output counter for NaN values.
* @return the rank data
* @throws IllegalArgumentException if the data contains NaN values when using
* {@link NaNStrategy#FAILED}.
*/
private DataPosition[] createRankData(double[] data, final int[] nanCount) {
return nanStrategy == NaNStrategy.REMOVED ?
createNonNaNRankData(data) :
createMappedRankData(data, createNaNAction(nanCount));
}
/**
* Creates the NaN action.
*
* @param nanCount Output counter for NaN values.
* @return the operator applied to NaN values
*/
private DoubleUnaryOperator createNaNAction(int[] nanCount) {
switch (nanStrategy) {
case MAXIMAL: // Replace NaNs with +INFs
return ACTION_POS_INF;
case MINIMAL: // Replace NaNs with -INFs
return ACTION_NEG_INF;
case REMOVED: // NaNs are removed
case FIXED: // NaNs are unchanged
// Count the NaNs in the data that must be handled
return x -> {
nanCount[0]++;
return x;
};
case FAILED:
return ACTION_ERROR;
default:
// this should not happen unless NaNStrategy enum is changed
throw new IllegalStateException();
}
}
/**
* Creates the rank data with NaNs removed.
*
* @param data Source data.
* @return the rank data
*/
private static DataPosition[] createNonNaNRankData(double[] data) {
final DataPosition[] ranks = new DataPosition[data.length];
int size = 0;
for (final double v : data) {
if (!Double.isNaN(v)) {
ranks[size] = new DataPosition(v, size);
size++;
}
}
return size == data.length ? ranks : Arrays.copyOf(ranks, size);
}
/**
* Creates the rank data.
*
* @param data Source data.
* @param nanAction Mapping operator applied to NaN values.
* @return the rank data
*/
private static DataPosition[] createMappedRankData(double[] data, DoubleUnaryOperator nanAction) {
final DataPosition[] ranks = new DataPosition[data.length];
for (int i = 0; i < data.length; i++) {
double v = data[i];
if (Double.isNaN(v)) {
v = nanAction.applyAsDouble(v);
}
ranks[i] = new DataPosition(v, i);
}
return ranks;
}
/**
* Resolve a sequence of ties, using the configured {@link TiesStrategy}. The
* input {@code ranks} array is expected to take the same value for all indices
* in {@code tiesTrace}. The common value is recoded according to the
* tiesStrategy. For example, if ranks = [5,8,2,6,2,7,1,2], tiesTrace = [2,4,7]
* and tiesStrategy is MINIMUM, ranks will be unchanged. The same array and
* trace with tiesStrategy AVERAGE will come out [5,8,3,6,3,7,1,3].
*
* <p>Note: For convenience the final index of the trace is passed as an argument;
* it is assumed the list is already non-empty. At the end of the method the
* list of indices is cleared.
*
* @param ranks Array of ranks.
* @param tiesTrace List of indices where {@code ranks} is constant, that is,
* for any i and j in {@code tiesTrace}: {@code ranks[i] == ranks[j]}.
* @param finalIndex The final index to add to the sequence of ties.
*/
private void resolveTie(double[] ranks, IntList tiesTrace, int finalIndex) {
tiesTrace.add(finalIndex);
// Constant value of ranks over tiesTrace.
// Note: c is a rank counter starting from 1 so limited to an int.
final double c = ranks[tiesTrace.get(0)];
// length of sequence of tied ranks
final int length = tiesTrace.size();
switch (tiesStrategy) {
case AVERAGE: // Replace ranks with average: (lower + upper) / 2
fill(ranks, tiesTrace, (2 * c + length - 1) * 0.5);
break;
case MAXIMUM: // Replace ranks with maximum values
fill(ranks, tiesTrace, c + length - 1);
break;
case MINIMUM: // Replace ties with minimum
// Note that the tie sequence already has all values set to c so
// no requirement to fill again.
break;
case SEQUENTIAL: // Fill sequentially from c to c + length - 1
case RANDOM: // Fill with randomized sequential values in [c, c + length - 1]
// This cast is safe as c is a counter.
int r = (int) c;
if (tiesStrategy == TiesStrategy.RANDOM) {
tiesTrace.shuffle(getRandomIntFunction());
}
final int size = tiesTrace.size();
for (int i = 0; i < size; i++) {
ranks[tiesTrace.get(i)] = r++;
}
break;
default: // this should not happen unless TiesStrategy enum is changed
throw new IllegalStateException();
}
tiesTrace.clear();
}
/**
* Sets {@code data[i] = value} for each i in {@code tiesTrace}.
*
* @param data Array to modify.
* @param tiesTrace List of index values to set.
* @param value Value to set.
*/
private static void fill(double[] data, IntList tiesTrace, double value) {
final int size = tiesTrace.size();
for (int i = 0; i < size; i++) {
data[tiesTrace.get(i)] = value;
}
}
/**
* Gets the function to map positive {@code x} randomly to {@code [0, x)}.
* Defaults to a system provided generator if the constructor source of randomness is null.
*
* @return the RNG
*/
private IntUnaryOperator getRandomIntFunction() {
IntUnaryOperator r = randomIntFunction;
if (r == null) {
// Default to a SplittableRandom
randomIntFunction = r = new SplittableRandom()::nextInt;
}
return r;
}
/**
* An expandable list of int values. This allows tracking array positions
* without using boxed values in a {@code List<Integer>}.
*/
private static class IntList {
/** The maximum size of array to allocate. */
private final int max;
/** The size of the list. */
private int size;
/** The list data. Initialised with space to store a tie of 2 values. */
private int[] data = new int[2];
/**
* @param max Maximum size of array to allocate. Can use the length of the parent array
* for which this is used to track indices.
*/
IntList(int max) {
this.max = max;
}
/**
* Adds the value to the list.
*
* @param value the value
*/
void add(int value) {
if (size == data.length) {
// Overflow safe doubling of the current size.
data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
}
data[size++] = value;
}
/**
* Gets the element at the specified {@code index}.
*
* @param index Element index
* @return the element
*/
int get(int index) {
return data[index];
}
/**
* Gets the number of elements in the list.
*
* @return the size
*/
int size() {
return size;
}
/**
* Clear the list.
*/
void clear() {
size = 0;
}
/**
* Shuffle the list.
*
* @param randomIntFunction Function maps positive {@code x} randomly to {@code [0, x)}.
*/
void shuffle(IntUnaryOperator randomIntFunction) {
// Fisher-Yates shuffle
final int[] array = data;
for (int i = size; i > 1; i--) {
swap(array, i - 1, randomIntFunction.applyAsInt(i));
}
}
/**
* Swaps the two specified elements in the specified array.
*
* @param array Data array
* @param i First index
* @param j Second index
*/
private static void swap(int[] array, int i, int j) {
final int tmp = array[i];
array[i] = array[j];
array[j] = tmp;
}
}
/**
* Represents the position of a {@code double} value in a data array. The
* Comparable interface is implemented so Arrays.sort can be used to sort an
* array of data positions by value. Note that the implicitly defined natural
* ordering is NOT consistent with equals.
*/
private static class DataPosition implements Comparable<DataPosition> {
/** Data value. */
private final double value;
/** Data position. */
private final int position;
/**
* Create an instance with the given value and position.
*
* @param value Data value.
* @param position Data position.
*/
DataPosition(double value, int position) {
this.value = value;
this.position = position;
}
/**
* Compare this value to another.
* Only the <strong>values</strong> are compared.
*
* @param other the other pair to compare this to
* @return result of {@code Double.compare(value, other.value)}
*/
@Override
public int compareTo(DataPosition other) {
return Double.compare(value, other.value);
}
// N.B. equals() and hashCode() are not implemented; see MATH-610 for discussion.
/**
* Returns the data position.
*
* @return position
*/
int getPosition() {
return position;
}
}
}