1 /*========================================================================= 2 * 3 * Copyright Insight Software Consortium 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 *=========================================================================*/ 18 #ifndef itkExpectationMaximizationMixtureModelEstimator_h 19 #define itkExpectationMaximizationMixtureModelEstimator_h 20 21 #include "itkMixtureModelComponentBase.h" 22 #include "itkGaussianMembershipFunction.h" 23 #include "itkSimpleDataObjectDecorator.h" 24 25 namespace itk 26 { 27 namespace Statistics 28 { 29 /** \class ExpectationMaximizationMixtureModelEstimator 30 * \brief This class generates the parameter estimates for a mixture 31 * model using expectation maximization strategy. 32 * 33 * The first template argument is the type of the target sample 34 * data. This estimator expects one or more mixture model component 35 * objects of the classes derived from the 36 * MixtureModelComponentBase. The actual component (or module) 37 * parameters are updated by each component. Users can think this 38 * class as a strategy or a integration point for the EM 39 * procedure. The initial proportion (SetInitialProportions), the 40 * input sample (SetSample), the mixture model components 41 * (AddComponent), and the maximum iteration (SetMaximumIteration) are 42 * required. The EM procedure terminates when the current iteration 43 * reaches the maximum iteration or the model parameters converge. 44 * 45 * <b>Recent API changes:</b> 46 * The static const macro to get the length of a measurement vector, 47 * \c MeasurementVectorSize has been removed to allow the length of a measurement 48 * vector to be specified at run time. It is now obtained at run time from the 49 * sample set as input. Please use the function 50 * GetMeasurementVectorSize() to get the length. 51 * 52 * \sa MixtureModelComponentBase, GaussianMixtureModelComponent 53 * \ingroup ITKStatistics 54 * 55 * \wiki 56 * \wikiexample{Statistics/ExpectationMaximizationMixtureModelEstimator_2D,2D Gaussian Mixture Model Expectation Maximization} 57 * \endwiki 58 */ 59 60 template< typename TSample > 61 class ITK_TEMPLATE_EXPORT ExpectationMaximizationMixtureModelEstimator:public Object 62 { 63 public: 64 /** Standard class type alias */ 65 using Self = ExpectationMaximizationMixtureModelEstimator; 66 using Superclass = Object; 67 using Pointer = SmartPointer< Self >; 68 using ConstPointer = SmartPointer< const Self >; 69 70 /** Standard macros */ 71 itkTypeMacro(ExpectationMaximizationMixtureModelEstimator, 72 Object); 73 itkNewMacro(Self); 74 75 /** TSample template argument related type alias */ 76 using SampleType = TSample; 77 using MeasurementType = typename TSample::MeasurementType; 78 using MeasurementVectorType = typename TSample::MeasurementVectorType; 79 80 /** Typedef requried to generate dataobject decorated output that can 81 * be plugged into SampleClassifierFilter */ 82 using GaussianMembershipFunctionType = GaussianMembershipFunction<MeasurementVectorType>; 83 84 using GaussianMembershipFunctionPointer = typename GaussianMembershipFunctionType::Pointer; 85 86 using MembershipFunctionType = MembershipFunctionBase< MeasurementVectorType >; 87 using MembershipFunctionPointer = typename MembershipFunctionType::ConstPointer; 88 using MembershipFunctionVectorType = std::vector< MembershipFunctionPointer >; 89 using MembershipFunctionVectorObjectType = SimpleDataObjectDecorator<MembershipFunctionVectorType>; 90 using MembershipFunctionVectorObjectPointer = typename MembershipFunctionVectorObjectType::Pointer; 91 92 /** Type of the mixture model component base class */ 93 using ComponentType = MixtureModelComponentBase< TSample >; 94 95 /** Type of the component pointer storage */ 96 using ComponentVectorType = std::vector< ComponentType * >; 97 98 /** Type of the membership function base class */ 99 using ComponentMembershipFunctionType = MembershipFunctionBase<MeasurementVectorType>; 100 101 /** Type of the array of the proportion values */ 102 using ProportionVectorType = Array< double >; 103 104 /** Sets the target data that will be classified by this */ 105 void SetSample(const TSample *sample); 106 107 /** Returns the target data */ 108 const TSample * GetSample() const; 109 110 /** Set/Gets the initial proportion values. The size of proportion 111 * vector should be same as the number of component (or classes) */ 112 void SetInitialProportions(ProportionVectorType & propotion); 113 114 const ProportionVectorType & GetInitialProportions() const; 115 116 /** Gets the result proportion values */ 117 const ProportionVectorType & GetProportions() const; 118 119 /** type alias for decorated array of proportion */ 120 using MembershipFunctionsWeightsArrayObjectType = SimpleDataObjectDecorator<ProportionVectorType>; 121 using MembershipFunctionsWeightsArrayPointer = typename MembershipFunctionsWeightsArrayObjectType::Pointer; 122 123 /** Get method for data decorated Membership functions weights array */ 124 const MembershipFunctionsWeightsArrayObjectType * GetMembershipFunctionsWeightsArray() const; 125 126 /** Set/Gets the maximum number of iterations. When the optimization 127 * process reaches the maximum number of interations, even if the 128 * class parameters aren't converged, the optimization process 129 * stops. */ 130 void SetMaximumIteration(int numberOfIterations); 131 132 int GetMaximumIteration() const; 133 134 /** Gets the current iteration. */ GetCurrentIteration()135 int GetCurrentIteration() 136 { 137 return m_CurrentIteration; 138 } 139 140 /** Adds a new component (or class). */ 141 int AddComponent(ComponentType *component); 142 143 /** Gets the total number of classes currently plugged in. */ 144 unsigned int GetNumberOfComponents() const; 145 146 /** Runs the optimization process. */ 147 void Update(); 148 149 /** Termination status after running optimization */ 150 enum TERMINATION_CODE { CONVERGED = 0, NOT_CONVERGED = 1 }; 151 152 /** Gets the termination status */ 153 TERMINATION_CODE GetTerminationCode() const; 154 155 /** Gets the membership function specified by componentIndex 156 argument. */ 157 ComponentMembershipFunctionType * GetComponentMembershipFunction(int componentIndex) const; 158 159 /** Output Membership function vector containing the membership functions with 160 * the final optimized parameters */ 161 const MembershipFunctionVectorObjectType * GetOutput() const; 162 163 protected: 164 ExpectationMaximizationMixtureModelEstimator(); 165 ~ExpectationMaximizationMixtureModelEstimator() override = default; 166 void PrintSelf(std::ostream & os, Indent indent) const override; 167 168 bool CalculateDensities(); 169 170 double CalculateExpectation() const; 171 172 bool UpdateComponentParameters(); 173 174 bool UpdateProportions(); 175 176 /** Starts the estimation process */ 177 void GenerateData(); 178 179 private: 180 /** Target data sample pointer*/ 181 const TSample *m_Sample; 182 183 int m_MaxIteration{100}; 184 int m_CurrentIteration{0}; 185 186 TERMINATION_CODE m_TerminationCode; 187 ComponentVectorType m_ComponentVector; 188 ProportionVectorType m_InitialProportions; 189 ProportionVectorType m_Proportions; 190 191 MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject; 192 MembershipFunctionsWeightsArrayPointer m_MembershipFunctionsWeightArrayObject; 193 }; // end of class 194 } // end of namespace Statistics 195 } // end of namespace itk 196 197 #ifndef ITK_MANUAL_INSTANTIATION 198 #include "itkExpectationMaximizationMixtureModelEstimator.hxx" 199 #endif 200 201 #endif 202