1 /*
2     Copyright (C) 2008, Martin Albrecht
3     Copyright (C) 2008, 2009 William Hart.
4     Copyright (C) 2010, Fredrik Johansson
5 
6     This file is part of FLINT.
7 
8     FLINT is free software: you can redistribute it and/or modify it under
9     the terms of the GNU Lesser General Public License (LGPL) as published
10     by the Free Software Foundation; either version 2.1 of the License, or
11     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
12 */
13 
14 #include <stdlib.h>
15 #include <gmp.h>
16 #include "flint.h"
17 #include "nmod_vec.h"
18 #include "nmod_mat.h"
19 
20 
21 void
nmod_mat_mul_strassen(nmod_mat_t C,const nmod_mat_t A,const nmod_mat_t B)22 nmod_mat_mul_strassen(nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B)
23 {
24     slong a, b, c;
25     slong anr, anc, bnr, bnc;
26 
27     nmod_mat_t A11, A12, A21, A22;
28     nmod_mat_t B11, B12, B21, B22;
29     nmod_mat_t C11, C12, C21, C22;
30     nmod_mat_t X1, X2;
31 
32     a = A->r;
33     b = A->c;
34     c = B->c;
35 
36     if (a <= 4 || b <= 4 || c <= 4)
37     {
38         nmod_mat_mul(C, A, B);
39         return;
40     }
41 
42     anr = a / 2;
43     anc = b / 2;
44     bnr = anc;
45     bnc = c / 2;
46 
47     nmod_mat_window_init(A11, A, 0, 0, anr, anc);
48     nmod_mat_window_init(A12, A, 0, anc, anr, 2*anc);
49     nmod_mat_window_init(A21, A, anr, 0, 2*anr, anc);
50     nmod_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc);
51 
52     nmod_mat_window_init(B11, B, 0, 0, bnr, bnc);
53     nmod_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc);
54     nmod_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc);
55     nmod_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc);
56 
57     nmod_mat_window_init(C11, C, 0, 0, anr, bnc);
58     nmod_mat_window_init(C12, C, 0, bnc, anr, 2*bnc);
59     nmod_mat_window_init(C21, C, anr, 0, 2*anr, bnc);
60     nmod_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc);
61 
62     nmod_mat_init(X1, anr, FLINT_MAX(bnc, anc), A->mod.n);
63     nmod_mat_init(X2, anc, bnc, A->mod.n);
64 
65     X1->c = anc;
66 
67     /*
68         See Jean-Guillaume Dumas, Clement Pernet, Wei Zhou; "Memory
69         efficient scheduling of Strassen-Winograd's matrix multiplication
70         algorithm"; https://arxiv.org/pdf/0707.2347v3 for reference on the
71         used operation scheduling.
72     */
73 
74     nmod_mat_sub(X1, A11, A21);
75     nmod_mat_sub(X2, B22, B12);
76     nmod_mat_mul(C21, X1, X2);
77 
78     nmod_mat_add(X1, A21, A22);
79     nmod_mat_sub(X2, B12, B11);
80     nmod_mat_mul(C22, X1, X2);
81 
82     nmod_mat_sub(X1, X1, A11);
83     nmod_mat_sub(X2, B22, X2);
84     nmod_mat_mul(C12, X1, X2);
85 
86     nmod_mat_sub(X1, A12, X1);
87     nmod_mat_mul(C11, X1, B22);
88 
89     X1->c = bnc;
90     nmod_mat_mul(X1, A11, B11);
91 
92     nmod_mat_add(C12, X1, C12);
93     nmod_mat_add(C21, C12, C21);
94     nmod_mat_add(C12, C12, C22);
95     nmod_mat_add(C22, C21, C22);
96     nmod_mat_add(C12, C12, C11);
97     nmod_mat_sub(X2, X2, B21);
98     nmod_mat_mul(C11, A22, X2);
99 
100     nmod_mat_clear(X2);
101 
102     nmod_mat_sub(C21, C21, C11);
103     nmod_mat_mul(C11, A12, B21);
104 
105     nmod_mat_add(C11, X1, C11);
106 
107     nmod_mat_clear(X1);
108 
109     nmod_mat_window_clear(A11);
110     nmod_mat_window_clear(A12);
111     nmod_mat_window_clear(A21);
112     nmod_mat_window_clear(A22);
113 
114     nmod_mat_window_clear(B11);
115     nmod_mat_window_clear(B12);
116     nmod_mat_window_clear(B21);
117     nmod_mat_window_clear(B22);
118 
119     nmod_mat_window_clear(C11);
120     nmod_mat_window_clear(C12);
121     nmod_mat_window_clear(C21);
122     nmod_mat_window_clear(C22);
123 
124     if (c > 2*bnc) /* A by last col of B -> last col of C */
125     {
126         nmod_mat_t Bc, Cc;
127         nmod_mat_window_init(Bc, B, 0, 2*bnc, b, c);
128         nmod_mat_window_init(Cc, C, 0, 2*bnc, a, c);
129         nmod_mat_mul(Cc, A, Bc);
130         nmod_mat_window_clear(Bc);
131         nmod_mat_window_clear(Cc);
132     }
133 
134     if (a > 2*anr) /* last row of A by B -> last row of C */
135     {
136         nmod_mat_t Ar, Cr;
137         nmod_mat_window_init(Ar, A, 2*anr, 0, a, b);
138         nmod_mat_window_init(Cr, C, 2*anr, 0, a, c);
139         nmod_mat_mul(Cr, Ar, B);
140         nmod_mat_window_clear(Ar);
141         nmod_mat_window_clear(Cr);
142     }
143 
144     if (b > 2*anc) /* last col of A by last row of B -> C */
145     {
146         nmod_mat_t Ac, Br, Cb;
147         nmod_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b);
148         nmod_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc);
149         nmod_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc);
150         nmod_mat_addmul(Cb, Cb, Ac, Br);
151         nmod_mat_window_clear(Ac);
152         nmod_mat_window_clear(Br);
153         nmod_mat_window_clear(Cb);
154     }
155 }
156