1 /*
2     Copyright (C) 2013 Tom Bachmann
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #ifndef NMOD_POLY_MATXX_H
13 #define NMOD_POLY_MATXX_H
14 
15 #include "nmod_poly_mat.h"
16 
17 #include "nmod_matxx.h"
18 #include "nmod_polyxx.h"
19 #include "permxx.h"
20 
21 #include "flintxx/matrix.h"
22 #include "flintxx/stdmath.h"
23 
24 // NOTE: it is *not* valid to use empty nmod_poly_matxx matrices!
25 // TODO nullspace member
26 
27 namespace flint {
28 FLINT_DEFINE_UNOP(sqr_interpolate)
FLINT_DEFINE_BINOP(mul_interpolate)29 FLINT_DEFINE_BINOP(mul_interpolate)
30 
31 namespace detail {
32 template<class Mat>
33 struct nmod_poly_matxx_traits : matrices::generic_traits<Mat> { };
34 } // detail
35 
36 template<class Operation, class Data>
37 class nmod_poly_matxx_expression
38     : public expression<derived_wrapper<nmod_poly_matxx_expression>, Operation, Data>
39 {
40 public:
41     typedef expression<derived_wrapper< ::flint::nmod_poly_matxx_expression>,
42               Operation, Data> base_t;
43     typedef detail::nmod_poly_matxx_traits<nmod_poly_matxx_expression> traits_t;
44 
45     FLINTXX_DEFINE_BASICS(nmod_poly_matxx_expression)
FLINTXX_DEFINE_CTORS(nmod_poly_matxx_expression)46     FLINTXX_DEFINE_CTORS(nmod_poly_matxx_expression)
47     FLINTXX_DEFINE_C_REF(nmod_poly_matxx_expression, nmod_poly_mat_struct, _mat)
48 
49     // These only make sense with immediates
50     nmodxx_ctx_srcref _ctx() const
51     {
52         return nmodxx_ctx_srcref::make(nmod_poly_mat_entry(_mat(), 0, 0)->mod);
53     }
54 
55     // These work on any expression without evaluation
estimate_ctx()56     nmodxx_ctx_srcref estimate_ctx() const
57     {
58         return tools::find_nmodxx_ctx(*this);
59     }
modulus()60     mp_limb_t modulus() const {return estimate_ctx().n();}
61 
62     template<class Expr>
create_temporary_rowscols(const Expr & e,slong rows,slong cols)63     static evaluated_t create_temporary_rowscols(
64             const Expr& e, slong rows, slong cols)
65     {
66         return evaluated_t(rows, cols, tools::find_nmodxx_ctx(e).n());
67     }
68     FLINTXX_DEFINE_MATRIX_METHODS(traits_t)
69 
FLINTXX_DEFINE_FORWARD_STATIC(from_ground)70     FLINTXX_DEFINE_FORWARD_STATIC(from_ground)
71 
72     static nmod_poly_matxx_expression randtest(slong rows, slong cols,
73             mp_limb_t M, frandxx& state, slong len)
74     {
75         nmod_poly_matxx_expression res(rows, cols, M);
76         res.set_randtest(state, len);
77         return res;
78     }
randtest_sparse(slong rows,slong cols,mp_limb_t M,frandxx & state,slong len,float density)79     static nmod_poly_matxx_expression randtest_sparse(slong rows, slong cols,
80             mp_limb_t M, frandxx& state, slong len, float density)
81     {
82         nmod_poly_matxx_expression res(rows, cols, M);
83         res.set_randtest_sparse(state, len, density);
84         return res;
85     }
86 
zero(slong rows,slong cols,mp_limb_t n)87     static nmod_poly_matxx_expression zero(slong rows, slong cols, mp_limb_t n)
88         {return nmod_poly_matxx_expression(rows, cols, n);}
one(slong rows,slong cols,mp_limb_t n)89     static nmod_poly_matxx_expression one(slong rows, slong cols, mp_limb_t n)
90     {
91         nmod_poly_matxx_expression res(rows, cols, n);
92         res.set_one();
93         return res;
94     }
95 
96     // these only make sense with targets
set_randtest(frandxx & state,slong len)97     void set_randtest(frandxx& state, slong len)
98         {nmod_poly_mat_randtest(_mat(), state._data(), len);}
set_randtest_sparse(frandxx & state,slong len,float density)99     void set_randtest_sparse(frandxx& state, slong len, float density)
100         {nmod_poly_mat_randtest_sparse(_mat(), state._data(), len, density);}
set_zero()101     void set_zero() {nmod_poly_mat_zero(_mat());}
set_one()102     void set_one() {nmod_poly_mat_one(_mat());}
103 
104     // these cause evaluation
is_zero()105     bool is_zero() const
106         {return nmod_poly_mat_is_zero(this->evaluate()._mat());}
is_one()107     bool is_one() const
108         {return nmod_poly_mat_is_one(this->evaluate()._mat());}
is_square()109     bool is_square() const
110         {return nmod_poly_mat_is_square(this->evaluate()._mat());}
is_empty()111     bool is_empty() const
112         {return nmod_poly_mat_is_empty(this->evaluate()._mat());}
max_length()113     slong max_length() const
114         {return nmod_poly_mat_max_length(this->evaluate()._mat());}
rank()115     slong rank() const {return nmod_poly_mat_rank(this->evaluate()._mat());}
find_pivot_any(slong start,slong end,slong c)116     slong find_pivot_any(slong start, slong end, slong c) const
117     {
118         return nmod_poly_mat_find_pivot_any(
119                 this->evaluate()._mat(), start, end, c);
120     }
find_pivot_partial(slong start,slong end,slong c)121     slong find_pivot_partial(slong start, slong end, slong c) const
122     {
123         return nmod_poly_mat_find_pivot_partial(
124                 this->evaluate()._mat(), start, end, c);
125     }
126 
127     // lazy members
128     FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmod_polyxx, det)
129     FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmod_polyxx, det_fflu)
130     FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmod_polyxx, det_interpolate)
131     FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmod_polyxx, trace)
132     FLINTXX_DEFINE_MEMBER_UNOP(sqr)
133     FLINTXX_DEFINE_MEMBER_UNOP(sqr_classical)
134     FLINTXX_DEFINE_MEMBER_UNOP(sqr_interpolate)
135     FLINTXX_DEFINE_MEMBER_UNOP(sqr_KS)
136     FLINTXX_DEFINE_MEMBER_UNOP(transpose)
137     //FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(???, nullspace) // TODO
138     FLINTXX_DEFINE_MEMBER_BINOP_(operator(), compeval)
139     FLINTXX_DEFINE_MEMBER_BINOP(solve)
140     FLINTXX_DEFINE_MEMBER_BINOP(solve_fflu)
141     FLINTXX_DEFINE_MEMBER_BINOP(mul_classical)
142     FLINTXX_DEFINE_MEMBER_BINOP(mul_interpolate)
143     FLINTXX_DEFINE_MEMBER_BINOP(mul_KS)
144     FLINTXX_DEFINE_MEMBER_BINOP(pow)
145 
146     //FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(???, rref) // TODO
147 
148     FLINTXX_DEFINE_MEMBER_FFLU
149 };
150 
151 namespace detail {
152 struct nmod_poly_mat_data;
153 } // detail
154 
155 typedef nmod_poly_matxx_expression<
156     operations::immediate, detail::nmod_poly_mat_data> nmod_poly_matxx;
157 typedef nmod_poly_matxx_expression<operations::immediate,
158             flint_classes::ref_data<
159                 nmod_poly_matxx, nmod_poly_mat_struct> > nmod_poly_matxx_ref;
160 typedef nmod_poly_matxx_expression<operations::immediate,
161         flint_classes::srcref_data<
162             nmod_poly_matxx, nmod_poly_matxx_ref,
163             nmod_poly_mat_struct> > nmod_poly_matxx_srcref;
164 
165 template<>
166 struct matrix_traits<nmod_poly_matxx>
167 {
168     template<class M> static slong rows(const M& m)
169     {
170         return nmod_poly_mat_nrows(m._mat());
171     }
172     template<class M> static slong cols(const M& m)
173     {
174         return nmod_poly_mat_ncols(m._mat());
175     }
176 
177     template<class M> static nmod_polyxx_srcref at(const M& m, slong i, slong j)
178     {
179         return nmod_polyxx_srcref::make(nmod_poly_mat_entry(m._mat(), i, j));
180     }
181     template<class M> static nmod_polyxx_ref at(M& m, slong i, slong j)
182     {
183         return nmod_polyxx_ref::make(nmod_poly_mat_entry(m._mat(), i, j));
184     }
185 };
186 
187 namespace traits {
188 template<> struct has_nmodxx_ctx<nmod_poly_matxx> : mp::true_ { };
189 template<> struct has_nmodxx_ctx<nmod_poly_matxx_ref> : mp::true_ { };
190 template<> struct has_nmodxx_ctx<nmod_poly_matxx_srcref> : mp::true_ { };
191 } // traits
192 
193 namespace detail {
194 template<>
195 struct nmod_poly_matxx_traits<nmod_poly_matxx_srcref>
196     : matrices::generic_traits_srcref<nmod_polyxx_srcref> { };
197 template<>
198 struct nmod_poly_matxx_traits<nmod_poly_matxx_ref>
199     : matrices::generic_traits_ref<nmod_polyxx_ref> { };
200 template<> struct nmod_poly_matxx_traits<nmod_poly_matxx>
201     : matrices::generic_traits_nonref<nmod_polyxx_ref, nmod_polyxx_srcref> { };
202 
203 struct nmod_poly_mat_data
204 {
205     typedef nmod_poly_mat_t& data_ref_t;
206     typedef const nmod_poly_mat_t& data_srcref_t;
207 
208     nmod_poly_mat_t inner;
209 
210     nmod_poly_mat_data(slong m, slong n, mp_limb_t modulus)
211     {
212         nmod_poly_mat_init(inner, m, n, modulus);
213     }
214 
215     nmod_poly_mat_data(const nmod_poly_mat_data& o)
216     {
217         nmod_poly_mat_init_set(inner, o.inner);
218     }
219 
220     nmod_poly_mat_data(nmod_poly_matxx_srcref o)
221     {
222         nmod_poly_mat_init_set(inner, o._data().inner);
223     }
224 
225     ~nmod_poly_mat_data() {nmod_poly_mat_clear(inner);}
226 
227     template<class Nmod_mat>
228     static nmod_poly_mat_data _from_ground(const Nmod_mat& m)
229     {
230         nmod_poly_mat_data res(m.rows(), m.cols(), m.modulus());
231         for(slong i = 0;i < m.rows();++i)
232             for(slong j = 0;j < m.cols();++j)
233                 nmod_poly_set_coeff_ui(nmod_poly_mat_entry(res.inner, i, j), 0,
234                         nmod_mat_entry(m._mat(), i, j));
235         return res;
236     }
237     template<class Nmod_mat>
238     static nmod_poly_mat_data from_ground(const Nmod_mat& m,
239             typename mp::enable_if<traits::is_nmod_matxx<Nmod_mat> >::type* = 0)
240     {
241         return _from_ground(m.evaluate());
242     }
243 };
244 } // detail
245 
246 // temporary instantiation stuff
247 FLINTXX_DEFINE_TEMPORARY_RULES(nmod_poly_matxx)
248 
249 #define NMOD_POLY_MATXX_COND_S FLINTXX_COND_S(nmod_poly_matxx)
250 #define NMOD_POLY_MATXX_COND_T FLINTXX_COND_T(nmod_poly_matxx)
251 
252 namespace rules {
253 FLINT_DEFINE_DOIT_COND2(assignment, NMOD_POLY_MATXX_COND_T, NMOD_POLY_MATXX_COND_S,
254         nmod_poly_mat_set(to._mat(), from._mat()))
255 
256 FLINTXX_DEFINE_SWAP(nmod_poly_matxx, nmod_poly_mat_swap(e1._mat(), e2._mat()))
257 
258 FLINTXX_DEFINE_EQUALS(nmod_poly_matxx, nmod_poly_mat_equal(e1._mat(), e2._mat()))
259 
260 FLINT_DEFINE_PRINT_PRETTY_COND_2(NMOD_POLY_MATXX_COND_S, const char*,
261         (nmod_poly_mat_print(from._mat(), extra), 1))
262 
263 FLINT_DEFINE_THREEARY_EXPR_COND3(mat_at_op, nmod_polyxx,
264         NMOD_POLY_MATXX_COND_S, traits::fits_into_slong, traits::fits_into_slong,
265         nmod_poly_set(to._poly(), nmod_poly_mat_entry(e1._mat(), e2, e3)))
266 
267 FLINT_DEFINE_BINARY_EXPR_COND2(times, nmod_poly_matxx,
268         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
269         nmod_poly_mat_mul(to._mat(), e1._mat(), e2._mat()))
270 FLINT_DEFINE_CBINARY_EXPR_COND2(times, nmod_poly_matxx,
271         NMOD_POLY_MATXX_COND_S, NMOD_POLYXX_COND_S,
272         nmod_poly_mat_scalar_mul_nmod_poly(to._mat(), e1._mat(), e2._poly()))
273 FLINT_DEFINE_CBINARY_EXPR_COND2(times, nmod_poly_matxx,
274         NMOD_POLY_MATXX_COND_S, NMODXX_COND_S,
275         nmod_poly_mat_scalar_mul_nmod(to._mat(), e1._mat(), e2._limb()))
276 
277 FLINT_DEFINE_BINARY_EXPR_COND2(plus, nmod_poly_matxx,
278         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
279         nmod_poly_mat_add(to._mat(), e1._mat(), e2._mat()))
280 FLINT_DEFINE_BINARY_EXPR_COND2(minus, nmod_poly_matxx,
281         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
282         nmod_poly_mat_sub(to._mat(), e1._mat(), e2._mat()))
283 
284 FLINT_DEFINE_UNARY_EXPR_COND(negate, nmod_poly_matxx, NMOD_POLY_MATXX_COND_S,
285         nmod_poly_mat_neg(to._mat(), from._mat()))
286 
287 namespace rdetail {
288 inline void nmod_poly_mat_transpose(nmod_poly_mat_t to,
289     const nmod_poly_mat_t from)
290 {
291     if(from == to) // guaranteed to be square
292     {
293         for(slong i = 0;i < nmod_poly_mat_nrows(to) - 1;++i)
294           for(slong j = i + 1;j < nmod_poly_mat_ncols(to);++j)
295               nmod_poly_swap(nmod_poly_mat_entry(to, i, j),
296                       nmod_poly_mat_entry(to, j, i));
297     }
298     else
299     {
300         for(slong i = 0;i < nmod_poly_mat_nrows(to);++i)
301           for(slong j = 0;j < nmod_poly_mat_ncols(to);++j)
302             nmod_poly_set(nmod_poly_mat_entry(to, i, j),
303                 nmod_poly_mat_entry(from, j, i));
304     }
305 }
306 }
307 // TODO update this when nmod_poly_mat has transpose
308 FLINT_DEFINE_UNARY_EXPR_COND(transpose_op, nmod_poly_matxx, NMOD_POLY_MATXX_COND_S,
309         rdetail::nmod_poly_mat_transpose(to._mat(), from._mat()))
310 FLINT_DEFINE_UNARY_EXPR_COND(trace_op, nmod_polyxx, NMOD_POLY_MATXX_COND_S,
311         nmod_poly_mat_trace(to._poly(), from._mat()))
312 
313 FLINT_DEFINE_BINARY_EXPR_COND2(evaluate_op, nmod_matxx,
314         NMOD_POLY_MATXX_COND_S, NMODXX_COND_S,
315         nmod_poly_mat_evaluate_nmod(to._mat(), e1._mat(), e2._limb()))
316 
317 #define NMOD_POLY_MATXX_DEFINE_MUL(name) \
318 FLINT_DEFINE_BINARY_EXPR_COND2(name##_op, nmod_poly_matxx, \
319         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S, \
320         nmod_poly_mat_##name(to._mat(), e1._mat(), e2._mat()))
321 NMOD_POLY_MATXX_DEFINE_MUL(mul_classical)
322 NMOD_POLY_MATXX_DEFINE_MUL(mul_KS)
323 NMOD_POLY_MATXX_DEFINE_MUL(mul_interpolate)
324 
325 FLINT_DEFINE_UNARY_EXPR_COND(sqr_op, nmod_poly_matxx, NMOD_POLY_MATXX_COND_S,
326         nmod_poly_mat_sqr(to._mat(), from._mat()))
327 FLINT_DEFINE_UNARY_EXPR_COND(sqr_KS_op, nmod_poly_matxx, NMOD_POLY_MATXX_COND_S,
328         nmod_poly_mat_sqr_KS(to._mat(), from._mat()))
329 FLINT_DEFINE_UNARY_EXPR_COND(sqr_classical_op, nmod_poly_matxx,
330         NMOD_POLY_MATXX_COND_S,
331         nmod_poly_mat_sqr_classical(to._mat(), from._mat()))
332 FLINT_DEFINE_UNARY_EXPR_COND(sqr_interpolate_op, nmod_poly_matxx,
333         NMOD_POLY_MATXX_COND_S,
334         nmod_poly_mat_sqr_interpolate(to._mat(), from._mat()))
335 
336 FLINT_DEFINE_BINARY_EXPR_COND2(pow_op, nmod_poly_matxx,
337         NMOD_POLY_MATXX_COND_S, traits::is_unsigned_integer,
338         nmod_poly_mat_pow(to._mat(), e1._mat(), e2))
339 
340 FLINT_DEFINE_UNARY_EXPR_COND(det_op, nmod_polyxx, NMOD_POLY_MATXX_COND_S,
341         nmod_poly_mat_det(to._poly(), from._mat()))
342 FLINT_DEFINE_UNARY_EXPR_COND(det_fflu_op, nmod_polyxx, NMOD_POLY_MATXX_COND_S,
343         nmod_poly_mat_det_fflu(to._poly(), from._mat()))
344 FLINT_DEFINE_UNARY_EXPR_COND(det_interpolate_op, nmod_polyxx,
345         NMOD_POLY_MATXX_COND_S,
346         nmod_poly_mat_det_interpolate(to._poly(), from._mat()))
347 
348 namespace rdetail {
349 typedef make_ltuple<mp::make_tuple<bool, nmod_poly_matxx, nmod_polyxx>::type >::type
350     nmod_poly_mat_inv_rt;
351 } // rdetail
352 
353 FLINT_DEFINE_UNARY_EXPR_COND(inv_op, rdetail::nmod_poly_mat_inv_rt,
354         NMOD_POLY_MATXX_COND_S,
355         to.template get<0>() = nmod_poly_mat_inv(to.template get<1>()._mat(),
356             to.template get<2>()._poly(), from._mat()))
357 
358 namespace rdetail {
359 typedef make_ltuple<mp::make_tuple<slong, nmod_poly_matxx>::type >::type
360     nmod_poly_mat_nullspace_rt;
361 } // rdetail
362 FLINT_DEFINE_UNARY_EXPR_COND(nullspace_op, rdetail::nmod_poly_mat_nullspace_rt,
363         NMOD_POLY_MATXX_COND_S, to.template get<0>() = nmod_poly_mat_nullspace(
364             to.template get<1>()._mat(), from._mat()))
365 
366 FLINT_DEFINE_BINARY_EXPR_COND2(solve_op, rdetail::nmod_poly_mat_inv_rt,
367         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
368         to.template get<0>() = nmod_poly_mat_solve(to.template get<1>()._mat(),
369             to.template get<2>()._poly(), e1._mat(), e2._mat()))
370 FLINT_DEFINE_BINARY_EXPR_COND2(solve_fflu_op, rdetail::nmod_poly_mat_inv_rt,
371         NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
372         to.template get<0>() = nmod_poly_mat_solve_fflu(
373             to.template get<1>()._mat(),
374             to.template get<2>()._poly(), e1._mat(), e2._mat()))
375 
376 namespace rdetail {
377 typedef make_ltuple<mp::make_tuple<slong, nmod_poly_matxx, nmod_polyxx>::type>::type
378     nmod_poly_matxx_fflu_rt;
379 } // rdetail
380 
381 FLINT_DEFINE_THREEARY_EXPR_COND3(fflu_op, rdetail::nmod_poly_matxx_fflu_rt,
382         NMOD_POLY_MATXX_COND_S, traits::is_maybe_perm, tools::is_bool,
383         to.template get<0>() = nmod_poly_mat_fflu(to.template get<1>()._mat(),
384             to.template get<2>()._poly(), maybe_perm_data(e2), e1._mat(), e3))
385 
386 FLINT_DEFINE_UNARY_EXPR_COND(rref_op, rdetail::nmod_poly_matxx_fflu_rt,
387         NMOD_POLY_MATXX_COND_S,
388         to.template get<0>() = nmod_poly_mat_rref(to.template get<1>()._mat(),
389             to.template get<2>()._poly(), from._mat()))
390 
391 FLINT_DEFINE_THREEARY_EXPR_COND3(solve_fflu_precomp_op, nmod_poly_matxx,
392         traits::is_permxx, NMOD_POLY_MATXX_COND_S, NMOD_POLY_MATXX_COND_S,
393         nmod_poly_mat_solve_fflu_precomp(to._mat(), e1._data(),
394             e2._mat(), e3._mat()))
395 } // rules
396 } // flint
397 
398 #endif
399