1 #include <omp.h>
2 #include "strassen.hpp"
3 
4 /*****************************************************************************
5 **
6 ** OptimizedStrassenMultiply
7 **
8 ** For large matrices A, B, and C of size MatrixSize * MatrixSize this
9 ** function performs the operation C = A x B efficiently.
10 **
11 ** INPUT:
12 **    C = (*C WRITE) Address of top left element of matrix C.
13 **    A = (*A IS READ ONLY) Address of top left element of matrix A.
14 **    B = (*B IS READ ONLY) Address of top left element of matrix B.
15 **    MatrixSize = Size of matrices (for n*n matrix, MatrixSize = n)
16 **    RowWidthA = Number of elements in memory between A[x,y] and A[x,y+1]
17 **    RowWidthB = Number of elements in memory between B[x,y] and B[x,y+1]
18 **    RowWidthC = Number of elements in memory between C[x,y] and C[x,y+1]
19 **
20 ** OUTPUT:
21 **    C = (*C WRITE) Matrix C contains A x B. (Initial value of *C undefined.)
22 **
23 *****************************************************************************/
OptimizedStrassenMultiply_omp(REAL * C,REAL * A,REAL * B,unsigned MatrixSize,unsigned RowWidthC,unsigned RowWidthA,unsigned RowWidthB,int Depth)24 void OptimizedStrassenMultiply_omp(REAL *C, REAL *A, REAL *B, unsigned MatrixSize,
25      unsigned RowWidthC, unsigned RowWidthA, unsigned RowWidthB, int Depth)
26 {
27   unsigned QuadrantSize = MatrixSize >> 1; /* MatixSize / 2 */
28   unsigned QuadrantSizeInBytes = sizeof(REAL) * QuadrantSize * QuadrantSize
29                                  + 32;
30   unsigned Column, Row;
31 
32   /************************************************************************
33   ** For each matrix A, B, and C, we'll want pointers to each quandrant
34   ** in the matrix. These quandrants will be addressed as follows:
35   **  --        --
36   **  | A11  A12 |
37   **  |          |
38   **  | A21  A22 |
39   **  --        --
40   ************************************************************************/
41   REAL /* *A11, *B11, *C11, */ *A12, *B12, *C12,
42        *A21, *B21, *C21, *A22, *B22, *C22;
43 
44   REAL *S1,*S2,*S3,*S4,*S5,*S6,*S7,*S8,*M2,*M5,*T1sMULT;
45   #define T2sMULT C22
46   #define NumberOfVariables 11
47 
48   PTR TempMatrixOffset = 0;
49   PTR MatrixOffsetA = 0;
50   PTR MatrixOffsetB = 0;
51 
52   char *Heap;
53   void *StartHeap;
54 
55   /* Distance between the end of a matrix row and the start of the next row */
56   PTR RowIncrementA = ( RowWidthA - QuadrantSize ) << 3;
57   PTR RowIncrementB = ( RowWidthB - QuadrantSize ) << 3;
58   PTR RowIncrementC = ( RowWidthC - QuadrantSize ) << 3;
59 
60   if (MatrixSize <= CUTOFF_SIZE) {
61     MultiplyByDivideAndConquer(C, A, B, MatrixSize, RowWidthC, RowWidthA, RowWidthB, 0);
62     return;
63   }
64 
65   /* Initialize quandrant matrices */
66   #define A11 A
67   #define B11 B
68   #define C11 C
69   A12 = A11 + QuadrantSize;
70   B12 = B11 + QuadrantSize;
71   C12 = C11 + QuadrantSize;
72   A21 = A + (RowWidthA * QuadrantSize);
73   B21 = B + (RowWidthB * QuadrantSize);
74   C21 = C + (RowWidthC * QuadrantSize);
75   A22 = A21 + QuadrantSize;
76   B22 = B21 + QuadrantSize;
77   C22 = C21 + QuadrantSize;
78 
79   /* Allocate Heap Space Here */
80   Heap = static_cast<char*>(malloc(QuadrantSizeInBytes * NumberOfVariables));
81   StartHeap = Heap;
82 
83   /* ensure that heap is on cache boundary */
84   if ( ((PTR) Heap) & 31)
85      Heap = (char*) ( ((PTR) Heap) + 32 - ( ((PTR) Heap) & 31) );
86 
87   /* Distribute the heap space over the variables */
88   S1 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
89   S2 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
90   S3 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
91   S4 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
92   S5 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
93   S6 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
94   S7 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
95   S8 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
96   M2 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
97   M5 = (REAL*) Heap; Heap += QuadrantSizeInBytes;
98   T1sMULT = (REAL*) Heap; Heap += QuadrantSizeInBytes;
99 
100   /***************************************************************************
101   ** Step through all columns row by row (vertically)
102   ** (jumps in memory by RowWidth => bad locality)
103   ** (but we want the best locality on the innermost loop)
104   ***************************************************************************/
105   for (Row = 0; Row < QuadrantSize; Row++) {
106 
107     /*************************************************************************
108     ** Step through each row horizontally (addressing elements in each column)
109     ** (jumps linearly througn memory => good locality)
110     *************************************************************************/
111     for (Column = 0; Column < QuadrantSize; Column++) {
112 
113       /***********************************************************
114       ** Within this loop, the following holds for MatrixOffset:
115       ** MatrixOffset = (Row * RowWidth) + Column
116       ** (note: that the unit of the offset is number of reals)
117       ***********************************************************/
118       /* Element of Global Matrix, such as A, B, C */
119       #define E(Matrix)   (* (REAL*) ( ((PTR) Matrix) + TempMatrixOffset ) )
120       #define EA(Matrix)  (* (REAL*) ( ((PTR) Matrix) + MatrixOffsetA ) )
121       #define EB(Matrix)  (* (REAL*) ( ((PTR) Matrix) + MatrixOffsetB ) )
122 
123       /* FIXME - may pay to expand these out - got higher speed-ups below */
124       /* S4 = A12 - ( S2 = ( S1 = A21 + A22 ) - A11 ) */
125       E(S4) = EA(A12) - ( E(S2) = ( E(S1) = EA(A21) + EA(A22) ) - EA(A11) );
126 
127       /* S8 = (S6 = B22 - ( S5 = B12 - B11 ) ) - B21 */
128       E(S8) = ( E(S6) = EB(B22) - ( E(S5) = EB(B12) - EB(B11) ) ) - EB(B21);
129 
130       /* S3 = A11 - A21 */
131       E(S3) = EA(A11) - EA(A21);
132 
133       /* S7 = B22 - B12 */
134       E(S7) = EB(B22) - EB(B12);
135 
136       TempMatrixOffset += sizeof(REAL);
137       MatrixOffsetA += sizeof(REAL);
138       MatrixOffsetB += sizeof(REAL);
139     } /* end row loop*/
140 
141     MatrixOffsetA += RowIncrementA;
142     MatrixOffsetB += RowIncrementB;
143   } /* end column loop */
144 
145   /* M2 = A11 x B11 */
146   #pragma omp task untied
147   OptimizedStrassenMultiply_omp(M2, A11, B11, QuadrantSize, QuadrantSize, RowWidthA, RowWidthB, Depth+1);
148 
149   /* M5 = S1 * S5 */
150   #pragma omp task untied
151   OptimizedStrassenMultiply_omp(M5, S1, S5, QuadrantSize, QuadrantSize, QuadrantSize, QuadrantSize, Depth+1);
152 
153   /* Step 1 of T1 = S2 x S6 + M2 */
154   #pragma omp task untied
155   OptimizedStrassenMultiply_omp(T1sMULT, S2, S6,  QuadrantSize, QuadrantSize, QuadrantSize, QuadrantSize, Depth+1);
156 
157   /* Step 1 of T2 = T1 + S3 x S7 */
158   #pragma omp task untied
159   OptimizedStrassenMultiply_omp(C22, S3, S7, QuadrantSize, RowWidthC /*FIXME*/, QuadrantSize, QuadrantSize, Depth+1);
160 
161   /* Step 1 of C11 = M2 + A12 * B21 */
162   #pragma omp task untied
163   OptimizedStrassenMultiply_omp(C11, A12, B21, QuadrantSize, RowWidthC, RowWidthA, RowWidthB, Depth+1);
164 
165   /* Step 1 of C12 = S4 x B22 + T1 + M5 */
166   #pragma omp task untied
167   OptimizedStrassenMultiply_omp(C12, S4, B22, QuadrantSize, RowWidthC, QuadrantSize, RowWidthB, Depth+1);
168 
169   /* Step 1 of C21 = T2 - A22 * S8 */
170   #pragma omp task untied
171   OptimizedStrassenMultiply_omp(C21, A22, S8, QuadrantSize, RowWidthC, RowWidthA, QuadrantSize, Depth+1);
172 
173   /**********************************************
174   ** Synchronization Point
175   **********************************************/
176   #pragma omp taskwait
177   /***************************************************************************
178   ** Step through all columns row by row (vertically)
179   ** (jumps in memory by RowWidth => bad locality)
180   ** (but we want the best locality on the innermost loop)
181   ***************************************************************************/
182   for (Row = 0; Row < QuadrantSize; Row++) {
183     /*************************************************************************
184     ** Step through each row horizontally (addressing elements in each column)
185     ** (jumps linearly througn memory => good locality)
186     *************************************************************************/
187     for (Column = 0; Column < QuadrantSize; Column += 4) {
188       REAL LocalM5_0 = *(M5);
189       REAL LocalM5_1 = *(M5+1);
190       REAL LocalM5_2 = *(M5+2);
191       REAL LocalM5_3 = *(M5+3);
192       REAL LocalM2_0 = *(M2);
193       REAL LocalM2_1 = *(M2+1);
194       REAL LocalM2_2 = *(M2+2);
195       REAL LocalM2_3 = *(M2+3);
196       REAL T1_0 = *(T1sMULT) + LocalM2_0;
197       REAL T1_1 = *(T1sMULT+1) + LocalM2_1;
198       REAL T1_2 = *(T1sMULT+2) + LocalM2_2;
199       REAL T1_3 = *(T1sMULT+3) + LocalM2_3;
200       REAL T2_0 = *(C22) + T1_0;
201       REAL T2_1 = *(C22+1) + T1_1;
202       REAL T2_2 = *(C22+2) + T1_2;
203       REAL T2_3 = *(C22+3) + T1_3;
204       (*(C11))   += LocalM2_0;
205       (*(C11+1)) += LocalM2_1;
206       (*(C11+2)) += LocalM2_2;
207       (*(C11+3)) += LocalM2_3;
208       (*(C12))   += LocalM5_0 + T1_0;
209       (*(C12+1)) += LocalM5_1 + T1_1;
210       (*(C12+2)) += LocalM5_2 + T1_2;
211       (*(C12+3)) += LocalM5_3 + T1_3;
212       (*(C22))   = LocalM5_0 + T2_0;
213       (*(C22+1)) = LocalM5_1 + T2_1;
214       (*(C22+2)) = LocalM5_2 + T2_2;
215       (*(C22+3)) = LocalM5_3 + T2_3;
216       (*(C21  )) = (- *(C21  )) + T2_0;
217       (*(C21+1)) = (- *(C21+1)) + T2_1;
218       (*(C21+2)) = (- *(C21+2)) + T2_2;
219       (*(C21+3)) = (- *(C21+3)) + T2_3;
220       M5 += 4;
221       M2 += 4;
222       T1sMULT += 4;
223       C11 += 4;
224       C12 += 4;
225       C21 += 4;
226       C22 += 4;
227     }
228     C11 = (REAL*) ( ((PTR) C11 ) + RowIncrementC);
229     C12 = (REAL*) ( ((PTR) C12 ) + RowIncrementC);
230     C21 = (REAL*) ( ((PTR) C21 ) + RowIncrementC);
231     C22 = (REAL*) ( ((PTR) C22 ) + RowIncrementC);
232   }
233   free(StartHeap);
234 }
235 
strassen_omp(unsigned num_threads,REAL * A,REAL * B,REAL * C,int n)236 void strassen_omp(unsigned num_threads, REAL *A, REAL *B, REAL *C, int n) {
237   omp_set_num_threads(num_threads);
238 	#pragma omp parallel
239   {
240 	  #pragma omp single
241     {
242 	    #pragma omp task untied
243       {
244 	    	OptimizedStrassenMultiply_omp(C, A, B, n, n, n, n, 1);
245       }
246     }
247   }
248 }
249 
measure_time_omp(unsigned num_threads,REAL * A,REAL * B,REAL * C,int n)250 std::chrono::microseconds measure_time_omp(unsigned num_threads, REAL *A, REAL *B, REAL *C, int n) {
251   auto beg = std::chrono::high_resolution_clock::now();
252   strassen_omp(num_threads, A, B, C, n);
253   auto end = std::chrono::high_resolution_clock::now();
254   return std::chrono::duration_cast<std::chrono::microseconds>(end - beg);
255 }
256 
257