1!
2! Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
3!
4! Licensed under the Apache License, Version 2.0 (the "License");
5! you may not use this file except in compliance with the License.
6! You may obtain a copy of the License at
7!
8!     http://www.apache.org/licenses/LICENSE-2.0
9!
10! Unless required by applicable law or agreed to in writing, software
11! distributed under the License is distributed on an "AS IS" BASIS,
12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13! See the License for the specific language governing permissions and
14! limitations under the License.
15!
16
17
18! directives.h -- contains preprocessor directives for F90 rte files
19
20#include "mmul_dir.h"
21
22subroutine ftn_mnaxtb_cmplx16( mra, ncb, kab, alpha, a, lda, b, ldb, beta, &
23     & c, ldc )
24  implicit none
25#include "pgf90_mmul_cmplx16.h"
26
27  !
28  ! The main idea here is that the bufrows will define the usage of the
29  ! L1 cache. We reference the same column or columns multiply while
30  ! accessing multiple partial rows of matrix a transposed in the buffer.
31
32  !           Remember that everything is buffer centric
33  !
34  !
35  !           <- bufca(1)>< (2)>                         <-bufcb->
36  !               i = 1, m                                j = 1, n
37  !                rowsa                                  colsb
38  !                  ar --->                                bc --->
39  !      ^    +----------+------+   ^              +----------+----+  ^
40  !      |    |          x      |   |              |   b      x    |  |
41  !      |    |          x      |   |              |   u      x    |  |
42  !  bufr(1)  |  A**T    x      | rowchunks=2      |   f  a   x c  |  |
43  !      |    |          x      |   |              |   f      x    |  |
44  !      |    | buffera  x      |   |              |   e      x    | kab = 1, k
45  !      |    |          x      |   |          br  |   r      x    |  |
46  !      |    |    I     x III  |   |          |   |   b      x    |  |
47  !      v    +xxxxxxxxxxxxxxxxx+   |          |   +xxxxxxxxxx+xxxx|  |
48  !      ^    |          x      |   |          v   |          x    |  |
49  !      |    |   II     x IV   |   |              |   B  b   x d  |  |
50  !   bufr(2) |          x      |   |              |          x    |  |
51  !      |    |          x      |   |              |          x    |  |
52  !      V    +----------+------+   V              +----------+----+  V
53  !            <--colchunks=2-->
54  !     x's mark buffer boudaries on the transposed matrices
55  !     For this case, bufca(1) = bufcols, bufr(1) = bufrows
56  !
57  !    Algorimically, we perform dot products of (I,a), (III,a), (II,b)
58  !    and (IV,b). The partial dot products of (I,a) are added to those
59  !    of (II,b) and those of (III,a) are added to those of (IV,b)
60  !
61  ! Iterations over the "chunks" are buffer based
62  ! while iterations over i and j are matrix based and keep track of where
63  ! we are in the larger scheme of things
64  ! Iterations over i and j are bounded by buffer dimensions
65  !
66  colsa = kab
67  rowsb = kab
68  rowsa = mra
69  colsb = ncb
70  if (colsa * rowsa * colsb < min_blocked_mult) then
71    if( beta .eq. 0.0 ) then
72      do j = 1, colsb
73         do i = 1, rowsa
74            temprr0 = 0.0
75            tempri0 = 0.0
76            tempir0 = 0.0
77            tempii0 = 0.0
78            do k = 1, colsa
79                temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * real(b(j, k))
80            enddo
81            do k = 1, colsa
82                tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(j, k)) + aimag(alpha) * real(a(i, k)) * aimag(b(j, k))
83            enddo
84            do k = 1, colsa
85                tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(j, k))
86            enddo
87            do k = 1, colsa
88                tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(j, k)) + aimag(alpha) * real(a(i, k)) * real(b(j, k))
89            enddo
90            c(i, j) = dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
91         enddo
92      enddo
93    else
94      do j = 1, colsb
95         do i = 1, rowsa
96            temprr0 = 0.0
97            tempri0 = 0.0
98            tempir0 = 0.0
99            tempii0 = 0.0
100            do k = 1, colsa
101                temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * real(b(j, k))
102            enddo
103            do k = 1, colsa
104                tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(j, k)) + aimag(alpha) * real(a(i, k)) * aimag(b(j, k))
105            enddo
106            do k = 1, colsa
107                tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(j, k))
108            enddo
109            do k = 1, colsa
110                tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(j, k)) + aimag(alpha) * real(a(i, k)) * real(b(j, k))
111            enddo
112
113            c(i, j) = beta * c(i, j) + dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
114         enddo
115      enddo
116    endif
117  else
118    allocate( buffera( bufrows * bufcols ) )
119    allocate( bufferb( bufrows * bufcols ) )
120
121    ! for algoritmic purposes, kab is the number of columns in matrix b, which
122    ! is also the number of columns in matrix a.
123
124
125    bufr = min( bufrows, colsa )
126    bufr_sav = bufr
127    bufca = min( bufcols, rowsa )
128    bufca_sav = bufca
129    bufcb = min( bufcols, colsb )
130    bufcb_sav = bufcb
131    ar_sav = 1
132    ac_sav = 1
133    bc = 1
134    br = 1
135    ! both rowchunks and colchunks are buffer centric
136    rowchunks = ( colsa + bufr - 1 )/bufr
137    colachunks = ( rowsa + bufca - 1 )/bufca
138    colbchunks = ( colsb + bufcb - 1 )/bufcb
139    ! these are for loop unrolling
140    colsb_chunk = 4
141
142    do rowchunk = 1, rowchunks
143       bufcb = bufcb_sav
144       do colbchunk = 1, colbchunks
145          bufcb = min( bufcb_sav, colsb - bc + 1 )
146          bufr = min( bufr_sav, rowsb - br + 1 )
147          call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, &
148               & bufr, bufcb )
149          !       ar = ar_sav
150          !       ac = 1
151          ar = 1
152          ac = ac_sav
153          do colachunk = 1, colachunks
154             if( br .eq. 1 )then
155                ! Note: alpha is 1.0 for matrix a to avoid multiplying by
156                ! alpha * alpha
157                bufca = min( bufca_sav, rowsa - ar + 1 )
158                call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, one, buffera, &
159                     & bufr, bufca )
160                ndxb0 = 0
161                ndxb1 = bufr
162                ndxb2 = ndxb1 + bufr
163                ndxb3 = ndxb2 + bufr
164                colsb_chunks = bufcb/colsb_chunk
165                colsb_end = bc + colsb_chunks * colsb_chunk - 1
166                colsb_strt = colsb_end + 1
167                jend = bc + bufcb - 1
168                j = bc
169                if( beta .eq. 0.0 ) then
170                   do jb = 1, colsb_chunks
171                      ndxa = 0
172                      do i = ar, ar + bufca - 1
173                         temp0 = 0.0
174                         temp1 = 0.0
175                         temp2 = 0.0
176                         temp3 = 0.0
177                         do k = 1, bufr
178                            bufatemp = buffera( ndxa + k )
179                            temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp
180                            temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp
181                            temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp
182                            temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp
183                         enddo
184                         c( i, j )     = temp0
185                         c( i, j + 1 ) = temp1
186                         c( i, j + 2 ) = temp2
187                         c( i, j + 3 ) = temp3
188                         ndxa = ndxa + bufr
189                      enddo
190                      ndxa = 0
191                      ndxb0 = ndxb0 + bufr * colsb_chunk
192                      ndxb1 = ndxb1 + bufr * colsb_chunk
193                      ndxb2 = ndxb2 + bufr * colsb_chunk
194                      ndxb3 = ndxb3 + bufr * colsb_chunk
195                      j = j + 4
196                   enddo
197                   ndxb = bufr * colsb_chunks * colsb_chunk
198                   do j = colsb_strt, jend
199                      ndxa = 0
200                      do i = ar, ar + bufca - 1
201                         temp = 0.0
202                         do k = 1, bufr
203                            temp = temp + bufferb( ndxb + k ) * &
204                                 & buffera( ndxa + k )
205                         enddo
206                         c( i, j ) = temp
207                         ndxa = ndxa + bufr
208                      enddo
209                      ndxb = ndxb + bufr
210                   enddo
211                   !             ac = ac + bufca
212                   ar = ar + bufca
213                   !           print *, "ac: ", ac
214                else
215                   do jb = 1, colsb_chunks
216                      ndxa = 0
217                      do i = ar, ar + bufca - 1
218                         temp0 = 0.0
219                         temp1 = 0.0
220                         temp2 = 0.0
221                         temp3 = 0.0
222                         do k = 1, bufr
223                            bufatemp = buffera( ndxa + k )
224                            temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp
225                            temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp
226                            temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp
227                            temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp
228                         enddo
229                         c( i, j )     = beta * c( i, j )     + temp0
230                         c( i, j + 1 ) = beta * c( i, j + 1 ) + temp1
231                         c( i, j + 2 ) = beta * c( i, j + 2 ) + temp2
232                         c( i, j + 3 ) = beta * c( i, j + 3 ) + temp3
233                         ndxa = ndxa + bufr
234                      enddo
235                      ndxa = 0
236                      ndxb0 = ndxb0 + bufr * colsb_chunk
237                      ndxb1 = ndxb1 + bufr * colsb_chunk
238                      ndxb2 = ndxb2 + bufr * colsb_chunk
239                      ndxb3 = ndxb3 + bufr * colsb_chunk
240                      j = j + 4
241                   enddo
242                   ndxb = bufr * colsb_chunks * colsb_chunk
243                   do j = colsb_strt, jend
244                      ndxa = 0
245                      do i = ar, ar + bufca - 1
246                         temp = 0.0
247                         do k = 1, bufr
248                            temp = temp + bufferb( ndxb + k ) * &
249                                 & buffera( ndxa + k )
250                         enddo
251                         c( i, j ) = beta * c( i, j ) + temp
252                         ndxa = ndxa + bufr
253                      enddo
254                      ndxb = ndxb + bufr
255                   enddo
256                   !             ac = ac + bufca
257                   ar = ar + bufca
258                   !           print *, "ac: ", ac
259                endif
260             else
261                bufca = min( bufca_sav, rowsa - ar + 1 )
262                call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, one  , buffera, &
263                     & bufr, bufca )
264                ndxb0 = 0
265                ndxb1 = bufr
266                ndxb2 = ndxb1 + bufr
267                ndxb3 = ndxb2 + bufr
268                colsb_chunks = bufcb/colsb_chunk
269                colsb_end = bc + colsb_chunks * colsb_chunk - 1
270                colsb_strt = colsb_end + 1
271                jend = bc + bufcb - 1
272                j = bc
273                do jb = 1, colsb_chunks
274                   ndxa = 0
275                   do i = ar, ar + bufca - 1
276                      temp0 = 0.0
277                      temp1 = 0.0
278                      temp2 = 0.0
279                      temp3 = 0.0
280                      do k = 1, bufr
281                         bufatemp = buffera( ndxa + k )
282                         temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp
283                         temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp
284                         temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp
285                         temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp
286                      enddo
287                      c( i, j )     = c( i, j )     + temp0
288                      c( i, j + 1 ) = c( i, j + 1 ) + temp1
289                      c( i, j + 2 ) = c( i, j + 2 ) + temp2
290                      c( i, j + 3 ) = c( i, j + 3 ) + temp3
291                      ndxa = ndxa + bufr
292                   enddo
293                   ndxa = 0
294                   ndxb0 = ndxb0 + bufr * colsb_chunk
295                   ndxb1 = ndxb1 + bufr * colsb_chunk
296                   ndxb2 = ndxb2 + bufr * colsb_chunk
297                   ndxb3 = ndxb3 + bufr * colsb_chunk
298                   j = j + 4
299                enddo
300                ndxb = bufr * colsb_chunks * colsb_chunk
301                do j = colsb_strt, jend
302                   ndxa = 0
303                   do i = ar, ar + bufca - 1
304                      temp = 0.0
305                      do k = 1, bufr
306                         temp = temp + bufferb( ndxb + k ) * buffera( ndxa + k )
307                      enddo
308                      c( i, j ) = c( i, j ) + temp
309                      ndxa = ndxa + bufr
310                   enddo
311                   ndxb = ndxb + bufr
312                enddo
313                ar = ar + bufca
314             endif
315          enddo
316
317          bc = bc + bufcb
318       enddo
319       br = br + bufr
320       ac_sav = ac_sav + bufr
321       bc = 1
322    enddo
323    deallocate( buffera )
324    deallocate( bufferb )
325  endif
326  return
327end subroutine ftn_mnaxtb_cmplx16
328