1 /*
2     Copyright (C) 2008, 2009 William Hart
3     Copyright (C) 2010, 2011 Sebastian Pancratz
4 
5     This file is part of FLINT.
6 
7     FLINT is free software: you can redistribute it and/or modify it under
8     the terms of the GNU Lesser General Public License (LGPL) as published
9     by the Free Software Foundation; either version 2.1 of the License, or
10     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
11 */
12 
13 #include <stdlib.h>
14 #include <gmp.h>
15 #include "flint.h"
16 #include "fmpz.h"
17 #include "fmpz_vec.h"
18 #include "fmpz_mod_poly.h"
19 
20 static void
__fmpz_mod_poly_divrem_divconquer(fmpz * Q,fmpz * R,const fmpz * A,slong lenA,const fmpz * B,slong lenB,const fmpz_t invB,const fmpz_t p)21 __fmpz_mod_poly_divrem_divconquer(fmpz * Q, fmpz * R,
22     const fmpz * A, slong lenA, const fmpz * B, slong lenB,
23     const fmpz_t invB, const fmpz_t p)
24 {
25     if (lenA < 2 * lenB - 1)
26     {
27         /*
28            Convert unbalanced division into a 2 n1 - 1 by n1 division
29          */
30 
31         const slong n1 = lenA - lenB + 1;
32         const slong n2 = lenB - n1;
33 
34         const fmpz * p1 = A + n2;
35         const fmpz * d1 = B + n2;
36         const fmpz * d2 = B;
37 
38         fmpz * W = _fmpz_vec_init((2 * n1 - 1) + lenB - 1);
39 
40         fmpz * d1q1 = R + n2;
41         fmpz * d2q1 = W + (2 * n1 - 1);
42 
43         _fmpz_mod_poly_divrem_divconquer_recursive(Q, d1q1, W, p1, d1, n1,
44                                                    invB, p);
45 
46         /*
47            Compute d2q1 = Q d2, of length lenB - 1
48          */
49 
50         if (n1 >= n2)
51             _fmpz_mod_poly_mul(d2q1, Q, n1, d2, n2, p);
52         else
53             _fmpz_mod_poly_mul(d2q1, d2, n2, Q, n1, p);
54 
55         /*
56            Compute BQ = d1q1 * x^n1 + d2q1, of length lenB - 1;
57            then compute R = A - BQ
58          */
59 
60         _fmpz_vec_swap(R, d2q1, n2);
61         _fmpz_mod_poly_add(R + n2, R + n2, n1 - 1, d2q1 + n2, n1 - 1, p);
62         _fmpz_mod_poly_sub(R, A, lenA, R, lenA, p);
63 
64         _fmpz_vec_clear(W, (2 * n1 - 1) + lenB - 1);
65     }
66     else  /* lenA = 2 * lenB - 1 */
67     {
68         fmpz * W = _fmpz_vec_init(lenA);
69 
70         _fmpz_mod_poly_divrem_divconquer_recursive(Q, R, W,
71                                                    A, B, lenB, invB, p);
72 
73         _fmpz_mod_poly_sub(R, A, lenB - 1, R, lenB - 1, p);
74 
75         _fmpz_vec_clear(W, lenA);
76     }
77 }
78 
_fmpz_mod_poly_divrem_divconquer(fmpz * Q,fmpz * R,const fmpz * A,slong lenA,const fmpz * B,slong lenB,const fmpz_t invB,const fmpz_t p)79 void _fmpz_mod_poly_divrem_divconquer(fmpz *Q, fmpz *R,
80     const fmpz *A, slong lenA, const fmpz *B, slong lenB,
81     const fmpz_t invB, const fmpz_t p)
82 {
83     if (lenA <= 2 * lenB - 1)
84     {
85         fmpz * W = _fmpz_vec_init(lenA);
86 
87         __fmpz_mod_poly_divrem_divconquer(Q, W, A, lenA, B, lenB, invB, p);
88 
89         _fmpz_vec_set(R, W, lenB - 1);
90         _fmpz_vec_clear(W, lenA);
91     }
92     else  /* lenA > 2 * lenB - 1 */
93     {
94         slong shift, n = 2 * lenB - 1, len1;
95         fmpz *QB, *W, *S;
96 
97         len1 = 2 * n + lenA;
98         W = _fmpz_vec_init(len1);
99         S = W + 2*n;
100         _fmpz_vec_set(S, A, lenA);
101         QB = W + n;
102 
103         while (lenA >= n)
104         {
105             shift = lenA - n;
106             _fmpz_mod_poly_divrem_divconquer_recursive(Q + shift, QB,
107                 W, S + shift, B, lenB, invB, p);
108             _fmpz_mod_poly_sub(S + shift, S + shift, n, QB, n, p);
109             lenA -= lenB;
110         }
111 
112         if (lenA >= lenB)
113         {
114             __fmpz_mod_poly_divrem_divconquer(Q, W, S, lenA, B, lenB, invB, p);
115             _fmpz_vec_swap(W, S, lenA);
116         }
117 
118         _fmpz_vec_set(R, S, lenB - 1);
119         _fmpz_vec_clear(W, len1);
120     }
121 }
122 
123 void
fmpz_mod_poly_divrem_divconquer(fmpz_mod_poly_t Q,fmpz_mod_poly_t R,const fmpz_mod_poly_t A,const fmpz_mod_poly_t B)124 fmpz_mod_poly_divrem_divconquer(fmpz_mod_poly_t Q, fmpz_mod_poly_t R,
125     const fmpz_mod_poly_t A, const fmpz_mod_poly_t B)
126 {
127     const slong lenA = A->length;
128     const slong lenB = B->length;
129     const slong lenQ = lenA - lenB + 1;
130 
131     fmpz *q, *r;
132     fmpz_t invB;
133 
134     if (lenB == 0)
135     {
136         if (fmpz_is_one(fmpz_mod_poly_modulus(B)))
137         {
138             fmpz_mod_poly_set(Q, A);
139             fmpz_mod_poly_zero(R);
140             return;
141         } else
142         {
143             flint_printf("Exception (fmpz_mod_poly_div_basecase). Division by zero.\n");
144             flint_abort();
145         }
146     }
147 
148     if (lenA < lenB)
149     {
150         fmpz_mod_poly_set(R, A);
151         fmpz_mod_poly_zero(Q);
152         return;
153     }
154 
155 	if (B->length < 8)
156 	{
157         fmpz_mod_poly_divrem_basecase(Q, R, A, B);
158         return;
159     }
160 
161     fmpz_init(invB);
162     fmpz_invmod(invB, fmpz_mod_poly_lead(B), &(B->p));
163 
164     if (Q == A || Q == B)
165     {
166         q = _fmpz_vec_init(lenQ);
167     }
168     else
169     {
170         fmpz_mod_poly_fit_length(Q, lenQ);
171         q = Q->coeffs;
172     }
173 
174     if (R == A || R == B)
175     {
176         r = _fmpz_vec_init(lenB - 1);
177     }
178     else
179     {
180         fmpz_mod_poly_fit_length(R, lenB - 1);
181         r = R->coeffs;
182     }
183 
184     _fmpz_mod_poly_divrem_divconquer(q, r, A->coeffs, lenA,
185                                            B->coeffs, lenB, invB, &(B->p));
186 
187     if (Q == A || Q == B)
188     {
189         _fmpz_vec_clear(Q->coeffs, Q->alloc);
190         Q->coeffs = q;
191         Q->alloc  = lenQ;
192         Q->length = lenQ;
193     }
194     else
195     {
196         _fmpz_mod_poly_set_length(Q, lenQ);
197     }
198 
199     if (R == A || R == B)
200     {
201         _fmpz_vec_clear(R->coeffs, R->alloc);
202         R->coeffs = r;
203         R->alloc  = lenB - 1;
204         R->length = lenB - 1;
205     }
206 
207     _fmpz_mod_poly_set_length(R, lenB - 1);
208     _fmpz_mod_poly_normalise(R);
209 
210     fmpz_clear(invB);
211 }
212