1! 2! Copyright (c) 2012-2018, NVIDIA CORPORATION. All rights reserved. 3! 4! Licensed under the Apache License, Version 2.0 (the "License"); 5! you may not use this file except in compliance with the License. 6! You may obtain a copy of the License at 7! 8! http://www.apache.org/licenses/LICENSE-2.0 9! 10! Unless required by applicable law or agreed to in writing, software 11! distributed under the License is distributed on an "AS IS" BASIS, 12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13! See the License for the specific language governing permissions and 14! limitations under the License. 15! 16 17 18! directives.h -- contains preprocessor directives for F90 rte files 19 20#include "mmul_dir.h" 21 22subroutine ftn_mtaxtb_cmplx8( mra, ncb, kab, alpha, a, lda, b, ldb, beta, & 23 & c, ldc ) 24 implicit none 25#include "pgf90_mmul_cmplx8.h" 26 27 ! Everything herein is focused on how the transposition buffer maps 28 ! to the matrix a. The size of the buffer is bufrows * bufcols 29 ! Since once transposed data will be read from the buffer down the rows, 30 ! bufrows corresponds to the columns of a while bufcols corresponds to 31 ! the rows of a. A bit confusing, but correct, I think 32 ! There are 4 cases to consider: 33 ! 1. rowsa <= bufcols AND colsa <= bufrows 34 ! 2. rowsa <= bufcols ( corresponds to a wide matrix ) 35 ! 3. colsa <= bufrows ( corresponds to a high matrix ) 36 ! 4. Both dimensions of a exceed both dimensions of the buffer 37 ! 38 ! The main idea here is that the bufrows will define the usage of the 39 ! L1 cache. We reference the same column or columns multiply while 40 ! accessing multiple partial rows of a transposed in the buffer. 41 42 ! 43 ! rowsa <-bufcb-> 44 ! colsb 45 ! i = 1, m -ac-> j = 1, n --bc-> 46 ! | +-----------------+ ^ +----------+----+ ^ 47 ! | | | | | x | | 48 ! | | | | | x | | 49 ! ak | A | rowchunks=2 | B x | | 50 ! | | | | | x | | 51 ! | | | | | | x | ka = 1, k 52 ! | | | | | | x | | 53 ! | | | | br | a x | | 54 ! v +xxxxxxxxxxxxxxxxx+ | | +xxxxxxxxxxxxxxx+ | 55 ! | | | | v | x | | 56 ! | | | | | x | | 57 ! | | | | x | | 58 ! | | | | | b x | | 59 ! V +-----------------+ V +----------+----+ V 60 ! <--colachunks=2--> 61 ! x's mark buffer boudaries on the transposed matrices 62 ! For this case, bufca(1) = bufcols, bufr(1) = bufrows 63 64 colsa = kab 65 rowsb = kab 66 rowsa = mra 67 colsb = ncb 68 if (colsa * rowsa * colsb < min_blocked_mult) then 69 if( beta .eq. 0.0 ) then 70 do j = 1, colsb 71 do i = 1, rowsa 72 temprr0 = 0.0 73 tempri0 = 0.0 74 tempir0 = 0.0 75 tempii0 = 0.0 76 do k = 1, colsa 77 temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k)) 78 enddo 79 do k = 1, colsa 80 tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k)) 81 enddo 82 do k = 1, colsa 83 tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k)) 84 enddo 85 do k = 1, colsa 86 tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k)) 87 enddo 88 c(i, j) = cmplx((temprr0 - tempii0), (tempri0 + tempir0)) 89 enddo 90 enddo 91 else 92 do j = 1, colsb 93 do i = 1, rowsa 94 temprr0 = 0.0 95 tempri0 = 0.0 96 tempir0 = 0.0 97 tempii0 = 0.0 98 do k = 1, colsa 99 temprr0 = temprr0 + real(alpha) * real(a(k, i)) * real(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * real(b(j, k)) 100 enddo 101 do k = 1, colsa 102 tempii0 = tempii0 + real(alpha) * aimag(a(k, i)) * aimag(b(j, k)) + aimag(alpha) * real(a(k, i)) * aimag(b(j, k)) 103 enddo 104 do k = 1, colsa 105 tempir0 = tempir0 + real(alpha) * real(a(k, i)) * aimag(b(j, k)) - aimag(alpha) * aimag(a(k, i)) * aimag(b(j, k)) 106 enddo 107 do k = 1, colsa 108 tempri0 = tempri0 + real(alpha) * aimag(a(k, i)) * real(b(j, k)) + aimag(alpha) * real(a(k, i)) * real(b(j, k)) 109 enddo 110 111 c(i, j) = beta * c(i, j) + cmplx((temprr0 - tempii0), (tempri0 + tempir0)) 112 enddo 113 enddo 114 endif 115 else 116 allocate( bufferb( bufrows * bufcols ) ) 117 118 119 ! set the number of buffer row chunks we will work on 120 bufr = min( bufrows, rowsb ) 121 bufr_sav = bufr 122 rowchunks = ( rowsb + bufr - 1 )/bufr 123 124 bufcb = min( bufcols, colsb ) 125 bufcb_sav = bufcb 126 colbchunks = ( colsb + bufcb - 1)/bufcb 127 ! Note that the starting column index into matrix a (ac) is the same as 128 ! starting index into matrix b. But we need 1 less than that so we can 129 ! add an index to it 130 br = 1 131 ac = 1 132 bc = 1 133 ak = 0 134 colsa_chunk = 4 135 colsa_chunks = mra/colsa_chunk 136 colsa_end = colsa_chunks * colsa_chunk 137 colsa_strt = colsa_end + 1 138 139 140 do rowchunk = 1, rowchunks 141 bc = 1 142 do colbchunk = 1, colbchunks 143 ak = br - 1 144 if( ta .eq. 2 )then !conjugate matrix a; b conjugated in transpose 145 if( br .eq. 1 ) then 146 bufcb = min( bufcb_sav, colsb - bc + 1 ) 147 bufr = min( bufr_sav, rowsb - br + 1 ) 148 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, & 149 & bufr, bufcb ) 150 if( beta .eq. 0.0 ) then 151 do i = 1, colsa_end, colsa_chunk 152 ndxb = 0 153 do j = bc, bc + bufcb - 1 154 temp0 = 0.0 155 temp1 = 0.0 156 temp2 = 0.0 157 temp3 = 0.0 158 do k = 1, bufr 159 bufbtemp = bufferb( ndxb + k ) 160 temp0 = temp0 + bufbtemp * conjg( a( ak + k, i ) ) 161 temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) ) 162 temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) ) 163 temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) ) 164 enddo 165 c( i, j ) = temp0 166 c( i + 1, j ) = temp1 167 c( i + 2, j ) = temp2 168 c( i + 3, j ) = temp3 169 ndxb = ndxb + bufr 170 enddo 171 enddo 172 ! Now clean up whatever is left from the loop unrolling 173 do i = colsa_strt, mra 174 ndxb = 0 175 do j = bc, bc + bufcb - 1 176 temp = 0.0 177 do k = 1, bufr 178 temp = temp + bufferb( ndxb + k ) * & 179 & conjg( a( ak + k, i ) ) 180 enddo 181 c( i, j ) = temp 182 ndxb = ndxb + bufr 183 enddo 184 enddo 185 else 186 do i = 1, colsa_end, colsa_chunk 187 ndxb = 0 188 do j = bc, bc + bufcb - 1 189 temp0 = 0.0 190 temp1 = 0.0 191 temp2 = 0.0 192 temp3 = 0.0 193 do k = 1, bufr 194 bufbtemp = bufferb( ndxb + k ) 195 temp0 = temp0 + bufbtemp * conjg( a( ak + k, i ) ) 196 temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) ) 197 temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) ) 198 temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) ) 199 enddo 200 c( i, j ) = beta * c( i, j ) + temp0 201 c( i + 1, j ) = beta * c( i + 1, j ) + temp1 202 c( i + 2, j ) = beta * c( i + 2, j ) + temp2 203 c( i + 3, j ) = beta * c( i + 3, j ) + temp3 204 ndxb = ndxb + bufr 205 enddo 206 enddo 207 ! Now clean up whatever is left from the loop unrolling 208 do i = colsa_strt, mra 209 ndxb = 0 210 do j = bc, bc + bufcb - 1 211 temp = 0.0 212 do k = 1, bufr 213 temp = temp + bufferb( ndxb + k ) * & 214 & conjg( a( ak + k, i ) ) 215 enddo 216 c( i, j ) = beta * c( i, j ) + temp 217 ndxb = ndxb + bufr 218 enddo 219 enddo 220 endif 221 else 222 bufcb = min( bufcb_sav, colsb - bc + 1 ) 223 bufr = min( bufr_sav, rowsb - br + 1 ) 224 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, & 225 & bufr, bufcb ) 226 do i = 1, colsa_end, colsa_chunk 227 ndxb = 0 228 do j = bc, bc + bufcb - 1 229 temp0 = 0.0 230 temp1 = 0.0 231 temp2 = 0.0 232 temp3 = 0.0 233 do k = 1, bufr 234 bufbtemp = bufferb( ndxb + k ) 235 temp0 = temp0 + bufbtemp * conjg( a( ak + k, i ) ) 236 temp1 = temp1 + bufbtemp * conjg( a( ak + k, i + 1 ) ) 237 temp2 = temp2 + bufbtemp * conjg( a( ak + k, i + 2 ) ) 238 temp3 = temp3 + bufbtemp * conjg( a( ak + k, i + 3 ) ) 239 enddo 240 c( i, j ) = c( i, j ) + temp0 241 c( i + 1, j ) = c( i + 1, j ) + temp1 242 c( i + 2, j ) = c( i + 2, j ) + temp2 243 c( i + 3, j ) = c( i + 3, j ) + temp3 244 ndxb = ndxb + bufr 245 enddo 246 enddo 247 ! Now clean up whatever is left from the loop unrolling 248 do i = colsa_strt, mra 249 ndxb = 0 250 do j = bc, bc + bufcb - 1 251 temp = 0.0 252 do k = 1, bufr 253 temp = temp + bufferb( ndxb + k ) * conjg( a( ak + k, i ) ) 254 enddo 255 c( i, j ) = c( i, j ) + temp 256 ndxb = ndxb + bufr 257 enddo 258 enddo 259 endif 260 else 261 if( br .eq. 1 ) then 262 bufcb = min( bufcb_sav, colsb - bc + 1 ) 263 bufr = min( bufr_sav, rowsb - br + 1 ) 264 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, & 265 & bufr, bufcb ) 266 if( beta .eq. 0.0 ) then 267 do i = 1, colsa_end, colsa_chunk 268 ndxb = 0 269 do j = bc, bc + bufcb - 1 270 temp0 = 0.0 271 temp1 = 0.0 272 temp2 = 0.0 273 temp3 = 0.0 274 do k = 1, bufr 275 bufbtemp = bufferb( ndxb + k ) 276 temp0 = temp0 + bufbtemp * a( ak + k, i ) 277 temp1 = temp1 + bufbtemp * a( ak + k, i + 1 ) 278 temp2 = temp2 + bufbtemp * a( ak + k, i + 2 ) 279 temp3 = temp3 + bufbtemp * a( ak + k, i + 3 ) 280 enddo 281 c( i, j ) = temp0 282 c( i + 1, j ) = temp1 283 c( i + 2, j ) = temp2 284 c( i + 3, j ) = temp3 285 ndxb = ndxb + bufr 286 enddo 287 enddo 288 ! Now clean up whatever is left from the loop unrolling 289 do i = colsa_strt, mra 290 ndxb = 0 291 do j = bc, bc + bufcb - 1 292 temp = 0.0 293 do k = 1, bufr 294 temp = temp + bufferb( ndxb + k ) * a( ak + k, i ) 295 enddo 296 c( i, j ) = temp 297 ndxb = ndxb + bufr 298 enddo 299 enddo 300 else 301 do i = 1, colsa_end, colsa_chunk 302 ndxb = 0 303 do j = bc, bc + bufcb - 1 304 temp0 = 0.0 305 temp1 = 0.0 306 temp2 = 0.0 307 temp3 = 0.0 308 do k = 1, bufr 309 bufbtemp = bufferb( ndxb + k ) 310 temp0 = temp0 + bufbtemp * a( ak + k, i ) 311 temp1 = temp1 + bufbtemp * a( ak + k, i + 1 ) 312 temp2 = temp2 + bufbtemp * a( ak + k, i + 2 ) 313 temp3 = temp3 + bufbtemp * a( ak + k, i + 3 ) 314 enddo 315 c( i, j ) = beta * c( i, j ) + temp0 316 c( i + 1, j ) = beta * c( i + 1, j ) + temp1 317 c( i + 2, j ) = beta * c( i + 2, j ) + temp2 318 c( i + 3, j ) = beta * c( i + 3, j ) + temp3 319 ndxb = ndxb + bufr 320 enddo 321 enddo 322 ! Now clean up whatever is left from the loop unrolling 323 do i = colsa_strt, mra 324 ndxb = 0 325 do j = bc, bc + bufcb - 1 326 temp = 0.0 327 do k = 1, bufr 328 temp = temp + bufferb( ndxb + k ) * a( ak + k, i ) 329 enddo 330 c( i, j ) = beta * c( i, j ) + temp 331 ndxb = ndxb + bufr 332 enddo 333 enddo 334 endif 335 else 336 bufcb = min( bufcb_sav, colsb - bc + 1 ) 337 bufr = min( bufr_sav, rowsb - br + 1 ) 338 call ftn_transpose_cmplx8( tb, b( bc, br ), ldb, alpha, bufferb, & 339 & bufr, bufcb ) 340 do i = 1, colsa_end, colsa_chunk 341 ndxb = 0 342 do j = bc, bc + bufcb - 1 343 temp0 = 0.0 344 temp1 = 0.0 345 temp2 = 0.0 346 temp3 = 0.0 347 do k = 1, bufr 348 bufbtemp = bufferb( ndxb + k ) 349 temp0 = temp0 + bufbtemp * a( ak + k, i ) 350 temp1 = temp1 + bufbtemp * a( ak + k, i + 1 ) 351 temp2 = temp2 + bufbtemp * a( ak + k, i + 2 ) 352 temp3 = temp3 + bufbtemp * a( ak + k, i + 3 ) 353 enddo 354 c( i, j ) = c( i, j ) + temp0 355 c( i + 1, j ) = c( i + 1, j ) + temp1 356 c( i + 2, j ) = c( i + 2, j ) + temp2 357 c( i + 3, j ) = c( i + 3, j ) + temp3 358 ndxb = ndxb + bufr 359 enddo 360 enddo 361 ! Now clean up whatever is left from the loop unrolling 362 do i = colsa_strt, mra 363 ndxb = 0 364 do j = bc, bc + bufcb - 1 365 temp = 0.0 366 do k = 1, bufr 367 temp = temp + bufferb( ndxb + k ) * a( ak + k, i ) 368 enddo 369 c( i, j ) = c( i, j ) + temp 370 ndxb = ndxb + bufr 371 enddo 372 enddo 373 endif 374 endif 375 ! adjust the boundaries in the direction of the columns of b 376 ! adjust the row values 377 bc = bc + bufcb 378 enddo 379 br = br + bufr 380 ! controlled but tcbe numbebrcbof bufferb chunks we use. 381 382 enddo 383 deallocate( bufferb ) 384 endif 385 return 386 end subroutine ftn_mtaxtb_cmplx8 387 388