1!
2! Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
3!
4! Licensed under the Apache License, Version 2.0 (the "License");
5! you may not use this file except in compliance with the License.
6! You may obtain a copy of the License at
7!
8!     http://www.apache.org/licenses/LICENSE-2.0
9!
10! Unless required by applicable law or agreed to in writing, software
11! distributed under the License is distributed on an "AS IS" BASIS,
12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13! See the License for the specific language governing permissions and
14! limitations under the License.
15!
16
17
18! directives.h -- contains preprocessor directives for F90 rte files
19
20#include "mmul_dir.h"
21
22subroutine ftn_mtaxtb_cmplx16( mra, ncb, kab, alpha, a, lda, b, ldb, beta, &
23     & c, ldc )
24  implicit none
25#include "pgf90_mmul_cmplx16.h"
26
27  ! Everything herein is focused on how the transposition buffer maps
28  ! to the matrix a. The size of the buffer is bufrows * bufcols
29  ! Since once transposed data will be read from the buffer down the rows,
30  ! bufrows corresponds to the columns of a while bufcols corresponds to
31  ! the rows of a. A bit confusing, but correct, I think
32  ! There are 4 cases to consider:
33  ! 1. rowsa <= bufcols AND colsa <= bufrows
34  ! 2. rowsa <= bufcols ( corresponds to a wide matrix )
35  ! 3. colsa <= bufrows ( corresponds to a high matrix )
36  ! 4. Both dimensions of a exceed both dimensions of the buffer
37  !
38  ! The main idea here is that the bufrows will define the usage of the
39  ! L1 cache. We reference the same column or columns multiply while
40  ! accessing multiple partial rows of a transposed in the buffer.
41
42  !
43  !                 rowsa                           <-bufcb->
44  !                                                   colsb
45  !              i = 1, m  -ac->                   j = 1, n --bc->
46  !      |    +-----------------+   ^          +----------+----+  ^
47  !      |    |                 |   |          |          x    |  |
48  !      |    |                 |   |          |          x    |  |
49  !    ak     |  A              | rowchunks=2  |    B     x    |  |
50  !      |    |                 |   |          |          x    |  |
51  !      |    |                 |   |    |     |          x    | ka = 1, k
52  !      |    |                 |   |    |     |          x    |  |
53  !      |    |                 |   |   br     |    a     x    |  |
54  !      v    +xxxxxxxxxxxxxxxxx+   |    |     +xxxxxxxxxxxxxxx+  |
55  !      |    |                 |   |    v     |          x    |  |
56  !      |    |                 |   |          |          x    |  |
57  !           |                 |   |          |          x    |  |
58  !      |    |                 |   |          |    b     x    |  |
59  !      V    +-----------------+   V          +----------+----+  V
60  !            <--colachunks=2-->
61  !     x's mark buffer boudaries on the transposed matrices
62  !     For this case, bufca(1) = bufcols, bufr(1) = bufrows
63
64  colsa = kab
65  rowsb = kab
66  rowsa = mra
67  colsb = ncb
68  if (colsa * rowsa * colsb < min_blocked_mult) then
69    if( beta .eq. 0.0 ) then
70      do j = 1, colsb
71         do i = 1, rowsa
72            temprr0 = 0.0
73            tempri0 = 0.0
74            tempir0 = 0.0
75            tempii0 = 0.0
76            do k = 1, colsa
77                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k))
78            enddo
79            do k = 1, colsa
80                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k))
81            enddo
82            do k = 1, colsa
83                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k))
84            enddo
85            do k = 1, colsa
86                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k))
87            enddo
88            c(i, j) = dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
89         enddo
90      enddo
91    else
92      do j = 1, colsb
93         do i = 1, rowsa
94            temprr0 = 0.0
95            tempri0 = 0.0
96            tempir0 = 0.0
97            tempii0 = 0.0
98            do k = 1, colsa
99                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k))
100            enddo
101            do k = 1, colsa
102                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k))
103            enddo
104            do k = 1, colsa
105                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k))
106            enddo
107            do k = 1, colsa
108                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k))
109            enddo
110
111            c(i, j) = beta * c(i, j) + dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
112         enddo
113      enddo
114    endif
115  else
116    allocate( bufferb( bufrows * bufcols ) )
117
118       ! set the number of buffer row chunks we will work on
119    bufr = min( bufrows, rowsb )
120    bufr_sav = bufr
121    rowchunks = ( rowsb + bufr - 1 )/bufr
122
123    bufcb = min( bufcols, colsb )
124    bufcb_sav = bufcb
125    colbchunks = ( colsb + bufcb - 1)/bufcb
126    ! Note that the starting column index into matrix a (ac) is the same as
127    ! starting index into matrix b. But we need 1 less than that so we can
128    ! add an index to it
129    br = 1
130    ac = 1
131    bc = 1
132    ak = 0
133    colsa_chunk = 4
134    colsa_chunks = mra/colsa_chunk
135    colsa_end = colsa_chunks * colsa_chunk
136    colsa_strt = colsa_end + 1
137
138
139    do rowchunk = 1, rowchunks
140        bc = 1
141        do colbchunk = 1, colbchunks
142           ak = br - 1
143           if( ta .eq. 2 )then !conjugate matrix a; b conjugated in transpose
144              if( br .eq. 1 ) then
145                 bufcb = min( bufcb_sav, colsb - bc + 1 )
146                 bufr = min( bufr_sav, rowsb - br + 1 )
147                 call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, &
148                      & bufr, bufcb )
149                 if( beta .eq. 0.0 ) then
150                    do i = 1, colsa_end, colsa_chunk
151                       ndxb = 0
152                       do j = bc, bc + bufcb - 1
153                          temp0 = 0.0
154                          temp1 = 0.0
155                          temp2 = 0.0
156                          temp3 = 0.0
157                          do k = 1, bufr
158                             bufbtemp = bufferb( ndxb + k )
159                             temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
160                             temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
161                             temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
162                             temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
163                          enddo
164                          c( i, j )     = temp0
165                          c( i + 1, j ) = temp1
166                          c( i + 2, j ) = temp2
167                          c( i + 3, j ) = temp3
168                          ndxb = ndxb + bufr
169                       enddo
170                    enddo
171                    ! Now clean up whatever is left from the loop unrolling
172                    do i = colsa_strt, mra
173                       ndxb = 0
174                       do j = bc, bc + bufcb - 1
175                          temp = 0.0
176                          do k = 1, bufr
177                             temp = temp + bufferb( ndxb + k ) * &
178                                  & conjg( a( ak + k, i ) )
179                          enddo
180                          c( i, j ) = temp
181                          ndxb = ndxb + bufr
182                       enddo
183                    enddo
184                 else
185                    do i = 1, colsa_end, colsa_chunk
186                       ndxb = 0
187                       do j = bc, bc + bufcb - 1
188                          temp0 = 0.0
189                          temp1 = 0.0
190                          temp2 = 0.0
191                          temp3 = 0.0
192                          do k = 1, bufr
193                             bufbtemp = bufferb( ndxb + k )
194                             temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
195                             temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
196                             temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
197                             temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
198                          enddo
199                          c( i, j )     = beta * c( i, j )     + temp0
200                          c( i + 1, j ) = beta * c( i + 1, j ) + temp1
201                          c( i + 2, j ) = beta * c( i + 2, j ) + temp2
202                          c( i + 3, j ) = beta * c( i + 3, j ) + temp3
203                          ndxb = ndxb + bufr
204                       enddo
205                    enddo
206                    ! Now clean up whatever is left from the loop unrolling
207                    do i = colsa_strt, mra
208                       ndxb = 0
209                       do j = bc, bc + bufcb - 1
210                          temp = 0.0
211                          do k = 1, bufr
212                             temp = temp + bufferb( ndxb + k ) * &
213                                  & conjg( a( ak + k, i ) )
214                          enddo
215                          c( i, j ) = beta * c( i, j ) + temp
216                          ndxb = ndxb + bufr
217                       enddo
218                    enddo
219                 endif
220              else
221                 bufcb = min( bufcb_sav, colsb - bc + 1 )
222                 bufr = min( bufr_sav, rowsb - br + 1 )
223                 call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, &
224                      & bufr, bufcb )
225                 do i = 1, colsa_end, colsa_chunk
226                    ndxb = 0
227                    do j = bc, bc + bufcb - 1
228                       temp0 = 0.0
229                       temp1 = 0.0
230                       temp2 = 0.0
231                       temp3 = 0.0
232                       do k = 1, bufr
233                          bufbtemp = bufferb( ndxb + k )
234                          temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
235                          temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
236                          temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
237                          temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
238                       enddo
239                       c( i, j )     = c( i, j )     + temp0
240                       c( i + 1, j ) = c( i + 1, j ) + temp1
241                       c( i + 2, j ) = c( i + 2, j ) + temp2
242                       c( i + 3, j ) = c( i + 3, j ) + temp3
243                       ndxb = ndxb + bufr
244                    enddo
245                 enddo
246                 ! Now clean up whatever is left from the loop unrolling
247                 do i = colsa_strt, mra
248                    ndxb = 0
249                    do j = bc, bc + bufcb - 1
250                       temp = 0.0
251                       do k = 1, bufr
252                          temp = temp + bufferb( ndxb + k ) * conjg( a( ak + k, i ) )
253                       enddo
254                       c( i, j ) = c( i, j ) + temp
255                       ndxb = ndxb + bufr
256                    enddo
257                 enddo
258              endif
259           else
260              if( br .eq. 1 ) then
261                 bufcb = min( bufcb_sav, colsb - bc + 1 )
262                 bufr = min( bufr_sav, rowsb - br + 1 )
263                 call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, &
264                      & bufr, bufcb )
265                 if( beta .eq. 0.0 ) then
266                    do i = 1, colsa_end, colsa_chunk
267                       ndxb = 0
268                       do j = bc, bc + bufcb - 1
269                          temp0 = 0.0
270                          temp1 = 0.0
271                          temp2 = 0.0
272                          temp3 = 0.0
273                          do k = 1, bufr
274                             bufbtemp = bufferb( ndxb + k )
275                             temp0 = temp0 + bufbtemp * a( ak + k, i )
276                             temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
277                             temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
278                             temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
279                          enddo
280                          c( i, j )     = temp0
281                          c( i + 1, j ) = temp1
282                          c( i + 2, j ) = temp2
283                          c( i + 3, j ) = temp3
284                          ndxb = ndxb + bufr
285                       enddo
286                    enddo
287                    ! Now clean up whatever is left from the loop unrolling
288                    do i = colsa_strt, mra
289                       ndxb = 0
290                       do j = bc, bc + bufcb - 1
291                          temp = 0.0
292                          do k = 1, bufr
293                             temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
294                          enddo
295                          c( i, j ) = temp
296                          ndxb = ndxb + bufr
297                       enddo
298                    enddo
299                 else
300                    do i = 1, colsa_end, colsa_chunk
301                       ndxb = 0
302                       do j = bc, bc + bufcb - 1
303                          temp0 = 0.0
304                          temp1 = 0.0
305                          temp2 = 0.0
306                          temp3 = 0.0
307                          do k = 1, bufr
308                             bufbtemp = bufferb( ndxb + k )
309                             temp0 = temp0 + bufbtemp * a( ak + k, i )
310                             temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
311                             temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
312                             temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
313                          enddo
314                          c( i, j )     = beta * c( i, j )     + temp0
315                          c( i + 1, j ) = beta * c( i + 1, j ) + temp1
316                          c( i + 2, j ) = beta * c( i + 2, j ) + temp2
317                          c( i + 3, j ) = beta * c( i + 3, j ) + temp3
318                          ndxb = ndxb + bufr
319                       enddo
320                    enddo
321                    ! Now clean up whatever is left from the loop unrolling
322                    do i = colsa_strt, mra
323                       ndxb = 0
324                       do j = bc, bc + bufcb - 1
325                          temp = 0.0
326                          do k = 1, bufr
327                             temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
328                          enddo
329                          c( i, j ) = beta * c( i, j ) + temp
330                          ndxb = ndxb + bufr
331                       enddo
332                    enddo
333                 endif
334              else
335                 bufcb = min( bufcb_sav, colsb - bc + 1 )
336                 bufr = min( bufr_sav, rowsb - br + 1 )
337                 call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, &
338                      & bufr, bufcb )
339                 do i = 1, colsa_end, colsa_chunk
340                    ndxb = 0
341                    do j = bc, bc + bufcb - 1
342                       temp0 = 0.0
343                       temp1 = 0.0
344                       temp2 = 0.0
345                       temp3 = 0.0
346                       do k = 1, bufr
347                          bufbtemp = bufferb( ndxb + k )
348                          temp0 = temp0 + bufbtemp * a( ak + k, i )
349                          temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
350                          temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
351                          temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
352                       enddo
353                       c( i, j )     = c( i, j )     + temp0
354                       c( i + 1, j ) = c( i + 1, j ) + temp1
355                       c( i + 2, j ) = c( i + 2, j ) + temp2
356                       c( i + 3, j ) = c( i + 3, j ) + temp3
357                       ndxb = ndxb + bufr
358                    enddo
359                 enddo
360                 ! Now clean up whatever is left from the loop unrolling
361                 do i = colsa_strt, mra
362                    ndxb = 0
363                    do j = bc, bc + bufcb - 1
364                       temp = 0.0
365                       do k = 1, bufr
366                          temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
367                       enddo
368                       c( i, j ) = c( i, j ) + temp
369                       ndxb = ndxb + bufr
370                    enddo
371                 enddo
372              endif
373           endif
374           ! adjust the boundaries in the direction of the columns of b
375           ! adjust the row values
376           bc = bc + bufcb
377        enddo
378        br = br + bufr
379        ! controlled but tcbe numbebrcbof bufferb chunks we use.
380
381     enddo
382     deallocate( bufferb )
383   endif
384   return
385 end subroutine ftn_mtaxtb_cmplx16
386
387