1!
2! Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
3!
4! Licensed under the Apache License, Version 2.0 (the "License");
5! you may not use this file except in compliance with the License.
6! You may obtain a copy of the License at
7!
8!     http://www.apache.org/licenses/LICENSE-2.0
9!
10! Unless required by applicable law or agreed to in writing, software
11! distributed under the License is distributed on an "AS IS" BASIS,
12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13! See the License for the specific language governing permissions and
14! limitations under the License.
15!
16
17
18! directives.h -- contains preprocessor directives for F90 rte files
19
20#include "mmul_dir.h"
21
22subroutine ftn_mtaxtb_cmplx8( mra, ncb, kab, alpha, a, lda, b, ldb, beta, &
23     & c, ldc )
24  implicit none
25#include "pgf90_mmul_cmplx8.h"
26
27  ! Everything herein is focused on how the transposition buffer maps
28  ! to the matrix a. The size of the buffer is bufrows * bufcols
29  ! Since once transposed data will be read from the buffer down the rows,
30  ! bufrows corresponds to the columns of a while bufcols corresponds to
31  ! the rows of a. A bit confusing, but correct, I think
32  ! There are 4 cases to consider:
33  ! 1. rowsa <= bufcols AND colsa <= bufrows
34  ! 2. rowsa <= bufcols ( corresponds to a wide matrix )
35  ! 3. colsa <= bufrows ( corresponds to a high matrix )
36  ! 4. Both dimensions of a exceed both dimensions of the buffer
37  !
38  ! The main idea here is that the bufrows will define the usage of the
39  ! L1 cache. We reference the same column or columns multiply while
40  ! accessing multiple partial rows of a transposed in the buffer.
41
42  !
43  !                 rowsa                           <-bufcb->
44  !                                                   colsb
45  !              i = 1, m  -ac->                   j = 1, n --bc->
46  !      |    +-----------------+   ^          +----------+----+  ^
47  !      |    |                 |   |          |          x    |  |
48  !      |    |                 |   |          |          x    |  |
49  !    ak     |  A              | rowchunks=2  |    B     x    |  |
50  !      |    |                 |   |          |          x    |  |
51  !      |    |                 |   |    |     |          x    | ka = 1, k
52  !      |    |                 |   |    |     |          x    |  |
53  !      |    |                 |   |   br     |    a     x    |  |
54  !      v    +xxxxxxxxxxxxxxxxx+   |    |     +xxxxxxxxxxxxxxx+  |
55  !      |    |                 |   |    v     |          x    |  |
56  !      |    |                 |   |          |          x    |  |
57  !           |                 |   |          |          x    |  |
58  !      |    |                 |   |          |    b     x    |  |
59  !      V    +-----------------+   V          +----------+----+  V
60  !            <--colachunks=2-->
61  !     x's mark buffer boudaries on the transposed matrices
62  !     For this case, bufca(1) = bufcols, bufr(1) = bufrows
63
64  colsa = kab
65  rowsb = kab
66  rowsa = mra
67  colsb = ncb
68  if (colsa * rowsa * colsb < min_blocked_mult) then
69    if( beta .eq. 0.0 ) then
70      do j = 1, colsb
71         do i = 1, rowsa
72            temprr0 = 0.0
73            tempri0 = 0.0
74            tempir0 = 0.0
75            tempii0 = 0.0
76            do k = 1, colsa
77                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k))
78            enddo
79            do k = 1, colsa
80                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k))
81            enddo
82            do k = 1, colsa
83                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k))
84            enddo
85            do k = 1, colsa
86                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k))
87            enddo
88            c(i, j) = cmplx((temprr0 - tempii0), (tempri0 + tempir0))
89         enddo
90      enddo
91    else
92      do j = 1, colsb
93         do i = 1, rowsa
94            temprr0 = 0.0
95            tempri0 = 0.0
96            tempir0 = 0.0
97            tempii0 = 0.0
98            do k = 1, colsa
99                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k))
100            enddo
101            do k = 1, colsa
102                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k))
103            enddo
104            do k = 1, colsa
105                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k))
106            enddo
107            do k = 1, colsa
108                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k))
109            enddo
110
111            c(i, j) = beta * c(i, j) + cmplx((temprr0 - tempii0), (tempri0 + tempir0))
112         enddo
113      enddo
114    endif
115  else
116    allocate( bufferb( bufrows * bufcols ) )
117
118
119       ! set the number of buffer row chunks we will work on
120    bufr = min( bufrows, rowsb )
121    bufr_sav = bufr
122    rowchunks = ( rowsb + bufr - 1 )/bufr
123
124    bufcb = min( bufcols, colsb )
125    bufcb_sav = bufcb
126    colbchunks = ( colsb + bufcb - 1)/bufcb
127    ! Note that the starting column index into matrix a (ac) is the same as
128    ! starting index into matrix b. But we need 1 less than that so we can
129    ! add an index to it
130    br = 1
131    ac = 1
132    bc = 1
133    ak = 0
134    colsa_chunk = 4
135    colsa_chunks = mra/colsa_chunk
136    colsa_end = colsa_chunks * colsa_chunk
137    colsa_strt = colsa_end + 1
138
139
140    do rowchunk = 1, rowchunks
141        bc = 1
142        do colbchunk = 1, colbchunks
143           ak = br - 1
144           if( ta .eq. 2 )then !conjugate matrix a; b conjugated in transpose
145              if( br .eq. 1 ) then
146                 bufcb = min( bufcb_sav, colsb - bc + 1 )
147                 bufr = min( bufr_sav, rowsb - br + 1 )
148                 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, &
149                      & bufr, bufcb )
150                 if( beta .eq. 0.0 ) then
151                    do i = 1, colsa_end, colsa_chunk
152                       ndxb = 0
153                       do j = bc, bc + bufcb - 1
154                          temp0 = 0.0
155                          temp1 = 0.0
156                          temp2 = 0.0
157                          temp3 = 0.0
158                          do k = 1, bufr
159                             bufbtemp = bufferb( ndxb + k )
160                             temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
161                             temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
162                             temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
163                             temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
164                          enddo
165                          c( i, j )     = temp0
166                          c( i + 1, j ) = temp1
167                          c( i + 2, j ) = temp2
168                          c( i + 3, j ) = temp3
169                          ndxb = ndxb + bufr
170                       enddo
171                    enddo
172                    ! Now clean up whatever is left from the loop unrolling
173                    do i = colsa_strt, mra
174                       ndxb = 0
175                       do j = bc, bc + bufcb - 1
176                          temp = 0.0
177                          do k = 1, bufr
178                             temp = temp + bufferb( ndxb + k ) * &
179                                  & conjg( a( ak + k, i ) )
180                          enddo
181                          c( i, j ) = temp
182                          ndxb = ndxb + bufr
183                       enddo
184                    enddo
185                 else
186                    do i = 1, colsa_end, colsa_chunk
187                       ndxb = 0
188                       do j = bc, bc + bufcb - 1
189                          temp0 = 0.0
190                          temp1 = 0.0
191                          temp2 = 0.0
192                          temp3 = 0.0
193                          do k = 1, bufr
194                             bufbtemp = bufferb( ndxb + k )
195                             temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
196                             temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
197                             temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
198                             temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
199                          enddo
200                          c( i, j )     = beta * c( i, j )     + temp0
201                          c( i + 1, j ) = beta * c( i + 1, j ) + temp1
202                          c( i + 2, j ) = beta * c( i + 2, j ) + temp2
203                          c( i + 3, j ) = beta * c( i + 3, j ) + temp3
204                          ndxb = ndxb + bufr
205                       enddo
206                    enddo
207                    ! Now clean up whatever is left from the loop unrolling
208                    do i = colsa_strt, mra
209                       ndxb = 0
210                       do j = bc, bc + bufcb - 1
211                          temp = 0.0
212                          do k = 1, bufr
213                             temp = temp + bufferb( ndxb + k ) * &
214                                  & conjg( a( ak + k, i ) )
215                          enddo
216                          c( i, j ) = beta * c( i, j ) + temp
217                          ndxb = ndxb + bufr
218                       enddo
219                    enddo
220                 endif
221              else
222                 bufcb = min( bufcb_sav, colsb - bc + 1 )
223                 bufr = min( bufr_sav, rowsb - br + 1 )
224                 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, &
225                      & bufr, bufcb )
226                 do i = 1, colsa_end, colsa_chunk
227                    ndxb = 0
228                    do j = bc, bc + bufcb - 1
229                       temp0 = 0.0
230                       temp1 = 0.0
231                       temp2 = 0.0
232                       temp3 = 0.0
233                       do k = 1, bufr
234                          bufbtemp = bufferb( ndxb + k )
235                          temp0 = temp0 + bufbtemp * conjg( a( ak + k, i     ) )
236                          temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) )
237                          temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) )
238                          temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) )
239                       enddo
240                       c( i, j )     = c( i, j )     + temp0
241                       c( i + 1, j ) = c( i + 1, j ) + temp1
242                       c( i + 2, j ) = c( i + 2, j ) + temp2
243                       c( i + 3, j ) = c( i + 3, j ) + temp3
244                       ndxb = ndxb + bufr
245                    enddo
246                 enddo
247                 ! Now clean up whatever is left from the loop unrolling
248                 do i = colsa_strt, mra
249                    ndxb = 0
250                    do j = bc, bc + bufcb - 1
251                       temp = 0.0
252                       do k = 1, bufr
253                          temp = temp + bufferb( ndxb + k ) * conjg( a( ak + k, i ) )
254                       enddo
255                       c( i, j ) = c( i, j ) + temp
256                       ndxb = ndxb + bufr
257                    enddo
258                 enddo
259              endif
260           else
261              if( br .eq. 1 ) then
262                 bufcb = min( bufcb_sav, colsb - bc + 1 )
263                 bufr = min( bufr_sav, rowsb - br + 1 )
264                 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, &
265                      & bufr, bufcb )
266                 if( beta .eq. 0.0 ) then
267                    do i = 1, colsa_end, colsa_chunk
268                       ndxb = 0
269                       do j = bc, bc + bufcb - 1
270                          temp0 = 0.0
271                          temp1 = 0.0
272                          temp2 = 0.0
273                          temp3 = 0.0
274                          do k = 1, bufr
275                             bufbtemp = bufferb( ndxb + k )
276                             temp0 = temp0 + bufbtemp * a( ak + k, i )
277                             temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
278                             temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
279                             temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
280                          enddo
281                          c( i, j )     = temp0
282                          c( i + 1, j ) = temp1
283                          c( i + 2, j ) = temp2
284                          c( i + 3, j ) = temp3
285                          ndxb = ndxb + bufr
286                       enddo
287                    enddo
288                    ! Now clean up whatever is left from the loop unrolling
289                    do i = colsa_strt, mra
290                       ndxb = 0
291                       do j = bc, bc + bufcb - 1
292                          temp = 0.0
293                          do k = 1, bufr
294                             temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
295                          enddo
296                          c( i, j ) = temp
297                          ndxb = ndxb + bufr
298                       enddo
299                    enddo
300                 else
301                    do i = 1, colsa_end, colsa_chunk
302                       ndxb = 0
303                       do j = bc, bc + bufcb - 1
304                          temp0 = 0.0
305                          temp1 = 0.0
306                          temp2 = 0.0
307                          temp3 = 0.0
308                          do k = 1, bufr
309                             bufbtemp = bufferb( ndxb + k )
310                             temp0 = temp0 + bufbtemp * a( ak + k, i )
311                             temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
312                             temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
313                             temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
314                          enddo
315                          c( i, j )     = beta * c( i, j )     + temp0
316                          c( i + 1, j ) = beta * c( i + 1, j ) + temp1
317                          c( i + 2, j ) = beta * c( i + 2, j ) + temp2
318                          c( i + 3, j ) = beta * c( i + 3, j ) + temp3
319                          ndxb = ndxb + bufr
320                       enddo
321                    enddo
322                    ! Now clean up whatever is left from the loop unrolling
323                    do i = colsa_strt, mra
324                       ndxb = 0
325                       do j = bc, bc + bufcb - 1
326                          temp = 0.0
327                          do k = 1, bufr
328                             temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
329                          enddo
330                          c( i, j ) = beta * c( i, j ) + temp
331                          ndxb = ndxb + bufr
332                       enddo
333                    enddo
334                 endif
335              else
336                 bufcb = min( bufcb_sav, colsb - bc + 1 )
337                 bufr = min( bufr_sav, rowsb - br + 1 )
338                 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, &
339                      & bufr, bufcb )
340                 do i = 1, colsa_end, colsa_chunk
341                    ndxb = 0
342                    do j = bc, bc + bufcb - 1
343                       temp0 = 0.0
344                       temp1 = 0.0
345                       temp2 = 0.0
346                       temp3 = 0.0
347                       do k = 1, bufr
348                          bufbtemp = bufferb( ndxb + k )
349                          temp0 = temp0 + bufbtemp * a( ak + k, i )
350                          temp1 = temp1 + bufbtemp * a( ak + k, i + 1 )
351                          temp2 = temp2 + bufbtemp * a( ak + k, i + 2 )
352                          temp3 = temp3 + bufbtemp * a( ak + k, i + 3 )
353                       enddo
354                       c( i, j )     = c( i, j )     + temp0
355                       c( i + 1, j ) = c( i + 1, j ) + temp1
356                       c( i + 2, j ) = c( i + 2, j ) + temp2
357                       c( i + 3, j ) = c( i + 3, j ) + temp3
358                       ndxb = ndxb + bufr
359                    enddo
360                 enddo
361                 ! Now clean up whatever is left from the loop unrolling
362                 do i = colsa_strt, mra
363                    ndxb = 0
364                    do j = bc, bc + bufcb - 1
365                       temp = 0.0
366                       do k = 1, bufr
367                          temp = temp + bufferb( ndxb + k ) * a( ak + k, i )
368                       enddo
369                       c( i, j ) = c( i, j ) + temp
370                       ndxb = ndxb + bufr
371                    enddo
372                 enddo
373              endif
374           endif
375           ! adjust the boundaries in the direction of the columns of b
376           ! adjust the row values
377           bc = bc + bufcb
378        enddo
379        br = br + bufr
380        ! controlled but tcbe numbebrcbof bufferb chunks we use.
381
382     enddo
383     deallocate( bufferb )
384     endif
385   return
386 end subroutine ftn_mtaxtb_cmplx8
387
388