1 //                                               -*- C++ -*-
2 /**
3  *  @brief Factory for Dirichlet distribution
4  *
5  *  Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6  *
7  *  This library is free software: you can redistribute it and/or modify
8  *  it under the terms of the GNU Lesser General Public License as published by
9  *  the Free Software Foundation, either version 3 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This library is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU Lesser General Public License for more details.
16  *
17  *  You should have received a copy of the GNU Lesser General Public License
18  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 #include "openturns/DirichletFactory.hxx"
22 #include "openturns/ResourceMap.hxx"
23 #include "openturns/SpecFunc.hxx"
24 #include "openturns/PersistentObjectFactory.hxx"
25 
26 BEGIN_NAMESPACE_OPENTURNS
27 
28 CLASSNAMEINIT(DirichletFactory)
29 
30 static const Factory<DirichletFactory> Factory_DirichletFactory;
31 
32 /* Default constructor */
DirichletFactory()33 DirichletFactory::DirichletFactory():
34   DistributionFactoryImplementation()
35 {
36   // Nothing to do
37 }
38 
39 /* Virtual constructor */
clone() const40 DirichletFactory * DirichletFactory::clone() const
41 {
42   return new DirichletFactory(*this);
43 }
44 
45 
46 /* Here is the interface that all derived class must implement */
47 
build(const Sample & sample) const48 Distribution DirichletFactory::build(const Sample & sample) const
49 {
50   return buildAsDirichlet(sample).clone();
51 }
52 
build(const Point & parameters) const53 Distribution DirichletFactory::build(const Point & parameters) const
54 {
55   return buildAsDirichlet(parameters).clone();
56 }
57 
build() const58 Distribution DirichletFactory::build() const
59 {
60   return buildAsDirichlet().clone();
61 }
62 
buildAsDirichlet(const Sample & sample) const63 Dirichlet DirichletFactory::buildAsDirichlet(const Sample & sample) const
64 {
65   const UnsignedInteger size = sample.getSize();
66   if (size < 2) throw InvalidArgumentException(HERE) << "Error: cannot build a Dirichlet distribution from a sample of size < 2";
67   const UnsignedInteger dimension = sample.getDimension();
68   // Check that the points lie in the simplex x_1+...+x_d < 1, x_k > 0
69   // and precompute the sufficient statistics
70   Point meanLog(dimension + 1);
71   Point sumX(dimension, 0.0);
72   Point sumX2(dimension, 0.0);
73   for (UnsignedInteger i = 0; i < size; ++i)
74   {
75     Scalar sum = 0.0;
76     for (UnsignedInteger j = 0; j < dimension; ++j)
77     {
78       const Scalar xIJ = sample(i, j);
79       if (!(xIJ > 0.0)) throw InvalidArgumentException(HERE) << "Error: the sample contains points not in the unit simplex: x=" << sample[i];
80       sum += xIJ;
81       meanLog[j] += std::log(xIJ);
82       sumX[j] += xIJ;
83       sumX2[j] += xIJ * xIJ;
84     }
85     if (!(sum < 1.0)) throw InvalidArgumentException(HERE) << "Error: the sample contains points not in the unit simplex: x=" << sample[i];
86     meanLog[dimension] += log1p(-sum);
87   }
88   // Normalize the sum of the logarithms
89   meanLog = meanLog * (1.0 / size);
90   // Find the maximum likelihood estimate using a fixed-point strategy
91   // First, compute a reasonable initial guess using moments
92   Point theta(dimension + 1, 0.0);
93   Scalar sumTheta = 0.0;
94   // Estimate the sum of parameters
95   for (UnsignedInteger i = 0; i < dimension; ++i)
96   {
97     const Scalar sumXI = sumX[i];
98     const Scalar sumX2I = sumX2[i];
99     const Scalar numerator = sumXI - sumX2I;
100     const Scalar denominator = sumX2I - sumXI * sumXI / size;
101     if (denominator == 0.0) throw InvalidArgumentException(HERE) << "Error: the component " << i << " of the sample is constant (equal to " << sumXI / size << "). Impossible to estimate a Dirichlet distribution.";
102     sumTheta += numerator / denominator;
103   }
104   sumTheta /= dimension;
105   // Estimate the parameters from the mean of the sample
106   Scalar lastTheta = sumTheta;
107   for (UnsignedInteger i = 0; i < dimension; ++i)
108   {
109     const Scalar thetaI = (sumX[i] / size) * sumTheta;
110     // If the estimate is positive, use it, if not, use a default value of ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" )
111     theta[i] = (thetaI > 0.0 ? thetaI : ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
112     lastTheta -= theta[i];
113   }
114   // If the estimate is positive, use it, if not, use a default value of ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" )
115   theta[dimension] = (lastTheta > 0.0 ? lastTheta : ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
116   Bool convergence = false;
117   UnsignedInteger iteration = 0;
118   while (!convergence && (iteration < ResourceMap::GetAsUnsignedInteger( "DirichletFactory-MaximumIteration" )))
119   {
120     // Newton iteration
121     ++iteration;
122     sumTheta = 0.0;
123     for (UnsignedInteger i = 0; i <= dimension; ++i) sumTheta += theta[i];
124     const Scalar diGammaSumTheta = SpecFunc::DiGamma(sumTheta);
125     const Scalar triGammaSumTheta = SpecFunc::TriGamma(sumTheta);
126     Point g(dimension + 1);
127     Point q(dimension + 1);
128     Scalar numerator = 0.0;
129     Scalar denominator = 0.0;
130     for (UnsignedInteger i = 0; i <= dimension; ++i)
131     {
132       g[i] = meanLog[i] - SpecFunc::DiGamma(theta[i]) + diGammaSumTheta;
133       q[i] = -SpecFunc::TriGamma(theta[i]);
134       numerator += g[i] / q[i];
135       denominator += 1.0 / q[i];
136     }
137     const Scalar b = numerator / (1.0 / triGammaSumTheta + denominator);
138     Point delta(dimension + 1);
139     for (UnsignedInteger i = 0; i <= dimension; ++i) delta[i] = (g[i] - b) / q[i];
140     // Newton update
141     theta = theta - delta;
142     convergence = (delta.norm() < dimension * ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
143   }
144   // Fixed point algorithm, works but is slow. Should never go there, as the Newton iteration should converge
145   iteration = 0;
146   while (!convergence && (iteration < ResourceMap::GetAsUnsignedInteger( "DirichletFactory-MaximumIteration" )))
147   {
148     ++ iteration;
149     sumTheta = 0.0;
150     for (UnsignedInteger i = 0; i <= dimension; ++i) sumTheta += theta[i];
151     const Scalar psiSumTheta = SpecFunc::DiGamma(sumTheta);
152     Scalar delta = 0.0;
153     for (UnsignedInteger i = 0; i <= dimension; ++i)
154     {
155       const Scalar thetaI = SpecFunc::DiGammaInv(psiSumTheta + meanLog[i]);
156       delta += std::abs(theta[i] - thetaI);
157       theta[i] = thetaI;
158     }
159     convergence = (delta < dimension * ResourceMap::GetAsScalar( "DirichletFactory-ParametersEpsilon" ));
160   }
161   Dirichlet result(theta);
162   result.setDescription(sample.getDescription());
163   return result;
164 }
165 
buildAsDirichlet(const Point & parameters) const166 Dirichlet DirichletFactory::buildAsDirichlet(const Point & parameters) const
167 {
168   try
169   {
170     Dirichlet distribution;
171     distribution.setParameter(parameters);
172     return distribution;
173   }
174   catch (InvalidArgumentException &)
175   {
176     throw InvalidArgumentException(HERE) << "Error: cannot build a Dirichlet distribution from the given parameters";
177   }
178 }
179 
buildAsDirichlet() const180 Dirichlet DirichletFactory::buildAsDirichlet() const
181 {
182   return Dirichlet();
183 }
184 
185 END_NAMESPACE_OPENTURNS
186