1# -*- coding: utf-8 -*- 2""" 3 4Created on Wed Mar 28 15:34:18 2012 5 6Author: Josef Perktold 7""" 8from io import StringIO 9 10import numpy as np 11from numpy.testing import assert_almost_equal, assert_equal 12from scipy import stats 13 14from statsmodels.stats.libqsturng import qsturng 15from statsmodels.stats.multicomp import tukeyhsd 16import statsmodels.stats.multicomp as multi 17 18 19ss = '''\ 20 43.9 1 1 21 39.0 1 2 22 46.7 1 3 23 43.8 1 4 24 44.2 1 5 25 47.7 1 6 26 43.6 1 7 27 38.9 1 8 28 43.6 1 9 29 40.0 1 10 30 89.8 2 1 31 87.1 2 2 32 92.7 2 3 33 90.6 2 4 34 87.7 2 5 35 92.4 2 6 36 86.1 2 7 37 88.1 2 8 38 90.8 2 9 39 89.1 2 10 40 68.4 3 1 41 69.3 3 2 42 68.5 3 3 43 66.4 3 4 44 70.0 3 5 45 68.1 3 6 46 70.6 3 7 47 65.2 3 8 48 63.8 3 9 49 69.2 3 10 50 36.2 4 1 51 45.2 4 2 52 40.7 4 3 53 40.5 4 4 54 39.3 4 5 55 40.3 4 6 56 43.2 4 7 57 38.7 4 8 58 40.9 4 9 59 39.7 4 10''' 60 61#idx Treatment StressReduction 62ss2 = '''\ 631 mental 2 642 mental 2 653 mental 3 664 mental 4 675 mental 4 686 mental 5 697 mental 3 708 mental 4 719 mental 4 7210 mental 4 7311 physical 4 7412 physical 4 7513 physical 3 7614 physical 5 7715 physical 4 7816 physical 1 7917 physical 1 8018 physical 2 8119 physical 3 8220 physical 3 8321 medical 1 8422 medical 2 8523 medical 2 8624 medical 2 8725 medical 3 8826 medical 2 8927 medical 3 9028 medical 1 9129 medical 3 9230 medical 1''' 93 94ss3 = '''\ 951 24.5 961 23.5 971 26.4 981 27.1 991 29.9 1002 28.4 1012 34.2 1022 29.5 1032 32.2 1042 30.1 1053 26.1 1063 28.3 1073 24.3 1083 26.2 1093 27.8''' 110 111cylinders = np.array([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 6, 6, 6, 4, 4, 112 4, 4, 4, 4, 6, 8, 8, 8, 8, 4, 4, 4, 4, 8, 8, 8, 8, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 113 6, 6, 4, 4, 4, 4, 4, 8, 4, 6, 6, 8, 8, 8, 8, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 114 4, 4, 4, 4, 4, 4, 4, 6, 6, 4, 6, 4, 4, 4, 4, 4, 4, 4, 4]) 115cyl_labels = np.array(['USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'France', 116 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'Japan', 'USA', 'USA', 'USA', 'Japan', 117 'Germany', 'France', 'Germany', 'Sweden', 'Germany', 'USA', 'USA', 'USA', 'USA', 'USA', 'Germany', 118 'USA', 'USA', 'France', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'Germany', 119 'Japan', 'USA', 'USA', 'USA', 'USA', 'Germany', 'Japan', 'Japan', 'USA', 'Sweden', 'USA', 'France', 120 'Japan', 'Germany', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 'USA', 121 'Germany', 'Japan', 'Japan', 'USA', 'USA', 'Japan', 'Japan', 'Japan', 'Japan', 'Japan', 'Japan', 'USA', 122 'USA', 'USA', 'USA', 'Japan', 'USA', 'USA', 'USA', 'Germany', 'USA', 'USA', 'USA']) 123 124dta = np.recfromtxt(StringIO(ss), names=("Rust","Brand","Replication")) 125dta2 = np.recfromtxt(StringIO(ss2), names = ("idx", "Treatment", "StressReduction")) 126dta3 = np.recfromtxt(StringIO(ss3), names = ("Brand", "Relief")) 127 128#print tukeyhsd(dta['Brand'], dta['Rust']) 129 130def get_thsd(mci): 131 var_ = np.var(mci.groupstats.groupdemean(), ddof=len(mci.groupsunique)) 132 means = mci.groupstats.groupmean 133 nobs = mci.groupstats.groupnobs 134 resi = tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=qsturng(0.95, len(means), (nobs-1).sum())) 135 print(resi[4]) 136 var2 = (mci.groupstats.groupvarwithin() * (nobs - 1)).sum() \ 137 / (nobs - 1).sum() 138 assert_almost_equal(var_, var2, decimal=14) 139 return resi 140 141mc = multi.MultiComparison(dta['Rust'], dta['Brand']) 142res = mc.tukeyhsd() 143print(res) 144 145mc2 = multi.MultiComparison(dta2['StressReduction'], dta2['Treatment']) 146res2 = mc2.tukeyhsd() 147print(res2) 148 149mc2s = multi.MultiComparison(dta2['StressReduction'][3:29], dta2['Treatment'][3:29]) 150res2s = mc2s.tukeyhsd() 151print(res2s) 152res2s_001 = mc2s.tukeyhsd(alpha=0.01) 153#R result 154tukeyhsd2s = np.array([1.888889,0.8888889,-1,0.2658549,-0.5908785,-2.587133,3.511923,2.368656,0.5871331,0.002837638,0.150456,0.1266072]).reshape(3,4, order='F') 155assert_almost_equal(res2s_001.confint, tukeyhsd2s[:,1:3], decimal=3) 156 157mc3 = multi.MultiComparison(dta3['Relief'], dta3['Brand']) 158res3 = mc3.tukeyhsd() 159print(res3) 160 161tukeyhsd4 = multi.MultiComparison(cylinders, cyl_labels, group_order=["Sweden", "Japan", "Germany", "France", "USA"]) 162res4 = tukeyhsd4.tukeyhsd() 163print(res4) 164try: 165 import matplotlib.pyplot as plt 166 fig = res4.plot_simultaneous("USA") 167 plt.show() 168except Exception as e: 169 print(e) 170 171for mci in [mc, mc2, mc3]: 172 get_thsd(mci) 173 174print(mc2.allpairtest(stats.ttest_ind, method='b')[0]) 175 176'''same as SAS: 177>>> np.var(mci.groupstats.groupdemean(), ddof=3) 1784.6773333333333351 179>>> var_ = np.var(mci.groupstats.groupdemean(), ddof=3) 180>>> tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=qsturng(0.95, 3, 12))[4] 181array([[ 0.95263648, 8.24736352], 182 [-3.38736352, 3.90736352], 183 [-7.98736352, -0.69263648]]) 184>>> tukeyhsd(means, nobs, var_, df=None, alpha=0.05, q_crit=3.77278)[4] 185array([[ 0.95098508, 8.24901492], 186 [-3.38901492, 3.90901492], 187 [-7.98901492, -0.69098508]]) 188''' 189 190ss5 = '''\ 191Comparisons significant at the 0.05 level are indicated by ***. 192BRAND 193Comparison Difference 194Between 195Means Simultaneous 95% Confidence Limits Sign. 1962 - 3 4.340 0.691 7.989 *** 1972 - 1 4.600 0.951 8.249 *** 1983 - 2 -4.340 -7.989 -0.691 *** 1993 - 1 0.260 -3.389 3.909 - 2001 - 2 -4.600 -8.249 -0.951 *** 2011 - 3 -0.260 -3.909 3.389 ''' 202 203ss5 = '''\ 2042 - 3 4.340 0.691 7.989 *** 2052 - 1 4.600 0.951 8.249 *** 2063 - 2 -4.340 -7.989 -0.691 *** 2073 - 1 0.260 -3.389 3.909 - 2081 - 2 -4.600 -8.249 -0.951 *** 2091 - 3 -0.260 -3.909 3.389 ''' 210 211dta5 = np.recfromtxt(StringIO(ss5), names = ('pair', 'mean', 'lower', 'upper', 'sig'), delimiter='\t') 212 213sas_ = dta5[[1,3,2]] 214confint1 = res3.confint 215confint2 = sas_[['lower','upper']].view(float).reshape((3,2)) 216assert_almost_equal(confint1, confint2, decimal=2) 217reject1 = res3.reject 218reject2 = sas_['sig'] == '***' 219assert_equal(reject1, reject2) 220meandiff1 = res3.meandiffs 221meandiff2 = sas_['mean'] 222assert_almost_equal(meandiff1, meandiff2, decimal=14) 223