1 // This file is part of the FXT library.
2 // Copyright (C) 2010, 2011, 2012, 2017, 2018 Joerg Arndt
3 // License: GNU General Public License version 3 or later,
4 // see the file COPYING.txt in the main directory.
5
6 #include "fxttypes.h"
7 #include "mod/mod.h"
8 #include "mod/factor.h"
9 #include "mod/numtheory.h"
10 //#include "mod/modarith.h"
11
12
13 #include "fxtio.h"
14
15
16 //#define INIT_DEBUG // define for debug
17 #ifdef INIT_DEBUG
18 #include "jjassert.h"
19 #warning '******** INIT_DEBUG'
20 #define MI_ASSERT(c, s) { if ( !(c) ) { cout << s << endl; jjassert(0); } }
21 #else
22 #define MI_ASSERT(c, s) { ; }
23 #endif
24
25
26 //class mod_cleanup
27 //{
28 //public:
29 // // no data
30 // mod_cleanup() { ; } // nop
31 // ~mod_cleanup()
32 // {
33 // mod_reset();
34 // }
35 //};
36 //// -------------------------
37 //
38 //static mod_cleanup do_mod_cleanup;
39
40 void
mod_reset()41 mod_reset()
42 {
43 #ifdef CLASS_MOD_USE_M1DD
44 mod::m1dd = 0.0;
45 #endif
46 mod::modfact.reset();
47
48 mod::maxorder = 0;
49 mod::max2pow = 0;
50
51 mod::phi = 0;
52 mod::phifact.reset();
53
54 mod::zero.x_ = 0;
55 mod::one.x_ = 0;
56 mod::maxordelem.x_ = 0;
57
58 if ( mod::mtab != nullptr ) delete [] mod::mtab;
59 // mod::root_2pow = 0;
60 // mod::root_m2pow = 0;
61 }
62 // -------------------------
63
64
65 bool
mod_initialize(umod_t m,umod_t * primes)66 mod_initialize(umod_t m, umod_t *primes/*=nullptr*/)
67 {
68 mod_reset();
69
70 mod::modulus = m;
71 #ifdef CLASS_MOD_USE_M1DD
72 mod::m1dd = (long double)1/(long double)m;
73 #endif
74
75 #ifdef INIT_DEBUG
76 mod_info0();
77 #endif
78 MI_ASSERT( !(m&((umod_t)3<<62)), " modulus must have less than 62 bits " );
79 // MI_ASSERT( !(m&((umod_t)1<<63)), " modulus must have less than 64 bits " );
80
81
82 // +++++ only after this point we can multiply mods !
83 #ifdef INIT_DEBUG
84 cout << "MOD_INIT(): can multiply mods" << endl;
85 #endif
86
87 mod::zero = (uint)0;
88 mod::one = (uint)1;
89
90 mod::modfact.make_factorization(m, primes);
91 #ifdef INIT_DEBUG
92 cout << mod::modfact << endl;
93 #endif
94 MI_ASSERT( mod::modfact.check(), "factorization of the modulus failed" );
95 MI_ASSERT( mod::modfact.is_factorization_of(m), "factorization of the modulus failed" );
96
97 // +++++ have modulus, modfact
98 #ifdef INIT_DEBUG
99 cout << "MOD_INIT(): have modulus, modfact" << endl;
100 mod_info1();
101 mod_info1b();
102 #endif
103
104
105 if ( mod::modfact.is_prime() ) mod::phi = m-1;
106 else mod::phi = euler_phi(mod::modfact);
107
108 (mod::phifact).make_factorization(mod::phi);
109 MI_ASSERT( mod::modfact.check(), "factorization of the phi failed" );
110 MI_ASSERT( mod::phifact.is_factorization_of(mod::phi), "factorization of phi failed" );
111
112 // +++++ have phi, phifact
113 #ifdef INIT_DEBUG
114 cout << "MOD_INIT(): have phi, phifact" << endl;
115 #endif
116
117 mod::maxorder = maxorder_mod(mod::modfact);
118
119
120 // +++++ have maxorder
121 #ifdef INIT_DEBUG
122 cout << "MOD_INIT(): have maxorder" << endl;
123 mod_info2();
124 #endif
125
126
127 #ifdef INIT_DEBUG
128 cout << "MOD_INIT(): ping 1." << endl;
129 #endif
130 mod::maxordelem = maxorder_element_mod(mod::modfact, mod::phifact);
131 #ifdef INIT_DEBUG
132 cout << "MOD_INIT(): ping 2." << endl;
133 #endif
134
135 #ifdef INIT_DEBUG
136 umod_t rr = (mod::maxordelem).order();
137 #endif
138
139 MI_ASSERT( rr != 0, "oops, order of primitive root is ==0" );
140 MI_ASSERT( rr == mod::maxorder,
141 "oops, order of primitive root is != maxorder" );
142
143
144 // +++++ have element of maximal order (primitive root if m cyclic)
145 #ifdef INIT_DEBUG
146 cout << "MOD_INIT(): have element of maximal order" << endl;
147 mod_info3();
148 #endif
149
150
151 mod::max2pow = 0;
152 {
153 umod_t t = mod::maxorder;
154 while ( 0==(t&1) ) { ++(mod::max2pow); t>>=1; }
155 }
156
157 const int m2 = (int)mod::max2pow; // jjcast
158 const umod_t z = ((mod::maxorder) >> m2);
159 // cout << "m2=" << m2 << endl;
160 // cout << "z=" << z << endl;
161
162 {
163 const ulong nn = (ulong)(m2+1); // jjcast
164
165 mod::mtab = new mod[6*nn];
166 mod *p = mod::mtab;
167
168 mod::root_2pow = p;
169 mod::root_m2pow = p + 1*nn;
170 mod::cos = p + 2*nn;
171 mod::isin = p + 3*nn;
172 mod::cosm = p + 4*nn;
173 mod::isinm = p + 5*nn;
174 }
175
176
177 // set up roots of order 2**(+-k):
178 mod t2;
179 #ifdef INIT_DEBUG
180 cout << "MOD_INIT(): searching maxorder-element ..." << endl;
181 #endif
182 t2 = mod::maxordelem.pow(z);
183 #ifdef INIT_DEBUG
184 cout << "MOD_INIT(): searching maxorder-element done." << endl;
185 #endif
186 // cout << "t=" << t << endl;
187 for (int k=m2; k>=0; --k)
188 {
189 mod::root_2pow[k] = t2;
190 mod::root_m2pow[k] = t2.inv();
191 t2 *= t2;
192 }
193
194 // set up sin/cos of order 2**(+-k):
195 for (int k=0; k<=m2; ++k)
196 {
197 mod tr = mod::root2pow(k);
198 mod tr2 = tr + tr;
199 mod tc = (tr * tr + mod::one) / tr2;
200 mod::cos[k] = tc;
201 mod::isin[k] = tr - tc;
202 }
203
204 for (int k=1; k<=m2; ++k)
205 {
206 mod tr = mod::root2pow(-k);
207 mod tr2 = tr + tr;
208 mod tc = (tr * tr + mod::one) / tr2;
209 mod::cosm[k] = tc;
210 mod::isinm[k] = tr - tc;
211 }
212
213
214
215
216 // +++++ from now on we can do NTTs ...
217 #ifdef INIT_DEBUG
218 cout << "MOD_INIT(): can do NTTs" << endl;
219 mod_info4();
220 #endif
221
222 #ifdef INIT_DEBUG
223 for (int k=0; k<=m2; ++k)
224 {
225 umod_t r; // only used if MI_ASSERT is defined
226 umod_t p2k = (1ULL << (uint)k);
227
228 r = (mod::root2pow(k)).order();
229 MI_ASSERT( r==p2k, "order(root_2pow(k)) is != 2**k" );
230
231 r = (mod::root2pow(-k)).order();
232 MI_ASSERT( r==p2k, "order(root2pow(-k)) is != 2**k" );
233
234 mod t = (mod::root2pow(k))*(mod::root2pow(-k));
235 MI_ASSERT( t==mod::one, "root2pow(k) * root2pow(-k) is != 1" );
236 }
237 #endif
238
239 #ifdef INIT_DEBUG
240 mod_info99();
241 #endif
242
243 return true;
244 }
245 // -------------------------
246