1# -*- coding: utf-8 -*-
2"""
3
4Created on Wed Mar 28 15:34:18 2012
5
6Author: Josef Perktold
7"""
8from io import StringIO
9
10import numpy as np
11from numpy.testing import assert_almost_equal, assert_equal
12from scipy import stats
13
14from statsmodels.stats.libqsturng import qsturng
15from statsmodels.stats.multicomp import tukeyhsd
16import statsmodels.stats.multicomp as multi
17
18
19ss = '''\
20  43.9  1   1
21  39.0  1   2
22  46.7  1   3
23  43.8  1   4
24  44.2  1   5
25  47.7  1   6
26  43.6  1   7
27  38.9  1   8
28  43.6  1   9
29  40.0  1  10
30  89.8  2   1
31  87.1  2   2
32  92.7  2   3
33  90.6  2   4
34  87.7  2   5
35  92.4  2   6
36  86.1  2   7
37  88.1  2   8
38  90.8  2   9
39  89.1  2  10
40  68.4  3   1
41  69.3  3   2
42  68.5  3   3
43  66.4  3   4
44  70.0  3   5
45  68.1  3   6
46  70.6  3   7
47  65.2  3   8
48  63.8  3   9
49  69.2  3  10
50  36.2  4   1
51  45.2  4   2
52  40.7  4   3
53  40.5  4   4
54  39.3  4   5
55  40.3  4   6
56  43.2  4   7
57  38.7  4   8
58  40.9  4   9
59  39.7  4  10'''
60
61#idx   Treatment StressReduction
62ss2 = '''\
631     mental               2
642     mental               2
653     mental               3
664     mental               4
675     mental               4
686     mental               5
697     mental               3
708     mental               4
719     mental               4
7210    mental               4
7311  physical               4
7412  physical               4
7513  physical               3
7614  physical               5
7715  physical               4
7816  physical               1
7917  physical               1
8018  physical               2
8119  physical               3
8220  physical               3
8321   medical               1
8422   medical               2
8523   medical               2
8624   medical               2
8725   medical               3
8826   medical               2
8927   medical               3
9028   medical               1
9129   medical               3
9230   medical               1'''
93
94ss3 = '''\
951 24.5
961 23.5
971 26.4
981 27.1
991 29.9
1002 28.4
1012 34.2
1022 29.5
1032 32.2
1042 30.1
1053 26.1
1063 28.3
1073 24.3
1083 26.2
1093 27.8'''
110
111cylinders = np.array([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 6, 6, 6, 4, 4,
112                    4, 4, 4, 4, 6, 8, 8, 8, 8, 4, 4, 4, 4, 8, 8, 8, 8, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6,
113                    6, 6, 4, 4, 4, 4, 4, 8, 4, 6, 6, 8, 8, 8, 8, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
114                    4, 4, 4, 4, 4, 4, 4, 6, 6, 4, 6, 4, 4, 4, 4, 4, 4, 4, 4])
115cyl_labels = np.array(['USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'France',
116    'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'Japan', 'USA', 'USA', 'USA', 'Japan',
117    'Germany', 'France', 'Germany', 'Sweden', 'Germany', 'USA', 'USA', 'USA', 'USA', 'USA', 'Germany',
118    'USA', 'USA', 'France', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'Germany',
119    'Japan', 'USA', 'USA', 'USA', 'USA', 'Germany', 'Japan', 'Japan', 'USA', 'Sweden', 'USA', 'France',
120    'Japan', 'Germany', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA',
121    'Germany', 'Japan', 'Japan', 'USA', 'USA', 'Japan', 'Japan', 'Japan', 'Japan', 'Japan', 'Japan', 'USA',
122    'USA', 'USA', 'USA', 'Japan', 'USA', 'USA', 'USA', 'Germany', 'USA', 'USA', 'USA'])
123
124dta = np.recfromtxt(StringIO(ss), names=("Rust","Brand","Replication"))
125dta2 = np.recfromtxt(StringIO(ss2), names = ("idx", "Treatment", "StressReduction"))
126dta3 = np.recfromtxt(StringIO(ss3), names = ("Brand", "Relief"))
127
128#print tukeyhsd(dta['Brand'], dta['Rust'])
129
130def get_thsd(mci):
131    var_ = np.var(mci.groupstats.groupdemean(), ddof=len(mci.groupsunique))
132    means = mci.groupstats.groupmean
133    nobs = mci.groupstats.groupnobs
134    resi = tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=qsturng(0.95, len(means), (nobs-1).sum()))
135    print(resi[4])
136    var2 = (mci.groupstats.groupvarwithin() * (nobs - 1)).sum() \
137                                                        / (nobs - 1).sum()
138    assert_almost_equal(var_, var2, decimal=14)
139    return resi
140
141mc = multi.MultiComparison(dta['Rust'], dta['Brand'])
142res = mc.tukeyhsd()
143print(res)
144
145mc2 = multi.MultiComparison(dta2['StressReduction'], dta2['Treatment'])
146res2 = mc2.tukeyhsd()
147print(res2)
148
149mc2s = multi.MultiComparison(dta2['StressReduction'][3:29], dta2['Treatment'][3:29])
150res2s = mc2s.tukeyhsd()
151print(res2s)
152res2s_001 = mc2s.tukeyhsd(alpha=0.01)
153#R result
154tukeyhsd2s = np.array([1.888889,0.8888889,-1,0.2658549,-0.5908785,-2.587133,3.511923,2.368656,0.5871331,0.002837638,0.150456,0.1266072]).reshape(3,4, order='F')
155assert_almost_equal(res2s_001.confint, tukeyhsd2s[:,1:3], decimal=3)
156
157mc3 = multi.MultiComparison(dta3['Relief'], dta3['Brand'])
158res3 = mc3.tukeyhsd()
159print(res3)
160
161tukeyhsd4 = multi.MultiComparison(cylinders, cyl_labels, group_order=["Sweden", "Japan", "Germany", "France", "USA"])
162res4 = tukeyhsd4.tukeyhsd()
163print(res4)
164try:
165    import matplotlib.pyplot as plt
166    fig = res4.plot_simultaneous("USA")
167    plt.show()
168except Exception as e:
169    print(e)
170
171for mci in [mc, mc2, mc3]:
172    get_thsd(mci)
173
174print(mc2.allpairtest(stats.ttest_ind, method='b')[0])
175
176'''same as SAS:
177>>> np.var(mci.groupstats.groupdemean(), ddof=3)
1784.6773333333333351
179>>> var_ = np.var(mci.groupstats.groupdemean(), ddof=3)
180>>> tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=qsturng(0.95, 3, 12))[4]
181array([[ 0.95263648,  8.24736352],
182       [-3.38736352,  3.90736352],
183       [-7.98736352, -0.69263648]])
184>>> tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=3.77278)[4]
185array([[ 0.95098508,  8.24901492],
186       [-3.38901492,  3.90901492],
187       [-7.98901492, -0.69098508]])
188'''
189
190ss5 = '''\
191Comparisons significant at the 0.05 level are indicated by ***.
192BRAND
193Comparison	Difference
194Between
195Means	Simultaneous 95% Confidence Limits	 Sign.
1962 - 3	4.340	0.691	7.989	***
1972 - 1	4.600	0.951	8.249	***
1983 - 2	-4.340	-7.989	-0.691	***
1993 - 1	0.260	-3.389	3.909	 -
2001 - 2	-4.600	-8.249	-0.951	***
2011 - 3	-0.260	-3.909	3.389	'''
202
203ss5 = '''\
2042 - 3	4.340	0.691	7.989	***
2052 - 1	4.600	0.951	8.249	***
2063 - 2	-4.340	-7.989	-0.691	***
2073 - 1	0.260	-3.389	3.909	 -
2081 - 2	-4.600	-8.249	-0.951	***
2091 - 3	-0.260	-3.909	3.389	'''
210
211dta5 = np.recfromtxt(StringIO(ss5), names = ('pair', 'mean', 'lower', 'upper', 'sig'), delimiter='\t')
212
213sas_ = dta5[[1,3,2]]
214confint1 = res3.confint
215confint2 = sas_[['lower','upper']].view(float).reshape((3,2))
216assert_almost_equal(confint1, confint2, decimal=2)
217reject1 = res3.reject
218reject2 = sas_['sig'] == '***'
219assert_equal(reject1, reject2)
220meandiff1 = res3.meandiffs
221meandiff2 = sas_['mean']
222assert_almost_equal(meandiff1, meandiff2, decimal=14)
223