1 // This file is part of the FXT library.
2 // Copyright (C) 2010, 2011, 2012, 2017, 2018 Joerg Arndt
3 // License: GNU General Public License version 3 or later,
4 // see the file COPYING.txt in the main directory.
5 
6 #include "fxttypes.h"
7 #include "mod/mod.h"
8 #include "mod/factor.h"
9 #include "mod/numtheory.h"
10 //#include "mod/modarith.h"
11 
12 
13 #include "fxtio.h"
14 
15 
16 //#define  INIT_DEBUG  // define for debug
17 #ifdef INIT_DEBUG
18 #include "jjassert.h"
19 #warning '******** INIT_DEBUG'
20 #define  MI_ASSERT(c, s)  { if ( !(c) ) { cout << s << endl; jjassert(0); } }
21 #else
22 #define  MI_ASSERT(c, s)  { ; }
23 #endif
24 
25 
26 //class mod_cleanup
27 //{
28 //public:
29 //    // no data
30 //    mod_cleanup()  { ; }  // nop
31 //    ~mod_cleanup()
32 //    {
33 //        mod_reset();
34 //    }
35 //};
36 //// -------------------------
37 //
38 //static mod_cleanup  do_mod_cleanup;
39 
40 void
mod_reset()41 mod_reset()
42 {
43 #ifdef CLASS_MOD_USE_M1DD
44     mod::m1dd = 0.0;
45 #endif
46     mod::modfact.reset();
47 
48     mod::maxorder = 0;
49     mod::max2pow = 0;
50 
51     mod::phi = 0;
52     mod::phifact.reset();
53 
54     mod::zero.x_ = 0;
55     mod::one.x_ = 0;
56     mod::maxordelem.x_ = 0;
57 
58     if ( mod::mtab != nullptr )  delete [] mod::mtab;
59 //    mod::root_2pow = 0;
60 //    mod::root_m2pow = 0;
61 }
62 // -------------------------
63 
64 
65 bool
mod_initialize(umod_t m,umod_t * primes)66 mod_initialize(umod_t m, umod_t *primes/*=nullptr*/)
67 {
68     mod_reset();
69 
70     mod::modulus = m;
71 #ifdef CLASS_MOD_USE_M1DD
72     mod::m1dd = (long double)1/(long double)m;
73 #endif
74 
75 #ifdef INIT_DEBUG
76     mod_info0();
77 #endif
78     MI_ASSERT( !(m&((umod_t)3<<62)), " modulus must have less than 62 bits " );
79 //    MI_ASSERT( !(m&((umod_t)1<<63)), " modulus must have less than 64 bits " );
80 
81 
82     // +++++ only after this point we can multiply mods !
83 #ifdef INIT_DEBUG
84     cout << "MOD_INIT(): can multiply mods" << endl;
85 #endif
86 
87     mod::zero = (uint)0;
88     mod::one =  (uint)1;
89 
90     mod::modfact.make_factorization(m, primes);
91 #ifdef INIT_DEBUG
92     cout << mod::modfact << endl;
93 #endif
94     MI_ASSERT( mod::modfact.check(), "factorization of the modulus failed" );
95     MI_ASSERT( mod::modfact.is_factorization_of(m), "factorization of the modulus failed" );
96 
97     // +++++ have modulus, modfact
98 #ifdef INIT_DEBUG
99     cout << "MOD_INIT(): have modulus, modfact" << endl;
100     mod_info1();
101     mod_info1b();
102 #endif
103 
104 
105     if ( mod::modfact.is_prime() )  mod::phi = m-1;
106     else  mod::phi = euler_phi(mod::modfact);
107 
108     (mod::phifact).make_factorization(mod::phi);
109     MI_ASSERT( mod::modfact.check(), "factorization of the phi failed" );
110     MI_ASSERT( mod::phifact.is_factorization_of(mod::phi), "factorization of phi failed" );
111 
112     // +++++ have phi, phifact
113 #ifdef INIT_DEBUG
114     cout << "MOD_INIT(): have phi, phifact" << endl;
115 #endif
116 
117     mod::maxorder = maxorder_mod(mod::modfact);
118 
119 
120     // +++++ have maxorder
121 #ifdef INIT_DEBUG
122     cout << "MOD_INIT(): have maxorder" << endl;
123     mod_info2();
124 #endif
125 
126 
127 #ifdef INIT_DEBUG
128     cout << "MOD_INIT(): ping 1." << endl;
129 #endif
130     mod::maxordelem = maxorder_element_mod(mod::modfact, mod::phifact);
131 #ifdef INIT_DEBUG
132     cout << "MOD_INIT(): ping 2." << endl;
133 #endif
134 
135 #ifdef INIT_DEBUG
136     umod_t rr = (mod::maxordelem).order();
137 #endif
138 
139     MI_ASSERT( rr != 0, "oops, order of primitive root is ==0" );
140     MI_ASSERT( rr == mod::maxorder,
141                "oops, order of primitive root is != maxorder" );
142 
143 
144     // +++++ have element of maximal order (primitive root if m cyclic)
145 #ifdef INIT_DEBUG
146     cout << "MOD_INIT(): have element of maximal order" << endl;
147     mod_info3();
148 #endif
149 
150 
151     mod::max2pow = 0;
152     {
153         umod_t t = mod::maxorder;
154         while ( 0==(t&1) )  { ++(mod::max2pow);  t>>=1; }
155     }
156 
157     const int m2 = (int)mod::max2pow;  // jjcast
158     const umod_t z = ((mod::maxorder) >> m2);
159 //    cout << "m2=" << m2 << endl;
160 //    cout << "z=" << z << endl;
161 
162     {
163         const ulong nn = (ulong)(m2+1);  // jjcast
164 
165         mod::mtab = new mod[6*nn];
166         mod *p = mod::mtab;
167 
168         mod::root_2pow = p;
169         mod::root_m2pow = p + 1*nn;
170         mod::cos = p + 2*nn;
171         mod::isin = p + 3*nn;
172         mod::cosm = p + 4*nn;
173         mod::isinm = p + 5*nn;
174     }
175 
176 
177     // set up roots of order 2**(+-k):
178     mod t2;
179 #ifdef INIT_DEBUG
180     cout << "MOD_INIT(): searching maxorder-element ..." << endl;
181 #endif
182     t2 = mod::maxordelem.pow(z);
183 #ifdef INIT_DEBUG
184     cout << "MOD_INIT(): searching maxorder-element done." << endl;
185 #endif
186 //    cout << "t=" << t << endl;
187     for (int k=m2; k>=0; --k)
188     {
189         mod::root_2pow[k] = t2;
190         mod::root_m2pow[k] = t2.inv();
191         t2 *= t2;
192     }
193 
194     // set up sin/cos of order 2**(+-k):
195     for (int k=0; k<=m2; ++k)
196     {
197         mod tr = mod::root2pow(k);
198         mod tr2 = tr + tr;
199         mod tc = (tr * tr + mod::one) / tr2;
200         mod::cos[k] = tc;
201         mod::isin[k] = tr - tc;
202     }
203 
204     for (int k=1; k<=m2; ++k)
205     {
206         mod tr = mod::root2pow(-k);
207         mod tr2 = tr + tr;
208         mod tc = (tr * tr + mod::one) / tr2;
209         mod::cosm[k] = tc;
210         mod::isinm[k] = tr - tc;
211     }
212 
213 
214 
215 
216     // +++++ from now on we can do NTTs ...
217 #ifdef INIT_DEBUG
218     cout << "MOD_INIT(): can do NTTs" << endl;
219     mod_info4();
220 #endif
221 
222 #ifdef INIT_DEBUG
223     for (int k=0; k<=m2; ++k)
224     {
225         umod_t r;  // only used if MI_ASSERT is defined
226         umod_t p2k = (1ULL << (uint)k);
227 
228         r = (mod::root2pow(k)).order();
229         MI_ASSERT( r==p2k, "order(root_2pow(k)) is != 2**k" );
230 
231         r = (mod::root2pow(-k)).order();
232         MI_ASSERT( r==p2k, "order(root2pow(-k)) is != 2**k" );
233 
234         mod t = (mod::root2pow(k))*(mod::root2pow(-k));
235         MI_ASSERT( t==mod::one,  "root2pow(k) * root2pow(-k) is != 1" );
236     }
237 #endif
238 
239 #ifdef INIT_DEBUG
240     mod_info99();
241 #endif
242 
243     return  true;
244 }
245 // -------------------------
246