1 // -*- C++ -*-
2 /**
3 * @brief Factory for Dirichlet distribution
4 *
5 * Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6 *
7 * This library is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU Lesser General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public License
18 * along with this library. If not, see <http://www.gnu.org/licenses/>.
19 *
20 */
21 #include "openturns/DirichletFactory.hxx"
22 #include "openturns/ResourceMap.hxx"
23 #include "openturns/SpecFunc.hxx"
24 #include "openturns/PersistentObjectFactory.hxx"
25
26 BEGIN_NAMESPACE_OPENTURNS
27
28 CLASSNAMEINIT(DirichletFactory)
29
30 static const Factory<DirichletFactory> Factory_DirichletFactory;
31
32 /* Default constructor */
DirichletFactory()33 DirichletFactory::DirichletFactory():
34 DistributionFactoryImplementation()
35 {
36 // Nothing to do
37 }
38
39 /* Virtual constructor */
clone() const40 DirichletFactory * DirichletFactory::clone() const
41 {
42 return new DirichletFactory(*this);
43 }
44
45
46 /* Here is the interface that all derived class must implement */
47
build(const Sample & sample) const48 Distribution DirichletFactory::build(const Sample & sample) const
49 {
50 return buildAsDirichlet(sample).clone();
51 }
52
build(const Point & parameters) const53 Distribution DirichletFactory::build(const Point & parameters) const
54 {
55 return buildAsDirichlet(parameters).clone();
56 }
57
build() const58 Distribution DirichletFactory::build() const
59 {
60 return buildAsDirichlet().clone();
61 }
62
buildAsDirichlet(const Sample & sample) const63 Dirichlet DirichletFactory::buildAsDirichlet(const Sample & sample) const
64 {
65 const UnsignedInteger size = sample.getSize();
66 if (size < 2) throw InvalidArgumentException(HERE) << "Error: cannot build a Dirichlet distribution from a sample of size < 2";
67 const UnsignedInteger dimension = sample.getDimension();
68 // Check that the points lie in the simplex x_1+...+x_d < 1, x_k > 0
69 // and precompute the sufficient statistics
70 Point meanLog(dimension + 1);
71 Point sumX(dimension, 0.0);
72 Point sumX2(dimension, 0.0);
73 for (UnsignedInteger i = 0; i < size; ++i)
74 {
75 Scalar sum = 0.0;
76 for (UnsignedInteger j = 0; j < dimension; ++j)
77 {
78 const Scalar xIJ = sample(i, j);
79 if (!(xIJ > 0.0)) throw InvalidArgumentException(HERE) << "Error: the sample contains points not in the unit simplex: x=" << sample[i];
80 sum += xIJ;
81 meanLog[j] += std::log(xIJ);
82 sumX[j] += xIJ;
83 sumX2[j] += xIJ * xIJ;
84 }
85 if (!(sum < 1.0)) throw InvalidArgumentException(HERE) << "Error: the sample contains points not in the unit simplex: x=" << sample[i];
86 meanLog[dimension] += log1p(-sum);
87 }
88 // Normalize the sum of the logarithms
89 meanLog = meanLog * (1.0 / size);
90 // Find the maximum likelihood estimate using a fixed-point strategy
91 // First, compute a reasonable initial guess using moments
92 Point theta(dimension + 1, 0.0);
93 Scalar sumTheta = 0.0;
94 // Estimate the sum of parameters
95 for (UnsignedInteger i = 0; i < dimension; ++i)
96 {
97 const Scalar sumXI = sumX[i];
98 const Scalar sumX2I = sumX2[i];
99 const Scalar numerator = sumXI - sumX2I;
100 const Scalar denominator = sumX2I - sumXI * sumXI / size;
101 if (denominator == 0.0) throw InvalidArgumentException(HERE) << "Error: the component " << i << " of the sample is constant (equal to " << sumXI / size << "). Impossible to estimate a Dirichlet distribution.";
102 sumTheta += numerator / denominator;
103 }
104 sumTheta /= dimension;
105 // Estimate the parameters from the mean of the sample
106 Scalar lastTheta = sumTheta;
107 for (UnsignedInteger i = 0; i < dimension; ++i)
108 {
109 const Scalar thetaI = (sumX[i] / size) * sumTheta;
110 // If the estimate is positive, use it, if not, use a default value of ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" )
111 theta[i] = (thetaI > 0.0 ? thetaI : ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
112 lastTheta -= theta[i];
113 }
114 // If the estimate is positive, use it, if not, use a default value of ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" )
115 theta[dimension] = (lastTheta > 0.0 ? lastTheta : ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
116 Bool convergence = false;
117 UnsignedInteger iteration = 0;
118 while (!convergence && (iteration < ResourceMap::GetAsUnsignedInteger( "DirichletFactory-MaximumIteration" )))
119 {
120 // Newton iteration
121 ++iteration;
122 sumTheta = 0.0;
123 for (UnsignedInteger i = 0; i <= dimension; ++i) sumTheta += theta[i];
124 const Scalar diGammaSumTheta = SpecFunc::DiGamma(sumTheta);
125 const Scalar triGammaSumTheta = SpecFunc::TriGamma(sumTheta);
126 Point g(dimension + 1);
127 Point q(dimension + 1);
128 Scalar numerator = 0.0;
129 Scalar denominator = 0.0;
130 for (UnsignedInteger i = 0; i <= dimension; ++i)
131 {
132 g[i] = meanLog[i] - SpecFunc::DiGamma(theta[i]) + diGammaSumTheta;
133 q[i] = -SpecFunc::TriGamma(theta[i]);
134 numerator += g[i] / q[i];
135 denominator += 1.0 / q[i];
136 }
137 const Scalar b = numerator / (1.0 / triGammaSumTheta + denominator);
138 Point delta(dimension + 1);
139 for (UnsignedInteger i = 0; i <= dimension; ++i) delta[i] = (g[i] - b) / q[i];
140 // Newton update
141 theta = theta - delta;
142 convergence = (delta.norm() < dimension * ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
143 }
144 // Fixed point algorithm, works but is slow. Should never go there, as the Newton iteration should converge
145 iteration = 0;
146 while (!convergence && (iteration < ResourceMap::GetAsUnsignedInteger( "DirichletFactory-MaximumIteration" )))
147 {
148 ++ iteration;
149 sumTheta = 0.0;
150 for (UnsignedInteger i = 0; i <= dimension; ++i) sumTheta += theta[i];
151 const Scalar psiSumTheta = SpecFunc::DiGamma(sumTheta);
152 Scalar delta = 0.0;
153 for (UnsignedInteger i = 0; i <= dimension; ++i)
154 {
155 const Scalar thetaI = SpecFunc::DiGammaInv(psiSumTheta + meanLog[i]);
156 delta += std::abs(theta[i] - thetaI);
157 theta[i] = thetaI;
158 }
159 convergence = (delta < dimension * ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
160 }
161 Dirichlet result(theta);
162 result.setDescription(sample.getDescription());
163 return result;
164 }
165
buildAsDirichlet(const Point & parameters) const166 Dirichlet DirichletFactory::buildAsDirichlet(const Point & parameters) const
167 {
168 try
169 {
170 Dirichlet distribution;
171 distribution.setParameter(parameters);
172 return distribution;
173 }
174 catch (InvalidArgumentException &)
175 {
176 throw InvalidArgumentException(HERE) << "Error: cannot build a Dirichlet distribution from the given parameters";
177 }
178 }
179
buildAsDirichlet() const180 Dirichlet DirichletFactory::buildAsDirichlet() const
181 {
182 return Dirichlet();
183 }
184
185 END_NAMESPACE_OPENTURNS
186