1! 2! Copyright (c) 2012-2018, NVIDIA CORPORATION. All rights reserved. 3! 4! Licensed under the Apache License, Version 2.0 (the "License"); 5! you may not use this file except in compliance with the License. 6! You may obtain a copy of the License at 7! 8! http://www.apache.org/licenses/LICENSE-2.0 9! 10! Unless required by applicable law or agreed to in writing, software 11! distributed under the License is distributed on an "AS IS" BASIS, 12! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13! See the License for the specific language governing permissions and 14! limitations under the License. 15! 16 17 18! directives.h -- contains preprocessor directives for F90 rte files 19 20#include "mmul_dir.h" 21 22subroutine ftn_mnaxnb_cmplx16( mra, ncb, kab, alpha, a, lda, b, ldb, & 23 & beta, c, ldc ) 24 implicit none 25#include "pgf90_mmul_cmplx16.h" 26 27 ! Everything herein is focused on how the transposition buffer maps 28 ! to the matrix a. The size of the buffer is bufrows * bufcols 29 ! Since once transposed data will be read from the buffer down the rows, 30 ! bufrows corresponds to the columns of a while bufcols corresponds to 31 ! the rows of a. A bit confusing, but correct, I think 32 ! There are 4 cases to consider: 33 ! 1. rowsa <= bufcols AND colsa <= bufrows 34 ! 2. rowsa <= bufcols ( corresponds to a wide matrix ) 35 ! 3. colsa <= bufrows ( corresponds to a high matrix ) 36 ! 4. Both dimensions of a exceed both dimensions of the buffer 37 ! 38 ! The main idea here is that the bufrows will define the usage of the 39 ! L1 cache. We reference the same column or columns multiply while 40 ! accessing multiple partial rows of a transposed in the buffer. 41 42 43 ! 44 ! rowsa colsb 45 ! <-bufca(1)>< (2) > <-bufcb(1)><(2)> 46 ! i = 1, m -ar-> j = 1, n --br-> 47 ! ^ +----------+------+ ^ +----------+----+ ^ 48 ! | | x | | | x | | 49 ! | | x | | | x | | 50 ! bufr(1) | A**T x | rowchunks=2 | B x | | 51 ! | | x | | | x | | 52 ! | | | buffera x | | | | bufferb x | ka = 1, k 53 ! | | | x | | | | x | | 54 ! ac | | I x III | | bc | a x c | | 55 ! | v +xxxxxxxxxxxxxxxxx+ | | +xxxxxxxxxxxxxxx+ | 56 ! v ^ | x | | v | x | | 57 ! | | x | | | x | | 58 ! bufr(2) | x | | | x | | 59 ! | | II x IV | | | b x d | | 60 ! V +----------+------+ V +----------+----+ V 61 ! <--colachunks=2--> <-colbchunks=2> 62 ! x's mark buffer boudaries on the transposed matrices 63 ! For this case, bufca(1) = bufcols, bufr(1) = bufrows 64 65 ! The structure of this code came from mnaxnb_real. 66 colsa = kab 67 rowsb = kab 68 rowsa = mra 69 colsb = ncb 70 if (colsa * rowsa * colsb < min_blocked_mult) then 71 if( beta .eq. 0.0 ) then 72 do j = 1, colsb 73 do i = 1, rowsa 74 temprr0 = 0.0 75 tempri0 = 0.0 76 tempir0 = 0.0 77 tempii0 = 0.0 78 do k = 1, colsa 79 temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * real(b(k, j)) 80 enddo 81 do k = 1, colsa 82 tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(k, j)) + aimag(alpha) * real(a(i, k)) * aimag(b(k, j)) 83 enddo 84 do k = 1, colsa 85 tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(k, j)) 86 enddo 87 do k = 1, colsa 88 tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(k, j)) + aimag(alpha) * real(a(i, k)) * real(b(k, j)) 89 enddo 90 c(i, j) = dcmplx((temprr0 - tempii0), (tempri0 + tempir0)) 91 enddo 92 enddo 93 else 94 do j = 1, colsb 95 do i = 1, rowsa 96 temprr0 = 0.0 97 tempri0 = 0.0 98 tempir0 = 0.0 99 tempii0 = 0.0 100 do k = 1, colsa 101 temprr0 = temprr0 + real(alpha) * real(a(i, k)) * real(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * real(b(k, j)) 102 enddo 103 do k = 1, colsa 104 tempii0 = tempii0 + real(alpha) * aimag(a(i, k)) * aimag(b(k, j)) + aimag(alpha) * real(a(i, k)) * aimag(b(k, j)) 105 enddo 106 do k = 1, colsa 107 tempir0 = tempir0 + real(alpha) * real(a(i, k)) * aimag(b(k, j)) - aimag(alpha) * aimag(a(i, k)) * aimag(b(k, j)) 108 enddo 109 do k = 1, colsa 110 tempri0 = tempri0 + real(alpha) * aimag(a(i, k)) * real(b(k, j)) + aimag(alpha) * real(a(i, k)) * real(b(k, j)) 111 enddo 112 113 c(i, j) = beta * c(i, j) + dcmplx((temprr0 - tempii0), (tempri0 + tempir0)) 114 enddo 115 enddo 116 endif 117 else 118 allocate( buffera( bufrows * bufcols ) ) 119 120 bufca = min( bufcols, rowsa ) 121 122 colachunks = ( rowsa + bufcols - 1)/bufcols 123 ! set the number of buffer row chunks we will work on 124 bufr = min( bufrows, colsa ) 125 bufr_sav = bufr 126 rowchunks = ( colsa + bufr - 1 )/bufr 127 bufca_sav = bufca 128 ac = 1 ! column index in matrix a for transpose 129 ! lor = colsa - bufr ! left-over rows adjusts for the the fact that 130 ! colsa/bufr * bufr may not be equal to colsa 131 ! Note that the starting column index into matrix a (ac) is the same as 132 ! starting index into matrix b. But we need 1 less than that so we can 133 ! add an index to it 134 colsb_chunks = 4 135 colsb_end = colsb/colsb_chunks * colsb_chunks 136 colsb_strt = colsb_end + 1 137 do rowchunk = 1, rowchunks 138 ar = 1 139 do colachunk = 1, colachunks 140 bufca = min( bufca_sav, rowsa - ar + 1 ) 141 bufr = min( bufr_sav, colsa - ac + 1 ) 142 call ftn_transpose_cmplx16( ta, a( ar, ac ), lda, alpha, buffera, & 143 & bufr, bufca ) 144 if ( ac .eq. 1 )then 145 146 if( beta .eq. 0.0 ) then 147 do j = 1, colsb_end, colsb_chunks 148 ndxa = 0 149 do i = ar, ar + bufca - 1 150 bk = ac - 1 151 temprr0 = 0.0 152 temprr1 = 0.0 153 temprr2 = 0.0 154 temprr3 = 0.0 155 tempii0 = 0.0 156 tempii1 = 0.0 157 tempii2 = 0.0 158 tempii3 = 0.0 159 tempri0 = 0.0 160 tempri1 = 0.0 161 tempri2 = 0.0 162 tempri3 = 0.0 163 tempir0 = 0.0 164 tempir1 = 0.0 165 tempir2 = 0.0 166 tempir3 = 0.0 167 do k = 1, bufr ! dot product of real(a) * real(b) 168 bufatempr = real( buffera( ndxa + k ) ) 169 temprr0 = temprr0 + bufatempr * & 170 & real( b( bk + k, j ) ) 171 temprr1 = temprr1 + bufatempr * & 172 & real( b( bk + k, j + 1 ) ) 173 temprr2 = temprr2 + bufatempr * & 174 & real( b( bk + k, j + 2 ) ) 175 temprr3 = temprr3 + bufatempr * & 176 & real( b( bk + k, j + 3 ) ) 177 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 178 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 179 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 180 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 181 enddo 182 do k = 1, bufr ! dot product of aimag(a) * aimag(b) 183 bufatempi = aimag( buffera( ndxa + k ) ) 184 tempii0 = tempii0 + bufatempi * & 185 & aimag( b( bk + k, j ) ) 186 tempii1 = tempii1 + bufatempi * & 187 & aimag( b( bk + k, j + 1 ) ) 188 tempii2 = tempii2 + bufatempi * & 189 & aimag( b( bk + k, j + 2 ) ) 190 tempii3 = tempii3 + bufatempi * & 191 & aimag( b( bk + k, j + 3 ) ) 192 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 193 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 194 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 195 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 196 enddo 197 do k = 1, bufr ! cross dot product of real(a) * aimag(b) 198 bufatempr = real( buffera( ndxa + k ) ) 199 tempri0 = tempri0 + bufatempr * & 200 & aimag( b( bk + k, j ) ) 201 tempri1 = tempri1 + bufatempr * & 202 & aimag( b( bk + k, j + 1 ) ) 203 tempri2 = tempri2 + bufatempr * & 204 & aimag( b( bk + k, j + 2 ) ) 205 tempri3 = tempri3 + bufatempr * & 206 & aimag( b( bk + k, j + 3 ) ) 207 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 208 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 209 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 210 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 211 enddo 212 do k = 1, bufr ! cross dot product of aimag(a) * real(b) 213 bufatempi = aimag( buffera( ndxa + k ) ) 214 tempir0 = tempir0 + bufatempi * & 215 & real( b( bk + k, j ) ) 216 tempir1 = tempir1 + bufatempi * & 217 & real( b( bk + k, j + 1 ) ) 218 tempir2 = tempir2 + bufatempi * & 219 & real( b( bk + k, j + 2 ) ) 220 tempir3 = tempir3 + bufatempi * & 221 & real( b( bk + k, j + 3 ) ) 222 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 223 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 224 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 225 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 226 enddo 227 c(i, j ) = DCMPLX(temprr0 - tempii0, tempri0 + tempir0) 228 c(i, j + 1) = DCMPLX(temprr1 - tempii1, tempri1 + tempir1) 229 c(i, j + 2) = DCMPLX(temprr2 - tempii2, tempri2 + tempir2) 230 c(i, j + 3) = DCMPLX(temprr3 - tempii3, tempri3 + tempir3) 231 ndxa = ndxa + bufr 232 enddo 233 enddo 234 235 ! This takes care of the last 236 ! colsb - colsb/colsb_chunks*colsb_chunks cases 237 do j = colsb_strt, colsb 238 ndxa = 0 239 bk = ac - 1 240 do i = ar, ar + bufca - 1 241 temp = 0.0 242 do k = 1, bufr 243 temp = temp + buffera( ndxa + k ) * b( bk + k, j ) 244 enddo 245 c( i, j ) = temp 246 ndxa = ndxa + bufr 247 enddo 248 enddo 249 else 250 do j = 1, colsb_end, colsb_chunks 251 ndxa = 0 252 do i = ar, ar + bufca - 1 253 bk = ac - 1 254 temp0 = 0.0 255 temp1 = 0.0 256 temp2 = 0.0 257 temp3 = 0.0 258 do k = 1, bufr 259 bufatemp = buffera( ndxa + k ) 260 temp0 = temp0 + bufatemp * b( bk + k, j ) 261 temp1 = temp1 + bufatemp * b( bk + k, j + 1 ) 262 temp2 = temp2 + bufatemp * b( bk + k, j + 2 ) 263 temp3 = temp3 + bufatemp * b( bk + k, j + 3 ) 264 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 265 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 266 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 267 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 268 enddo 269 c( i, j ) = beta * c( i, j ) + temp0 270 c( i, j + 1 ) = beta * c( i, j + 1 ) + temp1 271 c( i, j + 2 ) = beta * c( i, j + 2 ) + temp2 272 c( i, j + 3 ) = beta * c( i, j + 3 ) + temp3 273 ndxa = ndxa + bufr 274 enddo 275 enddo 276 277 ! This takes care of the last 278 ! colsb - colsb/colsb_chunks*colsb_chunks cases 279 do j = colsb_strt, colsb 280 ndxa = 0 281 bk = ac - 1 282 do i = ar, ar + bufca - 1 283 temp = 0.0 284 do k = 1, bufr 285 temp = temp + buffera( ndxa + k ) * b( bk + k, j ) 286 enddo 287 c( i, j ) = beta * c( i, j ) + temp 288 ndxa = ndxa + bufr 289 enddo 290 enddo 291 endif 292 else 293 do j = 1, colsb_end, colsb_chunks 294 ndxa = 0 295 do i = ar, ar + bufca - 1 296 bk = ac - 1 297 temp0 = 0.0 298 temp1 = 0.0 299 temp2 = 0.0 300 temp3 = 0.0 301 do k = 1, bufr 302 bufatemp = buffera( ndxa + k ) 303 temp0 = temp0 + bufatemp * b( bk + k, j ) 304 temp1 = temp1 + bufatemp * b( bk + k, j + 1 ) 305 temp2 = temp2 + bufatemp * b( bk + k, j + 2 ) 306 temp3 = temp3 + bufatemp * b( bk + k, j + 3 ) 307 ! temp4 = temp4 + bufatemp * b( bk + k, j + 4 ) 308 ! temp5 = temp5 + bufatemp * b( bk + k, j + 5 ) 309 ! temp6 = temp6 + bufatemp * b( bk + k, j + 6 ) 310 ! temp7 = temp7 + bufatemp * b( bk + k, j + 7 ) 311 enddo 312 c( i, j ) = c( i, j ) + temp0 313 c( i, j + 1 ) = c( i, j + 1 ) + temp1 314 c( i, j + 2 ) = c( i, j + 2 ) + temp2 315 c( i, j + 3 ) = c( i, j + 3 ) + temp3 316 ndxa = ndxa + bufr 317 enddo 318 enddo 319 320 ! This takes care of the last colsb - colsb/colsb_chunks*colsb_chunks 321 ! cases 322 do j = colsb_strt, colsb 323 ndxa = 0 324 bk = ac - 1 325 do i = ar, ar + bufca - 1 326 temp = 0.0 327 do k = 1, bufr 328 temp = temp + buffera( ndxa + k ) * b( bk + k, j ) 329 enddo 330 c( i, j ) = c( i, j ) + temp 331 ndxa = ndxa + bufr 332 enddo 333 enddo 334 endif 335 ! adjust the boundaries in the direction of the columns of a 336 337 ar = ar + bufca 338 enddo 339 ! adjust the row values 340 ac = ac + bufr 341 enddo 342 deallocate( buffera ) 343 endif 344 return 345end subroutine ftn_mnaxnb_cmplx16 346