1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkExpectationMaximizationMixtureModelEstimator_h
19 #define itkExpectationMaximizationMixtureModelEstimator_h
20 
21 #include "itkMixtureModelComponentBase.h"
22 #include "itkGaussianMembershipFunction.h"
23 #include "itkSimpleDataObjectDecorator.h"
24 
25 namespace itk
26 {
27 namespace Statistics
28 {
29 /** \class ExpectationMaximizationMixtureModelEstimator
30  *  \brief This class generates the parameter estimates for a mixture
31  *  model using expectation maximization strategy.
32  *
33  * The first template argument is the type of the target sample
34  * data. This estimator expects one or more mixture model component
35  * objects of the classes derived from the
36  * MixtureModelComponentBase. The actual component (or module)
37  * parameters are updated by each component. Users can think this
38  * class as a strategy or a integration point for the EM
39  * procedure. The initial proportion (SetInitialProportions), the
40  * input sample (SetSample), the mixture model components
41  * (AddComponent), and the maximum iteration (SetMaximumIteration) are
42  * required. The EM procedure terminates when the current iteration
43  * reaches the maximum iteration or the model parameters converge.
44  *
45  * <b>Recent API changes:</b>
46  * The static const macro to get the length of a measurement vector,
47  * \c MeasurementVectorSize  has been removed to allow the length of a measurement
48  * vector to be specified at run time. It is now obtained at run time from the
49  * sample set as input. Please use the function
50  * GetMeasurementVectorSize() to get the length.
51  *
52  * \sa MixtureModelComponentBase, GaussianMixtureModelComponent
53  * \ingroup ITKStatistics
54  *
55  * \wiki
56  * \wikiexample{Statistics/ExpectationMaximizationMixtureModelEstimator_2D,2D Gaussian Mixture Model Expectation Maximization}
57  * \endwiki
58  */
59 
60 template< typename TSample >
61 class ITK_TEMPLATE_EXPORT ExpectationMaximizationMixtureModelEstimator:public Object
62 {
63 public:
64   /** Standard class type alias */
65   using Self = ExpectationMaximizationMixtureModelEstimator;
66   using Superclass = Object;
67   using Pointer = SmartPointer< Self >;
68   using ConstPointer = SmartPointer< const Self >;
69 
70   /** Standard macros */
71   itkTypeMacro(ExpectationMaximizationMixtureModelEstimator,
72                Object);
73   itkNewMacro(Self);
74 
75   /** TSample template argument related type alias */
76   using SampleType = TSample;
77   using MeasurementType = typename TSample::MeasurementType;
78   using MeasurementVectorType = typename TSample::MeasurementVectorType;
79 
80   /** Typedef requried to generate dataobject decorated output that can
81     * be plugged into SampleClassifierFilter */
82   using GaussianMembershipFunctionType = GaussianMembershipFunction<MeasurementVectorType>;
83 
84   using GaussianMembershipFunctionPointer = typename GaussianMembershipFunctionType::Pointer;
85 
86   using MembershipFunctionType = MembershipFunctionBase< MeasurementVectorType >;
87   using MembershipFunctionPointer = typename MembershipFunctionType::ConstPointer;
88   using MembershipFunctionVectorType = std::vector< MembershipFunctionPointer >;
89   using MembershipFunctionVectorObjectType = SimpleDataObjectDecorator<MembershipFunctionVectorType>;
90   using MembershipFunctionVectorObjectPointer = typename MembershipFunctionVectorObjectType::Pointer;
91 
92   /** Type of the mixture model component base class */
93   using ComponentType = MixtureModelComponentBase< TSample >;
94 
95   /** Type of the component pointer storage */
96   using ComponentVectorType = std::vector< ComponentType * >;
97 
98   /** Type of the membership function base class */
99   using ComponentMembershipFunctionType = MembershipFunctionBase<MeasurementVectorType>;
100 
101   /** Type of the array of the proportion values */
102   using ProportionVectorType = Array< double >;
103 
104   /** Sets the target data that will be classified by this */
105   void SetSample(const TSample *sample);
106 
107   /** Returns the target data */
108   const TSample * GetSample() const;
109 
110   /** Set/Gets the initial proportion values. The size of proportion
111    * vector should be same as the number of component (or classes) */
112   void SetInitialProportions(ProportionVectorType & propotion);
113 
114   const ProportionVectorType & GetInitialProportions() const;
115 
116   /** Gets the result proportion values */
117   const ProportionVectorType & GetProportions() const;
118 
119   /** type alias for decorated array of proportion */
120   using MembershipFunctionsWeightsArrayObjectType = SimpleDataObjectDecorator<ProportionVectorType>;
121   using MembershipFunctionsWeightsArrayPointer = typename MembershipFunctionsWeightsArrayObjectType::Pointer;
122 
123   /** Get method for data decorated Membership functions weights array */
124   const MembershipFunctionsWeightsArrayObjectType * GetMembershipFunctionsWeightsArray() const;
125 
126   /** Set/Gets the maximum number of iterations. When the optimization
127    * process reaches the maximum number of interations, even if the
128    * class parameters aren't converged, the optimization process
129    * stops. */
130   void SetMaximumIteration(int numberOfIterations);
131 
132   int GetMaximumIteration() const;
133 
134   /** Gets the current iteration. */
GetCurrentIteration()135   int GetCurrentIteration()
136   {
137     return m_CurrentIteration;
138   }
139 
140   /** Adds a new component (or class). */
141   int AddComponent(ComponentType *component);
142 
143   /** Gets the total number of classes currently plugged in. */
144   unsigned int GetNumberOfComponents() const;
145 
146   /** Runs the optimization process. */
147   void Update();
148 
149   /** Termination status after running optimization */
150   enum TERMINATION_CODE { CONVERGED = 0, NOT_CONVERGED = 1 };
151 
152   /** Gets the termination status */
153   TERMINATION_CODE GetTerminationCode() const;
154 
155   /** Gets the membership function specified by componentIndex
156   argument. */
157   ComponentMembershipFunctionType * GetComponentMembershipFunction(int componentIndex) const;
158 
159   /** Output Membership function vector containing the membership functions with
160     * the final optimized parameters */
161   const MembershipFunctionVectorObjectType * GetOutput() const;
162 
163 protected:
164   ExpectationMaximizationMixtureModelEstimator();
165   ~ExpectationMaximizationMixtureModelEstimator() override = default;
166   void PrintSelf(std::ostream & os, Indent indent) const override;
167 
168   bool CalculateDensities();
169 
170   double CalculateExpectation() const;
171 
172   bool UpdateComponentParameters();
173 
174   bool UpdateProportions();
175 
176   /** Starts the estimation process */
177   void GenerateData();
178 
179 private:
180   /** Target data sample pointer*/
181   const TSample *m_Sample;
182 
183   int m_MaxIteration{100};
184   int m_CurrentIteration{0};
185 
186   TERMINATION_CODE     m_TerminationCode;
187   ComponentVectorType  m_ComponentVector;
188   ProportionVectorType m_InitialProportions;
189   ProportionVectorType m_Proportions;
190 
191   MembershipFunctionVectorObjectPointer  m_MembershipFunctionsObject;
192   MembershipFunctionsWeightsArrayPointer m_MembershipFunctionsWeightArrayObject;
193 };  // end of class
194 } // end of namespace Statistics
195 } // end of namespace itk
196 
197 #ifndef ITK_MANUAL_INSTANTIATION
198 #include "itkExpectationMaximizationMixtureModelEstimator.hxx"
199 #endif
200 
201 #endif
202