1! 2! Copyright (c) 2012-2018, NVIDIA CORPORATION. All rights reserved. 3! 4! Licensed under the Apache License, Version 2.0 (the "License"); 5! you may not use this file except in compliance with the License. 6! You may obtain a copy of the License at 7! 8! http://www.apache.org/licenses/LICENSE-2.0 9! 10! Unless required by applicable law or agreed to in writing, software 11! distributed under the License is distributed on an "AS IS" BASIS, 12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13! See the License for the specific language governing permissions and 14! limitations under the License. 15! 16 17 18! directives.h -- contains preprocessor directives for F90 rte files 19 20#include "mmul_dir.h" 21 22subroutine ftn_mnaxtb_cmplx16( mra, ncb, kab, alpha, a, lda, b, ldb, beta, & 23 & c, ldc ) 24 implicit none 25#include "pgf90_mmul_cmplx16.h" 26 27 ! 28 ! The main idea here is that the bufrows will define the usage of the 29 ! L1 cache. We reference the same column or columns multiply while 30 ! accessing multiple partial rows of matrix a transposed in the buffer. 31 32 ! Remember that everything is buffer centric 33 ! 34 ! 35 ! <- bufca(1)>< (2)> <-bufcb-> 36 ! i = 1, m j = 1, n 37 ! rowsa colsb 38 ! ar ---> bc ---> 39 ! ^ +----------+------+ ^ +----------+----+ ^ 40 ! | | x | | | b x | | 41 ! | | x | | | u x | | 42 ! bufr(1) | A**T x | rowchunks=2 | f a x c | | 43 ! | | x | | | f x | | 44 ! | | buffera x | | | e x | kab = 1, k 45 ! | | x | | br | r x | | 46 ! | | I x III | | | | b x | | 47 ! v +xxxxxxxxxxxxxxxxx+ | | +xxxxxxxxxx+xxxx| | 48 ! ^ | x | | v | x | | 49 ! | | II x IV | | | B b x d | | 50 ! bufr(2) | x | | | x | | 51 ! | | x | | | x | | 52 ! V +----------+------+ V +----------+----+ V 53 ! <--colchunks=2--> 54 ! x's mark buffer boudaries on the transposed matrices 55 ! For this case, bufca(1) = bufcols, bufr(1) = bufrows 56 ! 57 ! Algorimically, we perform dot products of (I,a), (III,a), (II,b) 58 ! and (IV,b). The partial dot products of (I,a) are added to those 59 ! of (II,b) and those of (III,a) are added to those of (IV,b) 60 ! 61 ! Iterations over the "chunks" are buffer based 62 ! while iterations over i and j are matrix based and keep track of where 63 ! we are in the larger scheme of things 64 ! Iterations over i and j are bounded by buffer dimensions 65 ! 66 colsa = kab 67 rowsb = kab 68 rowsa = mra 69 colsb = ncb 70 if (colsa * rowsa * colsb < min_blocked_mult) then 71 if( beta .eq. 0.0 ) then 72 do j = 1, colsb 73 do i = 1, rowsa 74 temprr0 = 0.0 75 tempri0 = 0.0 76 tempir0 = 0.0 77 tempii0 = 0.0 78 do k = 1, colsa 79 temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * real(b(j, k)) 80 enddo 81 do k = 1, colsa 82 tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(j, k)) + aimag(alpha) * real(a(i, k)) * aimag(b(j, k)) 83 enddo 84 do k = 1, colsa 85 tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(j, k)) 86 enddo 87 do k = 1, colsa 88 tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(j, k)) + aimag(alpha) * real(a(i, k)) * real(b(j, k)) 89 enddo 90 c(i, j) = dcmplx((temprr0 - tempii0), (tempri0 + tempir0)) 91 enddo 92 enddo 93 else 94 do j = 1, colsb 95 do i = 1, rowsa 96 temprr0 = 0.0 97 tempri0 = 0.0 98 tempir0 = 0.0 99 tempii0 = 0.0 100 do k = 1, colsa 101 temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * real(b(j, k)) 102 enddo 103 do k = 1, colsa 104 tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(j, k)) + aimag(alpha) * real(a(i, k)) * aimag(b(j, k)) 105 enddo 106 do k = 1, colsa 107 tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(j, k)) 108 enddo 109 do k = 1, colsa 110 tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(j, k)) + aimag(alpha) * real(a(i, k)) * real(b(j, k)) 111 enddo 112 113 c(i, j) = beta * c(i, j) + dcmplx((temprr0 - tempii0), (tempri0 + tempir0)) 114 enddo 115 enddo 116 endif 117 else 118 allocate( buffera( bufrows * bufcols ) ) 119 allocate( bufferb( bufrows * bufcols ) ) 120 121 ! for algoritmic purposes, kab is the number of columns in matrix b, which 122 ! is also the number of columns in matrix a. 123 124 125 bufr = min( bufrows, colsa ) 126 bufr_sav = bufr 127 bufca = min( bufcols, rowsa ) 128 bufca_sav = bufca 129 bufcb = min( bufcols, colsb ) 130 bufcb_sav = bufcb 131 ar_sav = 1 132 ac_sav = 1 133 bc = 1 134 br = 1 135 ! both rowchunks and colchunks are buffer centric 136 rowchunks = ( colsa + bufr - 1 )/bufr 137 colachunks = ( rowsa + bufca - 1 )/bufca 138 colbchunks = ( colsb + bufcb - 1 )/bufcb 139 ! these are for loop unrolling 140 colsb_chunk = 4 141 142 do rowchunk = 1, rowchunks 143 bufcb = bufcb_sav 144 do colbchunk = 1, colbchunks 145 bufcb = min( bufcb_sav, colsb - bc + 1 ) 146 bufr = min( bufr_sav, rowsb - br + 1 ) 147 call ftn_transpose_cmplx16( tb, b( bc, br ), ldb, alpha, bufferb, & 148 & bufr, bufcb ) 149 ! ar = ar_sav 150 ! ac = 1 151 ar = 1 152 ac = ac_sav 153 do colachunk = 1, colachunks 154 if( br .eq. 1 )then 155 ! Note: alpha is 1.0 for matrix a to avoid multiplying by 156 ! alpha * alpha 157 bufca = min( bufca_sav, rowsa - ar + 1 ) 158 call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, one, buffera, & 159 & bufr, bufca ) 160 ndxb0 = 0 161 ndxb1 = bufr 162 ndxb2 = ndxb1 + bufr 163 ndxb3 = ndxb2 + bufr 164 colsb_chunks = bufcb/colsb_chunk 165 colsb_end = bc + colsb_chunks * colsb_chunk - 1 166 colsb_strt = colsb_end + 1 167 jend = bc + bufcb - 1 168 j = bc 169 if( beta .eq. 0.0 ) then 170 do jb = 1, colsb_chunks 171 ndxa = 0 172 do i = ar, ar + bufca - 1 173 temp0 = 0.0 174 temp1 = 0.0 175 temp2 = 0.0 176 temp3 = 0.0 177 do k = 1, bufr 178 bufatemp = buffera( ndxa + k ) 179 temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp 180 temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp 181 temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp 182 temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp 183 enddo 184 c( i, j ) = temp0 185 c( i, j + 1 ) = temp1 186 c( i, j + 2 ) = temp2 187 c( i, j + 3 ) = temp3 188 ndxa = ndxa + bufr 189 enddo 190 ndxa = 0 191 ndxb0 = ndxb0 + bufr * colsb_chunk 192 ndxb1 = ndxb1 + bufr * colsb_chunk 193 ndxb2 = ndxb2 + bufr * colsb_chunk 194 ndxb3 = ndxb3 + bufr * colsb_chunk 195 j = j + 4 196 enddo 197 ndxb = bufr * colsb_chunks * colsb_chunk 198 do j = colsb_strt, jend 199 ndxa = 0 200 do i = ar, ar + bufca - 1 201 temp = 0.0 202 do k = 1, bufr 203 temp = temp + bufferb( ndxb + k ) * & 204 & buffera( ndxa + k ) 205 enddo 206 c( i, j ) = temp 207 ndxa = ndxa + bufr 208 enddo 209 ndxb = ndxb + bufr 210 enddo 211 ! ac = ac + bufca 212 ar = ar + bufca 213 ! print *, "ac: ", ac 214 else 215 do jb = 1, colsb_chunks 216 ndxa = 0 217 do i = ar, ar + bufca - 1 218 temp0 = 0.0 219 temp1 = 0.0 220 temp2 = 0.0 221 temp3 = 0.0 222 do k = 1, bufr 223 bufatemp = buffera( ndxa + k ) 224 temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp 225 temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp 226 temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp 227 temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp 228 enddo 229 c( i, j ) = beta * c( i, j ) + temp0 230 c( i, j + 1 ) = beta * c( i, j + 1 ) + temp1 231 c( i, j + 2 ) = beta * c( i, j + 2 ) + temp2 232 c( i, j + 3 ) = beta * c( i, j + 3 ) + temp3 233 ndxa = ndxa + bufr 234 enddo 235 ndxa = 0 236 ndxb0 = ndxb0 + bufr * colsb_chunk 237 ndxb1 = ndxb1 + bufr * colsb_chunk 238 ndxb2 = ndxb2 + bufr * colsb_chunk 239 ndxb3 = ndxb3 + bufr * colsb_chunk 240 j = j + 4 241 enddo 242 ndxb = bufr * colsb_chunks * colsb_chunk 243 do j = colsb_strt, jend 244 ndxa = 0 245 do i = ar, ar + bufca - 1 246 temp = 0.0 247 do k = 1, bufr 248 temp = temp + bufferb( ndxb + k ) * & 249 & buffera( ndxa + k ) 250 enddo 251 c( i, j ) = beta * c( i, j ) + temp 252 ndxa = ndxa + bufr 253 enddo 254 ndxb = ndxb + bufr 255 enddo 256 ! ac = ac + bufca 257 ar = ar + bufca 258 ! print *, "ac: ", ac 259 endif 260 else 261 bufca = min( bufca_sav, rowsa - ar + 1 ) 262 call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, one , buffera, & 263 & bufr, bufca ) 264 ndxb0 = 0 265 ndxb1 = bufr 266 ndxb2 = ndxb1 + bufr 267 ndxb3 = ndxb2 + bufr 268 colsb_chunks = bufcb/colsb_chunk 269 colsb_end = bc + colsb_chunks * colsb_chunk - 1 270 colsb_strt = colsb_end + 1 271 jend = bc + bufcb - 1 272 j = bc 273 do jb = 1, colsb_chunks 274 ndxa = 0 275 do i = ar, ar + bufca - 1 276 temp0 = 0.0 277 temp1 = 0.0 278 temp2 = 0.0 279 temp3 = 0.0 280 do k = 1, bufr 281 bufatemp = buffera( ndxa + k ) 282 temp0 = temp0 + bufferb( ndxb0 + k ) * bufatemp 283 temp1 = temp1 + bufferb( ndxb1 + k ) * bufatemp 284 temp2 = temp2 + bufferb( ndxb2 + k ) * bufatemp 285 temp3 = temp3 + bufferb( ndxb3 + k ) * bufatemp 286 enddo 287 c( i, j ) = c( i, j ) + temp0 288 c( i, j + 1 ) = c( i, j + 1 ) + temp1 289 c( i, j + 2 ) = c( i, j + 2 ) + temp2 290 c( i, j + 3 ) = c( i, j + 3 ) + temp3 291 ndxa = ndxa + bufr 292 enddo 293 ndxa = 0 294 ndxb0 = ndxb0 + bufr * colsb_chunk 295 ndxb1 = ndxb1 + bufr * colsb_chunk 296 ndxb2 = ndxb2 + bufr * colsb_chunk 297 ndxb3 = ndxb3 + bufr * colsb_chunk 298 j = j + 4 299 enddo 300 ndxb = bufr * colsb_chunks * colsb_chunk 301 do j = colsb_strt, jend 302 ndxa = 0 303 do i = ar, ar + bufca - 1 304 temp = 0.0 305 do k = 1, bufr 306 temp = temp + bufferb( ndxb + k ) * buffera( ndxa + k ) 307 enddo 308 c( i, j ) = c( i, j ) + temp 309 ndxa = ndxa + bufr 310 enddo 311 ndxb = ndxb + bufr 312 enddo 313 ar = ar + bufca 314 endif 315 enddo 316 317 bc = bc + bufcb 318 enddo 319 br = br + bufr 320 ac_sav = ac_sav + bufr 321 bc = 1 322 enddo 323 deallocate( buffera ) 324 deallocate( bufferb ) 325 endif 326 return 327end subroutine ftn_mnaxtb_cmplx16 328