1 /*
2     Copyright (C) 2010 William Hart
3     Copyright (C) 2010 Sebastian Pancratz
4 
5     This file is part of FLINT.
6 
7     FLINT is free software: you can redistribute it and/or modify it under
8     the terms of the GNU Lesser General Public License (LGPL) as published
9     by the Free Software Foundation; either version 2.1 of the License, or
10     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
11 */
12 
13 #include <stdlib.h>
14 #include <gmp.h>
15 #include "flint.h"
16 #include "nmod_vec.h"
17 #include "nmod_poly.h"
18 
19 void
_nmod_poly_mullow_KS(mp_ptr out,mp_srcptr in1,slong len1,mp_srcptr in2,slong len2,flint_bitcnt_t bits,slong n,nmod_t mod)20 _nmod_poly_mullow_KS(mp_ptr out, mp_srcptr in1, slong len1,
21             mp_srcptr in2, slong len2, flint_bitcnt_t bits, slong n, nmod_t mod)
22 {
23     slong limbs1, limbs2;
24     mp_ptr tmp, mpn1, mpn2, res;
25     int squaring;
26     TMP_INIT;
27 
28     len1 = FLINT_MIN(len1, n);
29     len2 = FLINT_MIN(len2, n);
30 
31     squaring = (in1 == in2 && len1 == len2);
32 
33     if (bits == 0)
34     {
35         flint_bitcnt_t bits1, bits2, loglen;
36 
37         /* Look at the actual bits of the input? This slows down the generic
38         case. Are there situations where we care enough about special input? */
39 #if 0
40         bits1  = _nmod_vec_max_bits2(in1, len1);
41         bits2  = squaring ? bits1 : _nmod_vec_max_bits2(in2, len2);
42 #else
43         bits1 = FLINT_BITS - (slong) mod.norm;
44         bits2 = bits1;
45 #endif
46         loglen = FLINT_BIT_COUNT(len2);
47         bits = bits1 + bits2 + loglen;
48     }
49 
50     limbs1 = (len1 * bits - 1) / FLINT_BITS + 1;
51     limbs2 = (len2 * bits - 1) / FLINT_BITS + 1;
52 
53     TMP_START;
54     tmp = TMP_ALLOC(sizeof(mp_limb_t) * (limbs1 + limbs2 + limbs1 + (squaring ? 0 : limbs2)));
55     res = tmp;
56     mpn1 = tmp + limbs1 + limbs2;
57     mpn2 = squaring ? mpn1 : (mpn1 + limbs1);
58 
59     _nmod_poly_bit_pack(mpn1, in1, len1, bits);
60     if (!squaring)
61         _nmod_poly_bit_pack(mpn2, in2, len2, bits);
62 
63     if (squaring)
64         mpn_sqr(res, mpn1, limbs1);
65     else
66         mpn_mul(res, mpn1, limbs1, mpn2, limbs2);
67 
68     _nmod_poly_bit_unpack(out, n, res, bits, mod);
69 
70     TMP_END;
71 }
72 
73 void
nmod_poly_mullow_KS(nmod_poly_t res,const nmod_poly_t poly1,const nmod_poly_t poly2,flint_bitcnt_t bits,slong n)74 nmod_poly_mullow_KS(nmod_poly_t res,
75                  const nmod_poly_t poly1, const nmod_poly_t poly2,
76                  flint_bitcnt_t bits, slong n)
77 {
78     slong len_out;
79 
80     if ((poly1->length == 0) || (poly2->length == 0) || n == 0)
81     {
82         nmod_poly_zero(res);
83         return;
84     }
85 
86     len_out = poly1->length + poly2->length - 1;
87     if (n > len_out)
88         n = len_out;
89 
90     if (res == poly1 || res == poly2)
91     {
92         nmod_poly_t temp;
93         nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
94         if (poly1->length >= poly2->length)
95             _nmod_poly_mullow_KS(temp->coeffs, poly1->coeffs, poly1->length,
96                               poly2->coeffs, poly2->length, bits,
97                               n, poly1->mod);
98         else
99             _nmod_poly_mullow_KS(temp->coeffs, poly2->coeffs, poly2->length,
100                               poly1->coeffs, poly1->length, bits,
101                               n, poly1->mod);
102         nmod_poly_swap(res, temp);
103         nmod_poly_clear(temp);
104     }
105     else
106     {
107         nmod_poly_fit_length(res, len_out);
108         if (poly1->length >= poly2->length)
109             _nmod_poly_mullow_KS(res->coeffs, poly1->coeffs, poly1->length,
110                               poly2->coeffs, poly2->length, bits,
111                               n, poly1->mod);
112         else
113             _nmod_poly_mullow_KS(res->coeffs, poly2->coeffs, poly2->length,
114                               poly1->coeffs, poly1->length, bits,
115                               n, poly1->mod);
116     }
117 
118     res->length = n;
119     _nmod_poly_normalise(res);
120 }
121