1 /*
2     Copyright (C) 2010 William Hart
3     Copyright (C) 2010 Sebastian Pancratz
4 
5     This file is part of FLINT.
6 
7     FLINT is free software: you can redistribute it and/or modify it under
8     the terms of the GNU Lesser General Public License (LGPL) as published
9     by the Free Software Foundation; either version 2.1 of the License, or
10     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
11 */
12 
13 #include <stdlib.h>
14 #include <gmp.h>
15 #include "flint.h"
16 #include "nmod_vec.h"
17 #include "nmod_poly.h"
18 
19 void
_nmod_poly_mullow_KS(mp_ptr out,mp_srcptr in1,slong len1,mp_srcptr in2,slong len2,flint_bitcnt_t bits,slong n,nmod_t mod)20 _nmod_poly_mullow_KS(mp_ptr out, mp_srcptr in1, slong len1,
21             mp_srcptr in2, slong len2, flint_bitcnt_t bits, slong n, nmod_t mod)
22 {
23     slong limbs1, limbs2;
24     mp_ptr mpn1, mpn2, res;
25     int squaring;
26 
27     len1 = FLINT_MIN(len1, n);
28     len2 = FLINT_MIN(len2, n);
29 
30     squaring = (in1 == in2 && len1 == len2);
31 
32     if (bits == 0)
33     {
34         flint_bitcnt_t bits1, bits2, loglen;
35         bits1  = _nmod_vec_max_bits(in1, len1);
36         bits2  = squaring ? bits1 : _nmod_vec_max_bits(in2, len2);
37         loglen = FLINT_BIT_COUNT(len2);
38 
39         bits = bits1 + bits2 + loglen;
40     }
41 
42     limbs1 = (len1 * bits - 1) / FLINT_BITS + 1;
43     limbs2 = (len2 * bits - 1) / FLINT_BITS + 1;
44 
45     mpn1 = (mp_ptr) flint_malloc(sizeof(mp_limb_t) * limbs1);
46     mpn2 = squaring ? mpn1 : (mp_ptr) flint_malloc(sizeof(mp_limb_t) * limbs2);
47 
48     _nmod_poly_bit_pack(mpn1, in1, len1, bits);
49     if (!squaring)
50         _nmod_poly_bit_pack(mpn2, in2, len2, bits);
51 
52     res = (mp_ptr) flint_malloc(sizeof(mp_limb_t) * (limbs1 + limbs2));
53 
54     if (squaring)
55         mpn_sqr(res, mpn1, limbs1);
56     else
57         mpn_mul(res, mpn1, limbs1, mpn2, limbs2);
58 
59     _nmod_poly_bit_unpack(out, n, res, bits, mod);
60 
61     flint_free(mpn2);
62     if (!squaring)
63         flint_free(mpn1);
64 
65     flint_free(res);
66 }
67 
68 void
nmod_poly_mullow_KS(nmod_poly_t res,const nmod_poly_t poly1,const nmod_poly_t poly2,flint_bitcnt_t bits,slong n)69 nmod_poly_mullow_KS(nmod_poly_t res,
70                  const nmod_poly_t poly1, const nmod_poly_t poly2,
71                  flint_bitcnt_t bits, slong n)
72 {
73     slong len_out;
74 
75     if ((poly1->length == 0) || (poly2->length == 0) || n == 0)
76     {
77         nmod_poly_zero(res);
78         return;
79     }
80 
81     len_out = poly1->length + poly2->length - 1;
82     if (n > len_out)
83         n = len_out;
84 
85     if (res == poly1 || res == poly2)
86     {
87         nmod_poly_t temp;
88         nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
89         if (poly1->length >= poly2->length)
90             _nmod_poly_mullow_KS(temp->coeffs, poly1->coeffs, poly1->length,
91                               poly2->coeffs, poly2->length, bits,
92                               n, poly1->mod);
93         else
94             _nmod_poly_mullow_KS(temp->coeffs, poly2->coeffs, poly2->length,
95                               poly1->coeffs, poly1->length, bits,
96                               n, poly1->mod);
97         nmod_poly_swap(res, temp);
98         nmod_poly_clear(temp);
99     }
100     else
101     {
102         nmod_poly_fit_length(res, len_out);
103         if (poly1->length >= poly2->length)
104             _nmod_poly_mullow_KS(res->coeffs, poly1->coeffs, poly1->length,
105                               poly2->coeffs, poly2->length, bits,
106                               n, poly1->mod);
107         else
108             _nmod_poly_mullow_KS(res->coeffs, poly2->coeffs, poly2->length,
109                               poly1->coeffs, poly1->length, bits,
110                               n, poly1->mod);
111     }
112 
113     res->length = n;
114     _nmod_poly_normalise(res);
115 }
116