1 /*
2     Copyright (C) 2008, 2009, William Hart
3     Copyright (C) 2010 Fredrik Johansson
4 
5     This file is part of FLINT.
6 
7     FLINT is free software: you can redistribute it and/or modify it under
8     the terms of the GNU Lesser General Public License (LGPL) as published
9     by the Free Software Foundation; either version 2.1 of the License, or
10     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
11 */
12 
13 #include <gmp.h>
14 #include "flint.h"
15 #include "ulong_extras.h"
16 #include "fmpz.h"
17 #include "nmod_vec.h"
18 #include "fmpz_vec.h"
19 
20 
fmpz_comb_temp_init(fmpz_comb_temp_t temp,const fmpz_comb_t comb)21 void fmpz_comb_temp_init(fmpz_comb_temp_t temp, const fmpz_comb_t comb)
22 {
23     slong n, i, j;
24 
25     /* Allocate space for comb_temp */
26     temp->n = n = comb->n;
27     temp->comb_temp = (fmpz **) flint_malloc(n * sizeof(fmpz *));
28 
29     j = (WORD(1) << (n - 1));
30     for (i = 0; i < n; i++)
31     {
32         temp->comb_temp[i] = _fmpz_vec_init(j);
33         j /= 2;
34     }
35 
36     fmpz_init(temp->temp);
37     fmpz_init(temp->temp2);
38 }
39 
40 void
fmpz_comb_init(fmpz_comb_t comb,mp_srcptr primes,slong num_primes)41 fmpz_comb_init(fmpz_comb_t comb, mp_srcptr primes, slong num_primes)
42 {
43     slong i, j;
44     slong n, num;
45     ulong log_comb, log_res;
46     fmpz_t temp, temp2;
47 
48     comb->primes = primes;
49     comb->num_primes = num_primes;
50 
51     n = FLINT_BIT_COUNT(num_primes);
52     comb->n = n;
53 
54     /* Create nmod_poly modulus information */
55 	comb->mod = (nmod_t *) flint_malloc(sizeof(nmod_t) * num_primes);
56     for (i = 0; i < num_primes; i++)
57         nmod_init(&comb->mod[i], primes[i]);
58 
59     /* Nothing to do */
60 	if (n == 0)
61         return;
62 
63 	/* Allocate space for comb and res */
64     comb->comb = (fmpz **) flint_malloc(n * sizeof(fmpz *));
65     comb->res = (fmpz **) flint_malloc(n * sizeof(fmpz *));
66 
67     /* Size of top level */
68     j = (WORD(1) << (n - 1));
69 
70     /* Initialise arrays at each level */
71 	for (i = 0; i < n; i++)
72     {
73         comb->comb[i] = _fmpz_vec_init(j);
74         comb->res[i] = _fmpz_vec_init(j);
75         j /= 2;
76     }
77 
78 	/* Compute products of pairs of primes and place in comb */
79     for (i = 0, j = 0; i + 2 <= num_primes; i += 2, j++)
80     {
81         fmpz_set_ui(comb->comb[0] + j, primes[i]);
82         fmpz_mul_ui(comb->comb[0] + j, comb->comb[0] + j, primes[i+1]);
83     }
84 
85     /* In case number of primes is odd */
86     if (i < num_primes)
87     {
88         fmpz_set_ui(comb->comb[0] + j, primes[i]);
89 	    i += 2;
90 	    j++;
91 	}
92 
93     /* Set the rest of the entries on that row of the comb to 1 */
94     num = (WORD(1) << n);
95 	for (; i < num; i += 2, j++)
96     {
97         fmpz_one(comb->comb[0] + j);
98     }
99 
100     /* Compute rest of comb by multiplying in pairs */
101     log_comb = 1;
102     num /= 2;
103     while (num >= 2)
104     {
105         for (i = 0, j = 0; i < num; i += 2, j++)
106         {
107             fmpz_mul(comb->comb[log_comb] + j, comb->comb[log_comb-1] + i,
108                 comb->comb[log_comb-1] + i + 1);
109         }
110         log_comb++;
111         num /= 2;
112     }
113 
114     /* Compute inverses from pairs of primes */
115     fmpz_init(temp);
116     fmpz_init(temp2);
117 
118     for (i = 0, j = 0; i + 2 <= num_primes; i += 2, j++)
119     {
120         fmpz_set_ui(temp, primes[i]);
121         fmpz_set_ui(temp2, primes[i+1]);
122         fmpz_invmod(comb->res[0] + j, temp, temp2);
123     }
124 
125     fmpz_clear(temp);
126     fmpz_clear(temp2);
127 
128     /* Compute remaining inverses, each level
129        combining pairs from the level below */
130 	log_res = 1;
131     num = (WORD(1) << (n - 1));
132 
133     while (log_res < n)
134     {
135         for (i = 0, j = 0; i < num; i += 2, j++)
136         {
137             fmpz_invmod(comb->res[log_res] + j, comb->comb[log_res-1] + i,
138                 comb->comb[log_res-1] + i + 1);
139         }
140         log_res++;
141         num /= 2;
142     }
143 }
144