1"""Base class for mixture models."""
2
3# Author: Wei Xue <xuewei4d@gmail.com>
4# Modified by Thierry Guillemot <thierry.guillemot.work@gmail.com>
5# License: BSD 3 clause
6
7import warnings
8from abc import ABCMeta, abstractmethod
9from time import time
10
11import numpy as np
12from scipy.special import logsumexp
13
14from .. import cluster
15from ..base import BaseEstimator
16from ..base import DensityMixin
17from ..exceptions import ConvergenceWarning
18from ..utils import check_random_state
19from ..utils.validation import check_is_fitted
20
21
22def _check_shape(param, param_shape, name):
23    """Validate the shape of the input parameter 'param'.
24
25    Parameters
26    ----------
27    param : array
28
29    param_shape : tuple
30
31    name : str
32    """
33    param = np.array(param)
34    if param.shape != param_shape:
35        raise ValueError(
36            "The parameter '%s' should have the shape of %s, but got %s"
37            % (name, param_shape, param.shape)
38        )
39
40
41class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
42    """Base class for mixture models.
43
44    This abstract class specifies an interface for all mixture classes and
45    provides basic common methods for mixture models.
46    """
47
48    def __init__(
49        self,
50        n_components,
51        tol,
52        reg_covar,
53        max_iter,
54        n_init,
55        init_params,
56        random_state,
57        warm_start,
58        verbose,
59        verbose_interval,
60    ):
61        self.n_components = n_components
62        self.tol = tol
63        self.reg_covar = reg_covar
64        self.max_iter = max_iter
65        self.n_init = n_init
66        self.init_params = init_params
67        self.random_state = random_state
68        self.warm_start = warm_start
69        self.verbose = verbose
70        self.verbose_interval = verbose_interval
71
72    def _check_initial_parameters(self, X):
73        """Check values of the basic parameters.
74
75        Parameters
76        ----------
77        X : array-like of shape (n_samples, n_features)
78        """
79        if self.n_components < 1:
80            raise ValueError(
81                "Invalid value for 'n_components': %d "
82                "Estimation requires at least one component"
83                % self.n_components
84            )
85
86        if self.tol < 0.0:
87            raise ValueError(
88                "Invalid value for 'tol': %.5f "
89                "Tolerance used by the EM must be non-negative"
90                % self.tol
91            )
92
93        if self.n_init < 1:
94            raise ValueError(
95                "Invalid value for 'n_init': %d Estimation requires at least one run"
96                % self.n_init
97            )
98
99        if self.max_iter < 1:
100            raise ValueError(
101                "Invalid value for 'max_iter': %d "
102                "Estimation requires at least one iteration"
103                % self.max_iter
104            )
105
106        if self.reg_covar < 0.0:
107            raise ValueError(
108                "Invalid value for 'reg_covar': %.5f "
109                "regularization on covariance must be "
110                "non-negative"
111                % self.reg_covar
112            )
113
114        # Check all the parameters values of the derived class
115        self._check_parameters(X)
116
117    @abstractmethod
118    def _check_parameters(self, X):
119        """Check initial parameters of the derived class.
120
121        Parameters
122        ----------
123        X : array-like of shape  (n_samples, n_features)
124        """
125        pass
126
127    def _initialize_parameters(self, X, random_state):
128        """Initialize the model parameters.
129
130        Parameters
131        ----------
132        X : array-like of shape  (n_samples, n_features)
133
134        random_state : RandomState
135            A random number generator instance that controls the random seed
136            used for the method chosen to initialize the parameters.
137        """
138        n_samples, _ = X.shape
139
140        if self.init_params == "kmeans":
141            resp = np.zeros((n_samples, self.n_components))
142            label = (
143                cluster.KMeans(
144                    n_clusters=self.n_components, n_init=1, random_state=random_state
145                )
146                .fit(X)
147                .labels_
148            )
149            resp[np.arange(n_samples), label] = 1
150        elif self.init_params == "random":
151            resp = random_state.rand(n_samples, self.n_components)
152            resp /= resp.sum(axis=1)[:, np.newaxis]
153        else:
154            raise ValueError(
155                "Unimplemented initialization method '%s'" % self.init_params
156            )
157
158        self._initialize(X, resp)
159
160    @abstractmethod
161    def _initialize(self, X, resp):
162        """Initialize the model parameters of the derived class.
163
164        Parameters
165        ----------
166        X : array-like of shape  (n_samples, n_features)
167
168        resp : array-like of shape (n_samples, n_components)
169        """
170        pass
171
172    def fit(self, X, y=None):
173        """Estimate model parameters with the EM algorithm.
174
175        The method fits the model ``n_init`` times and sets the parameters with
176        which the model has the largest likelihood or lower bound. Within each
177        trial, the method iterates between E-step and M-step for ``max_iter``
178        times until the change of likelihood or lower bound is less than
179        ``tol``, otherwise, a ``ConvergenceWarning`` is raised.
180        If ``warm_start`` is ``True``, then ``n_init`` is ignored and a single
181        initialization is performed upon the first call. Upon consecutive
182        calls, training starts where it left off.
183
184        Parameters
185        ----------
186        X : array-like of shape (n_samples, n_features)
187            List of n_features-dimensional data points. Each row
188            corresponds to a single data point.
189
190        y : Ignored
191            Not used, present for API consistency by convention.
192
193        Returns
194        -------
195        self : object
196            The fitted mixture.
197        """
198        self.fit_predict(X, y)
199        return self
200
201    def fit_predict(self, X, y=None):
202        """Estimate model parameters using X and predict the labels for X.
203
204        The method fits the model n_init times and sets the parameters with
205        which the model has the largest likelihood or lower bound. Within each
206        trial, the method iterates between E-step and M-step for `max_iter`
207        times until the change of likelihood or lower bound is less than
208        `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is
209        raised. After fitting, it predicts the most probable label for the
210        input data points.
211
212        .. versionadded:: 0.20
213
214        Parameters
215        ----------
216        X : array-like of shape (n_samples, n_features)
217            List of n_features-dimensional data points. Each row
218            corresponds to a single data point.
219
220        y : Ignored
221            Not used, present for API consistency by convention.
222
223        Returns
224        -------
225        labels : array, shape (n_samples,)
226            Component labels.
227        """
228        X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_min_samples=2)
229        if X.shape[0] < self.n_components:
230            raise ValueError(
231                "Expected n_samples >= n_components "
232                f"but got n_components = {self.n_components}, "
233                f"n_samples = {X.shape[0]}"
234            )
235        self._check_initial_parameters(X)
236
237        # if we enable warm_start, we will have a unique initialisation
238        do_init = not (self.warm_start and hasattr(self, "converged_"))
239        n_init = self.n_init if do_init else 1
240
241        max_lower_bound = -np.inf
242        self.converged_ = False
243
244        random_state = check_random_state(self.random_state)
245
246        n_samples, _ = X.shape
247        for init in range(n_init):
248            self._print_verbose_msg_init_beg(init)
249
250            if do_init:
251                self._initialize_parameters(X, random_state)
252
253            lower_bound = -np.inf if do_init else self.lower_bound_
254
255            for n_iter in range(1, self.max_iter + 1):
256                prev_lower_bound = lower_bound
257
258                log_prob_norm, log_resp = self._e_step(X)
259                self._m_step(X, log_resp)
260                lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)
261
262                change = lower_bound - prev_lower_bound
263                self._print_verbose_msg_iter_end(n_iter, change)
264
265                if abs(change) < self.tol:
266                    self.converged_ = True
267                    break
268
269            self._print_verbose_msg_init_end(lower_bound)
270
271            if lower_bound > max_lower_bound or max_lower_bound == -np.inf:
272                max_lower_bound = lower_bound
273                best_params = self._get_parameters()
274                best_n_iter = n_iter
275
276        if not self.converged_:
277            warnings.warn(
278                "Initialization %d did not converge. "
279                "Try different init parameters, "
280                "or increase max_iter, tol "
281                "or check for degenerate data." % (init + 1),
282                ConvergenceWarning,
283            )
284
285        self._set_parameters(best_params)
286        self.n_iter_ = best_n_iter
287        self.lower_bound_ = max_lower_bound
288
289        # Always do a final e-step to guarantee that the labels returned by
290        # fit_predict(X) are always consistent with fit(X).predict(X)
291        # for any value of max_iter and tol (and any random_state).
292        _, log_resp = self._e_step(X)
293
294        return log_resp.argmax(axis=1)
295
296    def _e_step(self, X):
297        """E step.
298
299        Parameters
300        ----------
301        X : array-like of shape (n_samples, n_features)
302
303        Returns
304        -------
305        log_prob_norm : float
306            Mean of the logarithms of the probabilities of each sample in X
307
308        log_responsibility : array, shape (n_samples, n_components)
309            Logarithm of the posterior probabilities (or responsibilities) of
310            the point of each sample in X.
311        """
312        log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
313        return np.mean(log_prob_norm), log_resp
314
315    @abstractmethod
316    def _m_step(self, X, log_resp):
317        """M step.
318
319        Parameters
320        ----------
321        X : array-like of shape (n_samples, n_features)
322
323        log_resp : array-like of shape (n_samples, n_components)
324            Logarithm of the posterior probabilities (or responsibilities) of
325            the point of each sample in X.
326        """
327        pass
328
329    @abstractmethod
330    def _get_parameters(self):
331        pass
332
333    @abstractmethod
334    def _set_parameters(self, params):
335        pass
336
337    def score_samples(self, X):
338        """Compute the log-likelihood of each sample.
339
340        Parameters
341        ----------
342        X : array-like of shape (n_samples, n_features)
343            List of n_features-dimensional data points. Each row
344            corresponds to a single data point.
345
346        Returns
347        -------
348        log_prob : array, shape (n_samples,)
349            Log-likelihood of each sample in `X` under the current model.
350        """
351        check_is_fitted(self)
352        X = self._validate_data(X, reset=False)
353
354        return logsumexp(self._estimate_weighted_log_prob(X), axis=1)
355
356    def score(self, X, y=None):
357        """Compute the per-sample average log-likelihood of the given data X.
358
359        Parameters
360        ----------
361        X : array-like of shape (n_samples, n_dimensions)
362            List of n_features-dimensional data points. Each row
363            corresponds to a single data point.
364
365        y : Ignored
366            Not used, present for API consistency by convention.
367
368        Returns
369        -------
370        log_likelihood : float
371            Log-likelihood of `X` under the Gaussian mixture model.
372        """
373        return self.score_samples(X).mean()
374
375    def predict(self, X):
376        """Predict the labels for the data samples in X using trained model.
377
378        Parameters
379        ----------
380        X : array-like of shape (n_samples, n_features)
381            List of n_features-dimensional data points. Each row
382            corresponds to a single data point.
383
384        Returns
385        -------
386        labels : array, shape (n_samples,)
387            Component labels.
388        """
389        check_is_fitted(self)
390        X = self._validate_data(X, reset=False)
391        return self._estimate_weighted_log_prob(X).argmax(axis=1)
392
393    def predict_proba(self, X):
394        """Evaluate the components' density for each sample.
395
396        Parameters
397        ----------
398        X : array-like of shape (n_samples, n_features)
399            List of n_features-dimensional data points. Each row
400            corresponds to a single data point.
401
402        Returns
403        -------
404        resp : array, shape (n_samples, n_components)
405            Density of each Gaussian component for each sample in X.
406        """
407        check_is_fitted(self)
408        X = self._validate_data(X, reset=False)
409        _, log_resp = self._estimate_log_prob_resp(X)
410        return np.exp(log_resp)
411
412    def sample(self, n_samples=1):
413        """Generate random samples from the fitted Gaussian distribution.
414
415        Parameters
416        ----------
417        n_samples : int, default=1
418            Number of samples to generate.
419
420        Returns
421        -------
422        X : array, shape (n_samples, n_features)
423            Randomly generated sample.
424
425        y : array, shape (nsamples,)
426            Component labels.
427        """
428        check_is_fitted(self)
429
430        if n_samples < 1:
431            raise ValueError(
432                "Invalid value for 'n_samples': %d . The sampling requires at "
433                "least one sample." % (self.n_components)
434            )
435
436        _, n_features = self.means_.shape
437        rng = check_random_state(self.random_state)
438        n_samples_comp = rng.multinomial(n_samples, self.weights_)
439
440        if self.covariance_type == "full":
441            X = np.vstack(
442                [
443                    rng.multivariate_normal(mean, covariance, int(sample))
444                    for (mean, covariance, sample) in zip(
445                        self.means_, self.covariances_, n_samples_comp
446                    )
447                ]
448            )
449        elif self.covariance_type == "tied":
450            X = np.vstack(
451                [
452                    rng.multivariate_normal(mean, self.covariances_, int(sample))
453                    for (mean, sample) in zip(self.means_, n_samples_comp)
454                ]
455            )
456        else:
457            X = np.vstack(
458                [
459                    mean + rng.randn(sample, n_features) * np.sqrt(covariance)
460                    for (mean, covariance, sample) in zip(
461                        self.means_, self.covariances_, n_samples_comp
462                    )
463                ]
464            )
465
466        y = np.concatenate(
467            [np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
468        )
469
470        return (X, y)
471
472    def _estimate_weighted_log_prob(self, X):
473        """Estimate the weighted log-probabilities, log P(X | Z) + log weights.
474
475        Parameters
476        ----------
477        X : array-like of shape (n_samples, n_features)
478
479        Returns
480        -------
481        weighted_log_prob : array, shape (n_samples, n_component)
482        """
483        return self._estimate_log_prob(X) + self._estimate_log_weights()
484
485    @abstractmethod
486    def _estimate_log_weights(self):
487        """Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm.
488
489        Returns
490        -------
491        log_weight : array, shape (n_components, )
492        """
493        pass
494
495    @abstractmethod
496    def _estimate_log_prob(self, X):
497        """Estimate the log-probabilities log P(X | Z).
498
499        Compute the log-probabilities per each component for each sample.
500
501        Parameters
502        ----------
503        X : array-like of shape (n_samples, n_features)
504
505        Returns
506        -------
507        log_prob : array, shape (n_samples, n_component)
508        """
509        pass
510
511    def _estimate_log_prob_resp(self, X):
512        """Estimate log probabilities and responsibilities for each sample.
513
514        Compute the log probabilities, weighted log probabilities per
515        component and responsibilities for each sample in X with respect to
516        the current state of the model.
517
518        Parameters
519        ----------
520        X : array-like of shape (n_samples, n_features)
521
522        Returns
523        -------
524        log_prob_norm : array, shape (n_samples,)
525            log p(X)
526
527        log_responsibilities : array, shape (n_samples, n_components)
528            logarithm of the responsibilities
529        """
530        weighted_log_prob = self._estimate_weighted_log_prob(X)
531        log_prob_norm = logsumexp(weighted_log_prob, axis=1)
532        with np.errstate(under="ignore"):
533            # ignore underflow
534            log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
535        return log_prob_norm, log_resp
536
537    def _print_verbose_msg_init_beg(self, n_init):
538        """Print verbose message on initialization."""
539        if self.verbose == 1:
540            print("Initialization %d" % n_init)
541        elif self.verbose >= 2:
542            print("Initialization %d" % n_init)
543            self._init_prev_time = time()
544            self._iter_prev_time = self._init_prev_time
545
546    def _print_verbose_msg_iter_end(self, n_iter, diff_ll):
547        """Print verbose message on initialization."""
548        if n_iter % self.verbose_interval == 0:
549            if self.verbose == 1:
550                print("  Iteration %d" % n_iter)
551            elif self.verbose >= 2:
552                cur_time = time()
553                print(
554                    "  Iteration %d\t time lapse %.5fs\t ll change %.5f"
555                    % (n_iter, cur_time - self._iter_prev_time, diff_ll)
556                )
557                self._iter_prev_time = cur_time
558
559    def _print_verbose_msg_init_end(self, ll):
560        """Print verbose message on the end of iteration."""
561        if self.verbose == 1:
562            print("Initialization converged: %s" % self.converged_)
563        elif self.verbose >= 2:
564            print(
565                "Initialization converged: %s\t time lapse %.5fs\t ll %.5f"
566                % (self.converged_, time() - self._init_prev_time, ll)
567            )
568