1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.math3.distribution; 18 19 import java.util.ArrayList; 20 import java.util.List; 21 22 import org.apache.commons.math3.exception.DimensionMismatchException; 23 import org.apache.commons.math3.exception.MathArithmeticException; 24 import org.apache.commons.math3.exception.NotPositiveException; 25 import org.apache.commons.math3.exception.util.LocalizedFormats; 26 import org.apache.commons.math3.random.RandomGenerator; 27 import org.apache.commons.math3.random.Well19937c; 28 import org.apache.commons.math3.util.Pair; 29 30 /** 31 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model"> 32 * mixture model</a> distributions. 33 * 34 * @param <T> Type of the mixture components. 35 * 36 * @since 3.1 37 */ 38 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution> 39 extends AbstractMultivariateRealDistribution { 40 /** Normalized weight of each mixture component. */ 41 private final double[] weight; 42 /** Mixture components. */ 43 private final List<T> distribution; 44 45 /** 46 * Creates a mixture model from a list of distributions and their 47 * associated weights. 48 * <p> 49 * <b>Note:</b> this constructor will implicitly create an instance of 50 * {@link Well19937c} as random generator to be used for sampling only (see 51 * {@link #sample()} and {@link #sample(int)}). In case no sampling is 52 * needed for the created distribution, it is advised to pass {@code null} 53 * as random generator via the appropriate constructors to avoid the 54 * additional initialisation overhead. 55 * 56 * @param components List of (weight, distribution) pairs from which to sample. 57 */ MixtureMultivariateRealDistribution(List<Pair<Double, T>> components)58 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { 59 this(new Well19937c(), components); 60 } 61 62 /** 63 * Creates a mixture model from a list of distributions and their 64 * associated weights. 65 * 66 * @param rng Random number generator. 67 * @param components Distributions from which to sample. 68 * @throws NotPositiveException if any of the weights is negative. 69 * @throws DimensionMismatchException if not all components have the same 70 * number of variables. 71 */ MixtureMultivariateRealDistribution(RandomGenerator rng, List<Pair<Double, T>> components)72 public MixtureMultivariateRealDistribution(RandomGenerator rng, 73 List<Pair<Double, T>> components) { 74 super(rng, components.get(0).getSecond().getDimension()); 75 76 final int numComp = components.size(); 77 final int dim = getDimension(); 78 double weightSum = 0; 79 for (int i = 0; i < numComp; i++) { 80 final Pair<Double, T> comp = components.get(i); 81 if (comp.getSecond().getDimension() != dim) { 82 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); 83 } 84 if (comp.getFirst() < 0) { 85 throw new NotPositiveException(comp.getFirst()); 86 } 87 weightSum += comp.getFirst(); 88 } 89 90 // Check for overflow. 91 if (Double.isInfinite(weightSum)) { 92 throw new MathArithmeticException(LocalizedFormats.OVERFLOW); 93 } 94 95 // Store each distribution and its normalized weight. 96 distribution = new ArrayList<T>(); 97 weight = new double[numComp]; 98 for (int i = 0; i < numComp; i++) { 99 final Pair<Double, T> comp = components.get(i); 100 weight[i] = comp.getFirst() / weightSum; 101 distribution.add(comp.getSecond()); 102 } 103 } 104 105 /** {@inheritDoc} */ density(final double[] values)106 public double density(final double[] values) { 107 double p = 0; 108 for (int i = 0; i < weight.length; i++) { 109 p += weight[i] * distribution.get(i).density(values); 110 } 111 return p; 112 } 113 114 /** {@inheritDoc} */ 115 @Override sample()116 public double[] sample() { 117 // Sampled values. 118 double[] vals = null; 119 120 // Determine which component to sample from. 121 final double randomValue = random.nextDouble(); 122 double sum = 0; 123 124 for (int i = 0; i < weight.length; i++) { 125 sum += weight[i]; 126 if (randomValue <= sum) { 127 // pick model i 128 vals = distribution.get(i).sample(); 129 break; 130 } 131 } 132 133 if (vals == null) { 134 // This should never happen, but it ensures we won't return a null in 135 // case the loop above has some floating point inequality problem on 136 // the final iteration. 137 vals = distribution.get(weight.length - 1).sample(); 138 } 139 140 return vals; 141 } 142 143 /** {@inheritDoc} */ 144 @Override reseedRandomGenerator(long seed)145 public void reseedRandomGenerator(long seed) { 146 // Seed needs to be propagated to underlying components 147 // in order to maintain consistency between runs. 148 super.reseedRandomGenerator(seed); 149 150 for (int i = 0; i < distribution.size(); i++) { 151 // Make each component's seed different in order to avoid 152 // using the same sequence of random numbers. 153 distribution.get(i).reseedRandomGenerator(i + 1 + seed); 154 } 155 } 156 157 /** 158 * Gets the distributions that make up the mixture model. 159 * 160 * @return the component distributions and associated weights. 161 */ getComponents()162 public List<Pair<Double, T>> getComponents() { 163 final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length); 164 165 for (int i = 0; i < weight.length; i++) { 166 list.add(new Pair<Double, T>(weight[i], distribution.get(i))); 167 } 168 169 return list; 170 } 171 } 172