1!
2! Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
3!
4! Licensed under the Apache License, Version 2.0 (the "License");
5! you may not use this file except in compliance with the License.
6! You may obtain a copy of the License at
7!
8!     http://www.apache.org/licenses/LICENSE-2.0
9!
10! Unless required by applicable law or agreed to in writing, software
11! distributed under the License is distributed on an "AS IS" BASIS,
12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13! See the License for the specific language governing permissions and
14! limitations under the License.
15!
16
17
18! directives.h -- contains preprocessor directives for F90 rte files
19
20#include "mmul_dir.h"
21
22subroutine ftn_mnaxnb_cmplx16( mra, ncb, kab, alpha, a, lda, b, ldb, &
23     & beta, c, ldc )
24  implicit none
25#include "pgf90_mmul_cmplx16.h"
26
27  ! Everything herein is focused on how the transposition buffer maps
28  ! to the matrix a. The size of the buffer is bufrows * bufcols
29  ! Since once transposed data will be read from the buffer down the rows,
30  ! bufrows corresponds to the columns of a while bufcols corresponds to
31  ! the rows of a. A bit confusing, but correct, I think
32  ! There are 4 cases to consider:
33  ! 1. rowsa <= bufcols AND colsa <= bufrows
34  ! 2. rowsa <= bufcols ( corresponds to a wide matrix )
35  ! 3. colsa <= bufrows ( corresponds to a high matrix )
36  ! 4. Both dimensions of a exceed both dimensions of the buffer
37  !
38  ! The main idea here is that the bufrows will define the usage of the
39  ! L1 cache. We reference the same column or columns multiply while
40  ! accessing multiple partial rows of a transposed in the buffer.
41
42
43  !
44  !                 rowsa                           colsb
45  !           <-bufca(1)>< (2) >                <-bufcb(1)><(2)>
46  !              i = 1, m  -ar->                   j = 1, n --br->
47  !      ^    +----------+------+   ^          +----------+----+  ^
48  !      |    |          x      |   |          |          x    |  |
49  !      |    |          x      |   |          |          x    |  |
50  !  bufr(1)  |  A**T    x      | rowchunks=2  |    B     x    |  |
51  !      |    |          x      |   |          |          x    |  |
52  !  |   |    | buffera  x      |   |    |     | bufferb  x    | ka = 1, k
53  !  |   |    |          x      |   |    |     |          x    |  |
54  ! ac   |    |    I     x III  |   |   bc     |    a     x c  |  |
55  !  |   v    +xxxxxxxxxxxxxxxxx+   |    |     +xxxxxxxxxxxxxxx+  |
56  !  v   ^    |          x      |   |    v     |          x    |  |
57  !      |    |          x      |   |          |          x    |  |
58  !   bufr(2) |          x      |   |          |          x    |  |
59  !      |    |   II     x IV   |   |          |    b     x d  |  |
60  !      V    +----------+------+   V          +----------+----+  V
61  !            <--colachunks=2-->               <-colbchunks=2>
62  !     x's mark buffer boudaries on the transposed matrices
63  !     For this case, bufca(1) = bufcols, bufr(1) = bufrows
64
65  ! The structure of this code came from mnaxnb_real.
66  colsa = kab
67  rowsb = kab
68  rowsa = mra
69  colsb = ncb
70  if (colsa * rowsa * colsb < min_blocked_mult) then
71    if( beta .eq. 0.0 ) then
72      do j = 1, colsb
73         do i = 1, rowsa
74            temprr0 = 0.0
75            tempri0 = 0.0
76            tempir0 = 0.0
77            tempii0 = 0.0
78            do k = 1, colsa
79                temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * real(b(k, j))
80            enddo
81            do k = 1, colsa
82                tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(k, j)) + aimag(alpha) * real(a(i, k)) * aimag(b(k, j))
83            enddo
84            do k = 1, colsa
85                tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(k, j))
86            enddo
87            do k = 1, colsa
88                tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(k, j)) + aimag(alpha) * real(a(i, k)) * real(b(k, j))
89            enddo
90            c(i, j) = dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
91         enddo
92      enddo
93    else
94      do j = 1, colsb
95         do i = 1, rowsa
96            temprr0 = 0.0
97            tempri0 = 0.0
98            tempir0 = 0.0
99            tempii0 = 0.0
100            do k = 1, colsa
101                temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * real(b(k, j))
102            enddo
103            do k = 1, colsa
104                tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(k, j)) + aimag(alpha) * real(a(i, k)) * aimag(b(k, j))
105            enddo
106            do k = 1, colsa
107                tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(k, j))
108            enddo
109            do k = 1, colsa
110                tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(k, j)) + aimag(alpha) * real(a(i, k)) * real(b(k, j))
111            enddo
112
113            c(i, j) = beta * c(i, j) + dcmplx((temprr0 - tempii0), (tempri0 + tempir0))
114         enddo
115      enddo
116    endif
117  else
118    allocate( buffera( bufrows * bufcols ) )
119
120    bufca = min( bufcols, rowsa )
121
122    colachunks = ( rowsa + bufcols - 1)/bufcols
123    ! set the number of buffer row chunks we will work on
124    bufr = min( bufrows, colsa )
125    bufr_sav = bufr
126    rowchunks = ( colsa + bufr - 1 )/bufr
127    bufca_sav = bufca
128    ac = 1   ! column index in matrix a for transpose
129    !  lor = colsa - bufr ! left-over rows adjusts for the the fact that
130    ! colsa/bufr * bufr may not be equal to colsa
131    ! Note that the starting column index into matrix a (ac) is the same as
132    ! starting index into matrix b. But we need 1 less than that so we can
133    ! add an index to it
134    colsb_chunks = 4
135    colsb_end = colsb/colsb_chunks * colsb_chunks
136    colsb_strt = colsb_end + 1
137    do rowchunk = 1, rowchunks
138       ar = 1
139       do colachunk = 1, colachunks
140          bufca = min( bufca_sav, rowsa - ar + 1 )
141          bufr = min( bufr_sav, colsa - ac + 1 )
142          call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, alpha, buffera, &
143               & bufr, bufca )
144          if ( ac .eq. 1 )then
145
146             if( beta .eq. 0.0 ) then
147                do j = 1, colsb_end, colsb_chunks
148                   ndxa = 0
149                   do i = ar, ar + bufca - 1
150                      bk = ac - 1
151                      temprr0 = 0.0
152                      temprr1 = 0.0
153                      temprr2 = 0.0
154                      temprr3 = 0.0
155                      tempii0 = 0.0
156                      tempii1 = 0.0
157                      tempii2 = 0.0
158                      tempii3 = 0.0
159                      tempri0 = 0.0
160                      tempri1 = 0.0
161                      tempri2 = 0.0
162                      tempri3 = 0.0
163                      tempir0 = 0.0
164                      tempir1 = 0.0
165                      tempir2 = 0.0
166                      tempir3 = 0.0
167                      do k = 1, bufr ! dot product of real(a) * real(b)
168                         bufatempr = real( buffera( ndxa + k ) )
169                         temprr0 = temprr0 + bufatempr * &
170                              & real( b( bk + k, j ) )
171                         temprr1 = temprr1 + bufatempr * &
172                              & real( b( bk + k, j + 1 ) )
173                         temprr2 = temprr2 + bufatempr * &
174                              & real( b( bk + k, j + 2 ) )
175                         temprr3 = temprr3 + bufatempr * &
176                              & real( b( bk + k, j + 3 ) )
177                         ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
178                         ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
179                         ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
180                         ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
181                      enddo
182                      do k = 1, bufr ! dot product of aimag(a) * aimag(b)
183                         bufatempi = aimag( buffera( ndxa + k ) )
184                         tempii0 = tempii0 + bufatempi * &
185                              & aimag( b( bk + k, j ) )
186                         tempii1 = tempii1 + bufatempi * &
187                              & aimag( b( bk + k, j + 1 ) )
188                         tempii2 = tempii2 + bufatempi * &
189                              & aimag( b( bk + k, j + 2 ) )
190                         tempii3 = tempii3 + bufatempi * &
191                              & aimag( b( bk + k, j + 3 ) )
192                         ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
193                         ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
194                         ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
195                         ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
196                      enddo
197                      do k = 1, bufr ! cross dot product of real(a) * aimag(b)
198                         bufatempr = real( buffera( ndxa + k ) )
199                         tempri0 = tempri0 + bufatempr * &
200                              & aimag( b( bk + k, j ) )
201                         tempri1 = tempri1 + bufatempr * &
202                              & aimag( b( bk + k, j + 1 ) )
203                         tempri2 = tempri2 + bufatempr * &
204                              & aimag( b( bk + k, j + 2 ) )
205                         tempri3 = tempri3 + bufatempr * &
206                              & aimag( b( bk + k, j + 3 ) )
207                         ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
208                         ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
209                         ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
210                         ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
211                      enddo
212                      do k = 1, bufr ! cross dot product of aimag(a) * real(b)
213                         bufatempi = aimag( buffera( ndxa + k ) )
214                         tempir0 = tempir0 + bufatempi * &
215                              & real( b( bk + k, j ) )
216                         tempir1 = tempir1 + bufatempi * &
217                              & real( b( bk + k, j + 1 ) )
218                         tempir2 = tempir2 + bufatempi * &
219                              & real( b( bk + k, j + 2 ) )
220                         tempir3 = tempir3 + bufatempi * &
221                              & real( b( bk + k, j + 3 ) )
222                         ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
223                         ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
224                         ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
225                         ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
226                      enddo
227                      c(i, j    ) = DCMPLX(temprr0 - tempii0, tempri0 + tempir0)
228                      c(i, j + 1) = DCMPLX(temprr1 - tempii1, tempri1 + tempir1)
229                      c(i, j + 2) = DCMPLX(temprr2 - tempii2, tempri2 + tempir2)
230                      c(i, j + 3) = DCMPLX(temprr3 - tempii3, tempri3 + tempir3)
231                      ndxa = ndxa + bufr
232                   enddo
233                enddo
234
235                ! This takes care of the last
236                ! colsb - colsb/colsb_chunks*colsb_chunks cases
237                do j = colsb_strt, colsb
238                   ndxa = 0
239                   bk = ac - 1
240                   do i = ar, ar + bufca - 1
241                      temp = 0.0
242                      do k = 1, bufr
243                         temp = temp + buffera( ndxa + k ) * b( bk + k, j )
244                      enddo
245                      c( i, j ) = temp
246                      ndxa = ndxa + bufr
247                   enddo
248                enddo
249             else
250                do j = 1, colsb_end, colsb_chunks
251                   ndxa = 0
252                   do i = ar, ar + bufca - 1
253                      bk = ac - 1
254                      temp0 = 0.0
255                      temp1 = 0.0
256                      temp2 = 0.0
257                      temp3 = 0.0
258                      do k = 1, bufr
259                         bufatemp = buffera( ndxa + k )
260                         temp0 = temp0 + bufatemp * b( bk + k, j )
261                         temp1 = temp1 + bufatemp * b( bk + k, j + 1 )
262                         temp2 = temp2 + bufatemp * b( bk + k, j + 2 )
263                         temp3 = temp3 + bufatemp * b( bk + k, j + 3 )
264                         ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
265                         ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
266                         ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
267                         ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
268                      enddo
269                      c( i, j     ) = beta * c( i, j )     + temp0
270                      c( i, j + 1 ) = beta * c( i, j + 1 ) + temp1
271                      c( i, j + 2 ) = beta * c( i, j + 2 ) + temp2
272                      c( i, j + 3 ) = beta * c( i, j + 3 ) + temp3
273                      ndxa = ndxa + bufr
274                   enddo
275                enddo
276
277                ! This takes care of the last
278                ! colsb - colsb/colsb_chunks*colsb_chunks cases
279                do j = colsb_strt, colsb
280                   ndxa = 0
281                   bk = ac - 1
282                   do i = ar, ar + bufca - 1
283                      temp = 0.0
284                      do k = 1, bufr
285                         temp = temp + buffera( ndxa + k ) * b( bk + k, j )
286                      enddo
287                      c( i, j ) = beta * c( i, j ) + temp
288                      ndxa = ndxa + bufr
289                   enddo
290                enddo
291             endif
292          else
293             do j = 1, colsb_end, colsb_chunks
294                ndxa = 0
295                do i = ar, ar + bufca - 1
296                   bk = ac - 1
297                   temp0 = 0.0
298                   temp1 = 0.0
299                   temp2 = 0.0
300                   temp3 = 0.0
301                   do k = 1, bufr
302                      bufatemp = buffera( ndxa + k )
303                      temp0 = temp0 + bufatemp * b( bk + k, j )
304                      temp1 = temp1 + bufatemp * b( bk + k, j + 1 )
305                      temp2 = temp2 + bufatemp * b( bk + k, j + 2 )
306                      temp3 = temp3 + bufatemp * b( bk + k, j + 3 )
307                      ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 )
308                      ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 )
309                      ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 )
310                      ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 )
311                   enddo
312                   c( i, j     ) = c( i, j )     + temp0
313                   c( i, j + 1 ) = c( i, j + 1 ) + temp1
314                   c( i, j + 2 ) = c( i, j + 2 ) + temp2
315                   c( i, j + 3 ) = c( i, j + 3 ) + temp3
316                   ndxa = ndxa + bufr
317                enddo
318             enddo
319
320             ! This takes care of the last colsb - colsb/colsb_chunks*colsb_chunks
321             ! cases
322             do j = colsb_strt, colsb
323                ndxa = 0
324                bk = ac - 1
325                do i = ar, ar + bufca - 1
326                   temp = 0.0
327                   do k = 1, bufr
328                      temp = temp + buffera( ndxa + k ) * b( bk + k, j )
329                   enddo
330                   c( i, j ) = c( i, j ) + temp
331                   ndxa = ndxa + bufr
332                enddo
333             enddo
334          endif
335          ! adjust the boundaries in the direction of the columns of a
336
337          ar = ar + bufca
338       enddo
339       ! adjust the row values
340       ac = ac + bufr
341    enddo
342    deallocate( buffera )
343  endif
344  return
345end subroutine ftn_mnaxnb_cmplx16
346