1 /*
2 Copyright (C) 2016 Aaditya Thakkar
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <http://www.gnu.org/licenses/>.
10 */
11
12 #include "fmpz_mat.h"
13 #include "fmpz.h"
14 #include "fmpz_vec.h"
15 #include "flint.h"
16
fmpz_mat_mul_strassen(fmpz_mat_t C,const fmpz_mat_t A,const fmpz_mat_t B)17 void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
18 {
19 slong a, b, c;
20 slong anr, anc, bnr, bnc;
21
22 fmpz_mat_t A11, A12, A21, A22;
23 fmpz_mat_t B11, B12, B21, B22;
24 fmpz_mat_t C11, C12, C21, C22;
25 fmpz_mat_t X1, X2;
26
27 a = A->r;
28 b = A->c;
29 c = B->c;
30
31 if (a <= 4 || b <= 4 || c <= 4)
32 {
33 fmpz_mat_mul(C, A, B);
34 return;
35 }
36
37 anr = a / 2;
38 anc = b / 2;
39 bnr = anc;
40 bnc = c / 2;
41
42 fmpz_mat_window_init(A11, A, 0, 0, anr, anc);
43 fmpz_mat_window_init(A12, A, 0, anc, anr, 2*anc);
44 fmpz_mat_window_init(A21, A, anr, 0, 2*anr, anc);
45 fmpz_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc);
46
47 fmpz_mat_window_init(B11, B, 0, 0, bnr, bnc);
48 fmpz_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc);
49 fmpz_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc);
50 fmpz_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc);
51
52 fmpz_mat_window_init(C11, C, 0, 0, anr, bnc);
53 fmpz_mat_window_init(C12, C, 0, bnc, anr, 2*bnc);
54 fmpz_mat_window_init(C21, C, anr, 0, 2*anr, bnc);
55 fmpz_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc);
56
57 fmpz_mat_init(X1, anr, FLINT_MAX(bnc, anc));
58 fmpz_mat_init(X2, anc, bnc);
59
60 X1->c = anc;
61
62 fmpz_mat_sub(X1, A11, A21);
63 fmpz_mat_sub(X2, B22, B12);
64 fmpz_mat_mul(C21, X1, X2);
65
66 fmpz_mat_add(X1, A21, A22);
67 fmpz_mat_sub(X2, B12, B11);
68 fmpz_mat_mul(C22, X1, X2);
69
70 fmpz_mat_sub(X1, X1, A11);
71 fmpz_mat_sub(X2, B22, X2);
72 fmpz_mat_mul(C12, X1, X2);
73
74 fmpz_mat_sub(X1, A12, X1);
75 fmpz_mat_mul(C11, X1, B22);
76
77 X1->c = bnc;
78 fmpz_mat_mul(X1, A11, B11);
79 fmpz_mat_add(C12, X1, C12);
80 fmpz_mat_add(C21, C12, C21);
81 fmpz_mat_add(C12, C12, C22);
82 fmpz_mat_add(C22, C21, C22);
83 fmpz_mat_add(C12, C12, C11);
84 fmpz_mat_sub(X2, X2, B21);
85 fmpz_mat_mul(C11, A22, X2);
86
87 fmpz_mat_clear(X2);
88
89 fmpz_mat_sub(C21, C21, C11);
90 fmpz_mat_mul(C11, A12, B21);
91
92 fmpz_mat_add(C11, X1, C11);
93
94 X1->c = FLINT_MAX(bnc, anc);
95 fmpz_mat_clear(X1);
96
97 fmpz_mat_window_clear(A11);
98 fmpz_mat_window_clear(A12);
99 fmpz_mat_window_clear(A21);
100 fmpz_mat_window_clear(A22);
101
102 fmpz_mat_window_clear(B11);
103 fmpz_mat_window_clear(B12);
104 fmpz_mat_window_clear(B21);
105 fmpz_mat_window_clear(B22);
106
107 fmpz_mat_window_clear(C11);
108 fmpz_mat_window_clear(C12);
109 fmpz_mat_window_clear(C21);
110 fmpz_mat_window_clear(C22);
111
112 if (c > 2*bnc)
113 {
114 fmpz_mat_t Bc, Cc;
115 fmpz_mat_window_init(Bc, B, 0, 2*bnc, b, c);
116 fmpz_mat_window_init(Cc, C, 0, 2*bnc, a, c);
117 fmpz_mat_mul(Cc, A, Bc);
118 fmpz_mat_window_clear(Bc);
119 fmpz_mat_window_clear(Cc);
120 }
121
122 if (a > 2*anr)
123 {
124 fmpz_mat_t Ar, Cr;
125 fmpz_mat_window_init(Ar, A, 2*anr, 0, a, b);
126 fmpz_mat_window_init(Cr, C, 2*anr, 0, a, c);
127 fmpz_mat_mul(Cr, Ar, B);
128 fmpz_mat_window_clear(Ar);
129 fmpz_mat_window_clear(Cr);
130 }
131
132 if (b > 2*anc)
133 {
134 fmpz_mat_t Ac, Br, Cb, tmp;
135 slong mt, nt;
136
137 fmpz_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b);
138 fmpz_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc);
139 fmpz_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc);
140
141 mt = Ac->r;
142 nt = Br->c;
143
144 fmpz_mat_init(tmp, mt, nt);
145 fmpz_mat_mul(tmp, Ac, Br);
146 fmpz_mat_add(Cb, Cb, tmp);
147 fmpz_mat_clear(tmp);
148 fmpz_mat_window_clear(Ac);
149 fmpz_mat_window_clear(Br);
150 fmpz_mat_window_clear(Cb);
151 }
152 }
153