1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package org.apache.commons.math3.distribution;
18 
19 import java.util.ArrayList;
20 import java.util.List;
21 
22 import org.apache.commons.math3.exception.DimensionMismatchException;
23 import org.apache.commons.math3.exception.MathArithmeticException;
24 import org.apache.commons.math3.exception.NotPositiveException;
25 import org.apache.commons.math3.exception.util.LocalizedFormats;
26 import org.apache.commons.math3.random.RandomGenerator;
27 import org.apache.commons.math3.random.Well19937c;
28 import org.apache.commons.math3.util.Pair;
29 
30 /**
31  * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
32  * mixture model</a> distributions.
33  *
34  * @param <T> Type of the mixture components.
35  *
36  * @since 3.1
37  */
38 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
39     extends AbstractMultivariateRealDistribution {
40     /** Normalized weight of each mixture component. */
41     private final double[] weight;
42     /** Mixture components. */
43     private final List<T> distribution;
44 
45     /**
46      * Creates a mixture model from a list of distributions and their
47      * associated weights.
48      * <p>
49      * <b>Note:</b> this constructor will implicitly create an instance of
50      * {@link Well19937c} as random generator to be used for sampling only (see
51      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
52      * needed for the created distribution, it is advised to pass {@code null}
53      * as random generator via the appropriate constructors to avoid the
54      * additional initialisation overhead.
55      *
56      * @param components List of (weight, distribution) pairs from which to sample.
57      */
MixtureMultivariateRealDistribution(List<Pair<Double, T>> components)58     public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
59         this(new Well19937c(), components);
60     }
61 
62     /**
63      * Creates a mixture model from a list of distributions and their
64      * associated weights.
65      *
66      * @param rng Random number generator.
67      * @param components Distributions from which to sample.
68      * @throws NotPositiveException if any of the weights is negative.
69      * @throws DimensionMismatchException if not all components have the same
70      * number of variables.
71      */
MixtureMultivariateRealDistribution(RandomGenerator rng, List<Pair<Double, T>> components)72     public MixtureMultivariateRealDistribution(RandomGenerator rng,
73                                                List<Pair<Double, T>> components) {
74         super(rng, components.get(0).getSecond().getDimension());
75 
76         final int numComp = components.size();
77         final int dim = getDimension();
78         double weightSum = 0;
79         for (int i = 0; i < numComp; i++) {
80             final Pair<Double, T> comp = components.get(i);
81             if (comp.getSecond().getDimension() != dim) {
82                 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
83             }
84             if (comp.getFirst() < 0) {
85                 throw new NotPositiveException(comp.getFirst());
86             }
87             weightSum += comp.getFirst();
88         }
89 
90         // Check for overflow.
91         if (Double.isInfinite(weightSum)) {
92             throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
93         }
94 
95         // Store each distribution and its normalized weight.
96         distribution = new ArrayList<T>();
97         weight = new double[numComp];
98         for (int i = 0; i < numComp; i++) {
99             final Pair<Double, T> comp = components.get(i);
100             weight[i] = comp.getFirst() / weightSum;
101             distribution.add(comp.getSecond());
102         }
103     }
104 
105     /** {@inheritDoc} */
density(final double[] values)106     public double density(final double[] values) {
107         double p = 0;
108         for (int i = 0; i < weight.length; i++) {
109             p += weight[i] * distribution.get(i).density(values);
110         }
111         return p;
112     }
113 
114     /** {@inheritDoc} */
115     @Override
sample()116     public double[] sample() {
117         // Sampled values.
118         double[] vals = null;
119 
120         // Determine which component to sample from.
121         final double randomValue = random.nextDouble();
122         double sum = 0;
123 
124         for (int i = 0; i < weight.length; i++) {
125             sum += weight[i];
126             if (randomValue <= sum) {
127                 // pick model i
128                 vals = distribution.get(i).sample();
129                 break;
130             }
131         }
132 
133         if (vals == null) {
134             // This should never happen, but it ensures we won't return a null in
135             // case the loop above has some floating point inequality problem on
136             // the final iteration.
137             vals = distribution.get(weight.length - 1).sample();
138         }
139 
140         return vals;
141     }
142 
143     /** {@inheritDoc} */
144     @Override
reseedRandomGenerator(long seed)145     public void reseedRandomGenerator(long seed) {
146         // Seed needs to be propagated to underlying components
147         // in order to maintain consistency between runs.
148         super.reseedRandomGenerator(seed);
149 
150         for (int i = 0; i < distribution.size(); i++) {
151             // Make each component's seed different in order to avoid
152             // using the same sequence of random numbers.
153             distribution.get(i).reseedRandomGenerator(i + 1 + seed);
154         }
155     }
156 
157     /**
158      * Gets the distributions that make up the mixture model.
159      *
160      * @return the component distributions and associated weights.
161      */
getComponents()162     public List<Pair<Double, T>> getComponents() {
163         final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length);
164 
165         for (int i = 0; i < weight.length; i++) {
166             list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
167         }
168 
169         return list;
170     }
171 }
172