1!
2! Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
3!
4! Licensed under the Apache License, Version 2.0 (the "License");
5! you may not use this file except in compliance with the License.
6! You may obtain a copy of the License at
7!
8!     http://www.apache.org/licenses/LICENSE-2.0
9!
10! Unless required by applicable law or agreed to in writing, software
11! distributed under the License is distributed on an "AS IS" BASIS,
12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13! See the License for the specific language governing permissions and
14! limitations under the License.
15!
16
17
18! directives.h -- contains preprocessor directives for F90 rte files
19
20#include "mmul_dir.h"
21
22subroutine ftn_mtaxnb_cmplx8( mra, ncb, kab, alpha, a, lda, b, ldb, beta, &
23     & c, ldc )
24  implicit none
25#include "pgf90_mmul_cmplx8.h"
26
27  !
28  !                 rowsa
29  !           <-bufca(1)>< (2) >                       colsb
30  !              i = 1, m  -ar->                   j = 1, n
31  !      ^    +----------+------+   ^  bk = 0->+--------------------+  ^
32  !      |    |          x      |   |          |                    |  |
33  !      |    |          x      |   |          |                    |  |
34  !  bufr(1)  |  A**T    x      | rowchunks=2  |                    |  |
35  !      |    |          x      |   |          |         B          |  |
36  !  |   |    | buffera  x      |   |          |                    | ka = 1, k
37  !  |   |    |          x      |   |          |                    |  |
38  ! ac   |    |    I     x III  |   |          |                    |  |
39  !  |   v    +xxxxxxxxxxxxxxxxx+   |  bk = bk>+xxxxxxxxxxxxxxxxxxxx+  |
40  !  v   ^    |          x      |   |   + bufr |                    |  |
41  !      |    |          x      |   |          |                    |  |
42  !   bufr(2) |          x      |   |          |                    |  |
43  !      |    |   II     x IV   |   |          |                    |  |
44  !      V    +----------+------+   V          +--------------------+  V
45  !            <--colachunks=2-->
46  !     x's mark buffer boudaries on the transposed matrix for A, the
47  !     part of B that is multiplied by buffera in B
48  !
49
50
51  !( I think this comment should be removed. The exchange of meanings for
52  ! colsa and rowsa is valid IF you are simply writing DO loops, but
53  ! we are not doing that herein.
54  ! since matrix a is transposed, the rows and columns get switched
55
56  colsa = kab
57  rowsb = kab
58  rowsa = mra
59  colsb = ncb
60  if (colsa * rowsa * colsb < min_blocked_mult) then
61    if( beta .eq. 0.0 ) then
62      do j = 1, colsb
63         do i = 1, rowsa
64            temprr0 = 0.0
65            tempri0 = 0.0
66            tempir0 = 0.0
67            tempii0 = 0.0
68            do k = 1, colsa
69                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(k, j)) - aimag(alpha) * aimag(a(k, i)) * real(b(k, j))
70            enddo
71            do k = 1, colsa
72                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(k, j)) + aimag(alpha) * real(a(k, i)) * aimag(b(k, j))
73            enddo
74            do k = 1, colsa
75                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(k, j))
76            enddo
77            do k = 1, colsa
78                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(k, j)) + aimag(alpha) * real(a(k, i)) * real(b(k, j))
79            enddo
80            c(i, j) = cmplx((temprr0 - tempii0), (tempri0 + tempir0))
81         enddo
82      enddo
83    else
84      do j = 1, colsb
85         do i = 1, rowsa
86            temprr0 = 0.0
87            tempri0 = 0.0
88            tempir0 = 0.0
89            tempii0 = 0.0
90            do k = 1, colsa
91                temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(k, j)) - aimag(alpha) * aimag(a(k, i)) * real(b(k, j))
92            enddo
93            do k = 1, colsa
94                tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(k, j)) + aimag(alpha) * real(a(k, i)) * aimag(b(k, j))
95            enddo
96            do k = 1, colsa
97                tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(k, j))
98            enddo
99            do k = 1, colsa
100                tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(k, j)) + aimag(alpha) * real(a(k, i)) * real(b(k, j))
101            enddo
102
103            c(i, j) = beta * c(i, j) + cmplx((temprr0 - tempii0), (tempri0 + tempir0))
104         enddo
105      enddo
106    endif
107  else
108    allocate( buffera( bufrows * bufcols ) )
109
110    bufca = min( rowsa, bufcols )
111    bufca_sav = bufca
112    colachunks = ( rowsa + bufca - 1)/bufca
113    ! set the number of buffer row chunks we will work on
114    bufr = min( colsa, bufrows )
115    bufr_sav = bufr
116    rowchunks = ( colsa + bufr - 1 )/bufr
117
118    ac = 1   ! column index in matrix a for gather.
119    ! Note that the starting column index into matrix a (ac) is the same as
120    ! starting index into matrix b. But we need 1 less than that so we can
121    ! add an index to it
122    ar = 1
123    colsb_chunk = 4
124    colsb_chunks = colsb/colsb_chunk
125    colsb_end = colsb_chunks * colsb_chunk
126    colsb_strt = colsb_end + 1
127
128    do rowchunk = 1, rowchunks ! This will set the values over k
129       ar = 1 ! row index in matrix a for gather and reference to C()
130       !     loc = rowsa - bufca
131       do colachunk = 1, colachunks ! this over m
132          if( ac .eq. 1 ) then
133             bufca = min( bufca_sav, rowsa - ar + 1 )
134             bufr = min( bufr_sav, colsa - ac + 1 )
135             call ftn_gather_cmplx8( ta, a( ac, ar ), lda, alpha,  buffera, &
136                  & bufr, bufca )
137             bk = ac - 1
138             if( beta .eq. 0.0 ) then
139                do j = 1, colsb_end, colsb_chunk
140                   ndxa = 0
141                   do i = ar, ar + bufca - 1
142                      temp0 = 0
143                      temp1 = 0
144                      temp2 = 0
145                      temp3 = 0
146                      do k = 1, bufr
147                         bufatemp = buffera( ndxa + k )
148                         temp0 = temp0 + bufatemp * b( bk + k, j )
149                         temp1 = temp1 + bufatemp * b( bk + k, j + 1 )
150                         temp2 = temp2 + bufatemp * b( bk + k, j + 2 )
151                         temp3 = temp3 + bufatemp * b( bk + k, j + 3 )
152                      enddo
153                      c( i, j )     = temp0
154                      c( i, j + 1 ) = temp1
155                      c( i, j + 2 ) = temp2
156                      c( i, j + 3 ) = temp3
157                      ndxa = ndxa + bufr
158                   enddo
159                enddo
160                do j = colsb_strt, colsb
161                   ndxa = 0
162                   do i = ar, ar + bufca - 1
163                      temp = 0.0
164                      do k = 1, bufr
165                         temp = temp + buffera( ndxa + k ) * b( bk + k, j )
166                      enddo
167                      c( i, j ) = temp
168                      ndxa = ndxa + bufr
169                   enddo
170                enddo
171             else
172                do j = 1, colsb_end, colsb_chunk
173                   ndxa = 0
174                   do i = ar, ar + bufca - 1
175                      temp0 = 0
176                      temp1 = 0
177                      temp2 = 0
178                      temp3 = 0
179                      do k = 1, bufr
180                         bufatemp = buffera( ndxa + k )
181                         temp0 = temp0 + bufatemp * b( bk + k, j )
182                         temp1 = temp1 + bufatemp * b( bk + k, j + 1 )
183                         temp2 = temp2 + bufatemp * b( bk + k, j + 2 )
184                         temp3 = temp3 + bufatemp * b( bk + k, j + 3 )
185                      enddo
186                      c( i, j )     = beta * c( i, j )     + temp0
187                      c( i, j + 1 ) = beta * c( i, j + 1 ) + temp1
188                      c( i, j + 2 ) = beta * c( i, j + 2 ) + temp2
189                      c( i, j + 3 ) = beta * c( i, j + 3 ) + temp3
190                      ndxa = ndxa + bufr
191                   enddo
192                enddo
193                do j = colsb_strt, colsb
194                   ndxa = 0
195                   do i = ar, ar + bufca - 1
196                      temp = 0.0
197                      do k = 1, bufr
198                         temp = temp + buffera( ndxa + k ) * b( bk + k, j )
199                      enddo
200                      c( i, j ) = beta * c( i, j ) + temp
201                      ndxa = ndxa + bufr
202                   enddo
203                enddo
204             endif
205          else
206             bufca = min( bufca_sav, rowsa - ar + 1 )
207             bufr = min( bufr_sav, colsa - ac + 1 )
208             call ftn_gather_cmplx8( ta, a( ac, ar ), lda, alpha,  buffera, &
209                  & bufr, bufca )
210             bk = ac - 1
211             do j = 1, colsb_end, colsb_chunk
212                ndxa = 0
213                do i = ar, ar + bufca - 1
214                   temp0 = 0
215                   temp1 = 0
216                   temp2 = 0
217                   temp3 = 0
218                   do k = 1, bufr
219                      bufatemp = buffera( ndxa + k )
220                      temp0 = temp0 + bufatemp * b( bk + k, j )
221                      temp1 = temp1 + bufatemp * b( bk + k, j + 1 )
222                      temp2 = temp2 + bufatemp * b( bk + k, j + 2 )
223                      temp3 = temp3 + bufatemp * b( bk + k, j + 3 )
224                   enddo
225                   c( i, j )     = c( i, j )     + temp0
226                   c( i, j + 1 ) = c( i, j + 1 ) + temp1
227                   c( i, j + 2 ) = c( i, j + 2 ) + temp2
228                   c( i, j + 3 ) = c( i, j + 3 ) + temp3
229                   ndxa = ndxa + bufr
230                enddo
231             enddo
232             do j = colsb_strt, colsb
233                ndxa = 0
234                do i = ar, ar + bufca - 1
235                   temp = 0.0
236                   do k = 1, bufr
237                      temp = temp + buffera( ndxa + k ) * b( bk + k, j )
238                   enddo
239                   c( i, j ) = c( i, j ) + temp
240                   ndxa = ndxa + bufr
241                enddo
242             enddo
243          endif
244          ar = ar + bufca
245          !        bufr = min( bufr, lor )
246          !        lor = lor - bufr
247       enddo
248       ac = ac + bufr
249       !     bufca = min( bufca, loc )
250       !     loc = loc - bufca ! Note: this is not circular since the loops are
251       ! controlled but the number of buffera chunks we use.
252       !     bufr = bufr + colsa
253
254       !     lor = colsa - bufr
255    enddo
256
257    deallocate( buffera )
258    endif
259  return
260end subroutine ftn_mtaxnb_cmplx8
261