1# -*- coding: utf-8 -*-
2"""
3Created on Fri Sep 15 13:38:13 2017
4
5Author: Josef Perktold
6"""
7
8import numpy as np
9from numpy.testing import assert_allclose
10import pytest
11
12from statsmodels.discrete.discrete_model import Poisson
13import statsmodels.discrete._diagnostics_count as dia
14
15
16class TestCountDiagnostic(object):
17
18    @classmethod
19    def setup_class(cls):
20
21        expected_params = [1, 1, 0.5]
22        np.random.seed(987123)
23        nobs = 500
24        exog = np.ones((nobs, 2))
25        exog[:nobs//2, 1] = 0
26        # offset is used to create misspecification of the model
27        # for predicted probabilities conditional moment test
28        #offset = 0.5 * np.random.randn(nobs)
29        #range_mix = 0.5
30        #offset = -range_mix / 2 + range_mix * np.random.rand(nobs)
31        offset = 0
32        mu_true = np.exp(exog.dot(expected_params[:-1]) + offset)
33
34        endog_poi = np.random.poisson(mu_true / 5)
35        # endog3 = distr.zigenpoisson.rvs(mu_true, 0,
36        #                                2, 0.01, size=mu_true.shape)
37
38        model_poi = Poisson(endog_poi, exog)
39        res_poi = model_poi.fit(method='bfgs', maxiter=5000)
40        cls.exog = exog
41        cls.endog = endog_poi
42        cls.res = res_poi
43        cls.nobs = nobs
44
45    def test_count(self):
46        # partially smoke
47        tzi1 = dia.test_poisson_zeroinflation(self.res)
48
49        tzi2 = dia.test_poisson_zeroinflation_brock(self.res)
50        # compare two implementation in special case
51        assert_allclose(tzi1[:2], (tzi2[0]**2, tzi2[1]), rtol=1e-5)
52
53        tzi3 = dia.test_poisson_zeroinflation(self.res, self.exog)
54
55        # regression test
56        tzi3_1 = (0.79863597832443878, 0.67077736750318928, 2, 2)
57        assert_allclose(tzi3, tzi3_1, rtol=5e-4)
58
59    @pytest.mark.matplotlib
60    def test_probs(self, close_figures):
61        nobs = self.nobs
62        probs = self.res.predict_prob()
63        freq = np.bincount(self.endog) / nobs
64
65        tzi = dia.test_chisquare_prob(self.res, probs[:, :2])
66        # regression numbers
67        tzi1 = (0.387770845, 0.5334734738)
68        assert_allclose(tzi[:2], tzi1, rtol=5e-5)
69
70        # smoke test for plot
71        dia.plot_probs(freq, probs.mean(0))
72