1 /*
2     Copyright (C) 2016 Aaditya Thakkar
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include "fmpz_mat.h"
13 #include "fmpz.h"
14 #include "fmpz_vec.h"
15 #include "flint.h"
16 
fmpz_mat_mul_strassen(fmpz_mat_t C,const fmpz_mat_t A,const fmpz_mat_t B)17 void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
18 {
19     slong a, b, c;
20     slong anr, anc, bnr, bnc;
21 
22     fmpz_mat_t A11, A12, A21, A22;
23     fmpz_mat_t B11, B12, B21, B22;
24     fmpz_mat_t C11, C12, C21, C22;
25     fmpz_mat_t X1, X2;
26 
27     a = A->r;
28     b = A->c;
29     c = B->c;
30 
31     if (a <= 4 || b <= 4 || c <= 4)
32     {
33         fmpz_mat_mul(C, A, B);
34         return;
35     }
36 
37     anr = a / 2;
38     anc = b / 2;
39     bnr = anc;
40     bnc = c / 2;
41 
42     fmpz_mat_window_init(A11, A, 0, 0, anr, anc);
43     fmpz_mat_window_init(A12, A, 0, anc, anr, 2*anc);
44     fmpz_mat_window_init(A21, A, anr, 0, 2*anr, anc);
45     fmpz_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc);
46 
47     fmpz_mat_window_init(B11, B, 0, 0, bnr, bnc);
48     fmpz_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc);
49     fmpz_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc);
50     fmpz_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc);
51 
52     fmpz_mat_window_init(C11, C, 0, 0, anr, bnc);
53     fmpz_mat_window_init(C12, C, 0, bnc, anr, 2*bnc);
54     fmpz_mat_window_init(C21, C, anr, 0, 2*anr, bnc);
55     fmpz_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc);
56 
57     fmpz_mat_init(X1, anr, FLINT_MAX(bnc, anc));
58     fmpz_mat_init(X2, anc, bnc);
59 
60     X1->c = anc;
61 
62     fmpz_mat_sub(X1, A11, A21);
63     fmpz_mat_sub(X2, B22, B12);
64     fmpz_mat_mul(C21, X1, X2);
65 
66     fmpz_mat_add(X1, A21, A22);
67     fmpz_mat_sub(X2, B12, B11);
68     fmpz_mat_mul(C22, X1, X2);
69 
70     fmpz_mat_sub(X1, X1, A11);
71     fmpz_mat_sub(X2, B22, X2);
72     fmpz_mat_mul(C12, X1, X2);
73 
74     fmpz_mat_sub(X1, A12, X1);
75     fmpz_mat_mul(C11, X1, B22);
76 
77     X1->c = bnc;
78     fmpz_mat_mul(X1, A11, B11);
79     fmpz_mat_add(C12, X1, C12);
80     fmpz_mat_add(C21, C12, C21);
81     fmpz_mat_add(C12, C12, C22);
82     fmpz_mat_add(C22, C21, C22);
83     fmpz_mat_add(C12, C12, C11);
84     fmpz_mat_sub(X2, X2, B21);
85     fmpz_mat_mul(C11, A22, X2);
86 
87     fmpz_mat_clear(X2);
88 
89     fmpz_mat_sub(C21, C21, C11);
90     fmpz_mat_mul(C11, A12, B21);
91 
92     fmpz_mat_add(C11, X1, C11);
93 
94     X1->c = FLINT_MAX(bnc, anc);
95     fmpz_mat_clear(X1);
96 
97     fmpz_mat_window_clear(A11);
98     fmpz_mat_window_clear(A12);
99     fmpz_mat_window_clear(A21);
100     fmpz_mat_window_clear(A22);
101 
102     fmpz_mat_window_clear(B11);
103     fmpz_mat_window_clear(B12);
104     fmpz_mat_window_clear(B21);
105     fmpz_mat_window_clear(B22);
106 
107     fmpz_mat_window_clear(C11);
108     fmpz_mat_window_clear(C12);
109     fmpz_mat_window_clear(C21);
110     fmpz_mat_window_clear(C22);
111 
112     if (c > 2*bnc)
113     {
114         fmpz_mat_t Bc, Cc;
115         fmpz_mat_window_init(Bc, B, 0, 2*bnc, b, c);
116         fmpz_mat_window_init(Cc, C, 0, 2*bnc, a, c);
117         fmpz_mat_mul(Cc, A, Bc);
118         fmpz_mat_window_clear(Bc);
119         fmpz_mat_window_clear(Cc);
120     }
121 
122     if (a > 2*anr)
123     {
124         fmpz_mat_t Ar, Cr;
125         fmpz_mat_window_init(Ar, A, 2*anr, 0, a, b);
126         fmpz_mat_window_init(Cr, C, 2*anr, 0, a, c);
127         fmpz_mat_mul(Cr, Ar, B);
128         fmpz_mat_window_clear(Ar);
129         fmpz_mat_window_clear(Cr);
130     }
131 
132     if (b > 2*anc)
133     {
134         fmpz_mat_t Ac, Br, Cb, tmp;
135         slong mt, nt;
136 
137         fmpz_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b);
138         fmpz_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc);
139         fmpz_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc);
140 
141         mt = Ac->r;
142         nt = Br->c;
143 
144         fmpz_mat_init(tmp, mt, nt);
145         fmpz_mat_mul(tmp, Ac, Br);
146         fmpz_mat_add(Cb, Cb, tmp);
147         fmpz_mat_clear(tmp);
148         fmpz_mat_window_clear(Ac);
149         fmpz_mat_window_clear(Br);
150         fmpz_mat_window_clear(Cb);
151     }
152 }
153