1!=======================================================================! 2! Copyright (c) Intel Corporation - All rights reserved. ! 3! This file is part of the LIBXSMM library. ! 4! ! 5! For information on the license, see the LICENSE file. ! 6! Further information: https://github.com/hfp/libxsmm/ ! 7! SPDX-License-Identifier: BSD-3-Clause ! 8!=======================================================================! 9! Hans Pabst (Intel Corp.), Alexander Heinecke (Intel Corp.), and 10! Maxwell Hutchinson (University of Chicago) 11!=======================================================================! 12 13 PROGRAM stpm 14 USE :: LIBXSMM, libxsmm_mmcall => libxsmm_dmmcall_abc 15 USE :: STREAM_UPDATE_KERNELS 16 17 !$ USE omp_lib 18 IMPLICIT NONE 19 20 INTEGER, PARAMETER :: T = KIND(0D0) 21 REAL(T), PARAMETER :: alpha = 1, beta = 0 22 23 REAL(T), ALLOCATABLE, DIMENSION(:,:,:,:), TARGET :: a, c, d 24 !DIR$ ATTRIBUTES ALIGN:64 :: a, c, d 25 REAL(T), ALLOCATABLE, TARGET :: dx(:,:), dy(:,:), dz(:,:) 26 REAL(T), ALLOCATABLE, TARGET, SAVE :: tm1(:,:,:) 27 REAL(T), ALLOCATABLE, TARGET, SAVE :: tm2(:,:,:) 28 REAL(T), ALLOCATABLE, TARGET, SAVE :: tm3(:,:,:) 29 !$OMP THREADPRIVATE(tm1, tm2, tm3) 30 TYPE(LIBXSMM_DMMFUNCTION) :: xmm1, xmm2, xmm3 31 DOUBLE PRECISION :: duration, max_diff 32 INTEGER :: argc, m, n, k, routine, check, mm, nn, kk 33 INTEGER(8) :: i, j, ix, iy, iz, r, s 34 INTEGER(8) :: size0, size1, size 35 INTEGER(8) :: repetitions, start 36 CHARACTER(32) :: argv 37 38 argc = COMMAND_ARGUMENT_COUNT() 39 IF (1 <= argc) THEN 40 CALL GET_COMMAND_ARGUMENT(1, argv) 41 READ(argv, "(I32)") m 42 ELSE 43 m = 8 44 END IF 45 IF (3 <= argc) THEN 46 CALL GET_COMMAND_ARGUMENT(3, argv) 47 READ(argv, "(I32)") k 48 ELSE 49 k = m 50 END IF 51 IF (2 <= argc) THEN 52 CALL GET_COMMAND_ARGUMENT(2, argv) 53 READ(argv, "(I32)") n 54 ELSE 55 n = k 56 END IF 57 mm = 0 58 IF (4 <= argc) THEN 59 CALL GET_COMMAND_ARGUMENT(4, argv) 60 READ(argv, "(I32)") mm 61 END IF 62 mm = MERGE(10, mm, 0.EQ.mm) 63 nn = 0 64 IF (5 <= argc) THEN 65 CALL GET_COMMAND_ARGUMENT(5, argv) 66 READ(argv, "(I32)") nn 67 END IF 68 nn = MERGE(mm, nn, 0.EQ.nn) 69 kk = 0 70 IF (6 <= argc) THEN 71 CALL GET_COMMAND_ARGUMENT(6, argv) 72 READ(argv, "(I32)") kk 73 END IF 74 kk = MERGE(mm, kk, 0.EQ.kk) 75 IF (7 <= argc) THEN 76 CALL GET_COMMAND_ARGUMENT(7, argv) 77 READ(argv, "(I32)") size1 78 ELSE 79 size1 = 0 80 END IF 81 IF (8 <= argc) THEN 82 CALL GET_COMMAND_ARGUMENT(8, argv) 83 READ(argv, "(I32)") size 84 ELSE 85 size = 0 ! 1 repetition by default 86 END IF 87 88 ! Initialize LIBXSMM 89 CALL libxsmm_init() 90 91 ! workload is about 2 GByte in memory by default 92 size0 = ((m * n * k) + (nn * mm * kk)) * T ! size of single stream element in Byte 93 size1 = MERGE(2048_8, MERGE(size1, ISHFT(ABS(size0 * size1) & 94 & + ISHFT(1, 20) - 1, -20), 0.LE.size1), 0.EQ.size1) 95 size = ISHFT(MERGE(MAX(size, size1), ISHFT(ABS(size) * size0 & 96 & + ISHFT(1, 20) - 1, -20), 0.LE.size), 20) / size0 97 s = ISHFT(size1, 20) / size0 98 repetitions = size / s 99 duration = 0 100 max_diff = 0 101 102 ALLOCATE(a(m,n,k,s)) 103 ALLOCATE(c(mm,nn,kk,s)) 104 ALLOCATE(dx(mm,m), dy(n,nn), dz(k,kk)) 105 106 ! Initialize 107 !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) & 108 !$OMP SHARED(a, m, mm, n, nn, k, kk, s) 109 DO i = 1, s 110 DO ix = 1, m 111 DO iy = 1, n 112 DO iz = 1, k 113 a(ix,iy,iz,i) = ix + iy*m + iz*m*n 114 END DO 115 END DO 116 END DO 117 END DO 118 !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) & 119 !$OMP SHARED(c, m, mm, n, nn, k, kk, s) 120 DO i = 1, s 121 DO ix = 1, mm 122 DO iy = 1, nn 123 DO iz = 1, kk 124 c(ix,iy,iz,i) = 0.0 125 END DO 126 END DO 127 END DO 128 END DO 129 dx = 1. 130 dy = 1. 131 dz = 1. 132 133 WRITE(*, "(6(A,I0),A,I0,A,I0,A,I0)") & 134 & "m=", m, " n=", n, " k=", k, & 135 & " mm=", mm, " nn=", nn, " kk=", kk, & 136 & " elements=", UBOUND(a, 4), & 137 & " size=", size1, "MB repetitions=", repetitions 138 139 CALL GETENV("CHECK", argv) 140 READ(argv, "(I32)") check 141 IF (0.NE.check) THEN 142 ALLOCATE(d(mm,nn,kk,s)) 143 !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) & 144 !$OMP SHARED(d, m, mm, n, nn, k, kk, s) 145 DO i = 1, s 146 DO ix = 1, mm 147 DO iy = 1, nn 148 DO iz = 1, kk 149 d(ix,iy,iz,i) = 0.0 150 END DO 151 END DO 152 END DO 153 END DO 154 155 WRITE(*, "(A)") "Calculating check..." 156 !$OMP PARALLEL PRIVATE(i, j, r) DEFAULT(NONE) & 157 !$OMP SHARED(a, dx, dy, dz, d, m, n, k, mm, nn, kk, & 158 !$OMP repetitions) 159 ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k)) 160 tm1 = 0; tm2 = 0; 161 DO r = 1, repetitions 162 !$OMP DO 163 DO i = LBOUND(a, 4), UBOUND(a, 4) 164 tm1 = RESHAPE( & 165 & MATMUL(dx, RESHAPE(a(:,:,:,i), (/m,n*k/))), & 166 & (/mm, n, k/)) ! [mm,m]x[m,n*k]->[mm,n*k] 167 DO j = 1, k 168 tm2(:,:,j) = MATMUL(tm1(:,:,j), dy) ! [mm,n]x[n,nn]->[mm,nn] 169 END DO 170 ! because we can't RESHAPE d 171 d(:,:,:,i) = RESHAPE( & 172 & MATMUL(RESHAPE(tm2, (/mm*nn, k/)), dz), & 173 & (/mm,nn,kk/)) ! [mm*nn,k]x[k,kk]->[mm*nn,kk] 174 END DO 175 END DO 176 ! Deallocate thread-local arrays 177 DEALLOCATE(tm1, tm2) 178 !$OMP END PARALLEL 179 END IF 180 181 WRITE(*, "(A)") "Streamed... (BLAS)" 182 !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) & 183 !$OMP SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, & 184 !$OMP duration, repetitions) 185 ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk)) 186 tm1 = 0; tm2 = 0; tm3 = 3 187 !$OMP MASTER 188 start = libxsmm_timer_tick() 189 !$OMP END MASTER 190 !$OMP BARRIER 191 DO r = 1, repetitions 192 !$OMP DO 193 DO i = LBOUND(a, 4), UBOUND(a, 4) 194 ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm) 195 CALL libxsmm_blas_dgemm(m=mm, n=n*k, k=m, & 196 & a=dx, b=a(:,:,1,i), c=tm1(:,:,1), & 197 & alpha=alpha, beta=beta) 198 DO j = 1, k 199 ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm) 200 CALL libxsmm_blas_dgemm(m=mm, n=nn, k=n, & 201 & a=tm1(:,:,j), b=dy, c=tm2(:,:,j), & 202 & alpha=alpha, beta=beta) 203 END DO 204 ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm) 205 CALL libxsmm_blas_dgemm(m=mm*nn, n=kk, k=k, & 206 & a=tm2(:,:,1), b=dz, c=tm3(:,:,1), & 207 & alpha=alpha, beta=beta) 208 CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk) 209 END DO 210 END DO 211 !$OMP BARRIER 212 !$OMP MASTER 213 duration = libxsmm_timer_duration(start, libxsmm_timer_tick()) 214 !$OMP END MASTER 215 ! Deallocate thread-local arrays 216 DEALLOCATE(tm1, tm2, tm3) 217 !$OMP END PARALLEL 218 219 CALL performance(duration, m, n, k, mm, nn, kk, size) 220 IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d)) 221 222 WRITE(*, "(A)") "Streamed... (mxm)" 223 !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) & 224 !$OMP SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, & 225 !$OMP duration, repetitions) 226 ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk)) 227 tm1 = 0; tm2 = 0; tm3 = 3 228 !$OMP MASTER 229 start = libxsmm_timer_tick() 230 !$OMP END MASTER 231 !$OMP BARRIER 232 DO r = 1, repetitions 233 !$OMP DO 234 DO i = LBOUND(a, 4), UBOUND(a, 4) 235 CALL mxmf2(dx, mm, a(:,:,:,i), m, tm1, n*k) 236 DO j = 1, k 237 CALL mxmf2(tm1(:,:,j), mm, dy, n, tm2(:,:,j), nn) 238 END DO 239 CALL mxmf2(tm2, mm*nn, dz, k, tm3, kk) 240 CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk) 241 END DO 242 END DO 243 !$OMP BARRIER 244 !$OMP MASTER 245 duration = libxsmm_timer_duration(start, libxsmm_timer_tick()) 246 !$OMP END MASTER 247 ! Deallocate thread-local arrays 248 DEALLOCATE(tm1, tm2, tm3) 249 !$OMP END PARALLEL 250 251 CALL performance(duration, m, n, k, mm, nn, kk, size) 252 IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d)) 253 254 WRITE(*, "(A)") "Streamed... (auto-dispatched)" 255 !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) & 256 !$OMP SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, & 257 !$OMP duration, repetitions) 258 ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk)) 259 tm1 = 0; tm2 = 0; tm3 = 3 260 !$OMP MASTER 261 start = libxsmm_timer_tick() 262 !$OMP END MASTER 263 !$OMP BARRIER 264 DO r = 1, repetitions 265 !$OMP DO 266 DO i = LBOUND(a, 4), UBOUND(a, 4) 267 ! PGI: cannot deduce generic procedure (libxsmm_gemm) 268 CALL libxsmm_dgemm(m=mm, n=n*k, k=m, & 269 & a=dx, b=a(:,:,1,i), c=tm1(:,:,1), & 270 & alpha=alpha, beta=beta) 271 DO j = 1, k 272 ! PGI: cannot deduce generic procedure (libxsmm_gemm) 273 CALL libxsmm_dgemm(m=mm, n=nn, k=n, & 274 & a=tm1(:,:,j), b=dy, c=tm2(:,:,j), & 275 & alpha=alpha, beta=beta) 276 END DO 277 ! PGI: cannot deduce generic procedure (libxsmm_gemm) 278 CALL libxsmm_dgemm(m=mm*nn, n=kk, k=k, & 279 & a=tm2(:,:,1), b=dz, c=tm3(:,:,1), & 280 & alpha=alpha, beta=beta) 281 CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk) 282 END DO 283 END DO 284 !$OMP BARRIER 285 !$OMP MASTER 286 duration = libxsmm_timer_duration(start, libxsmm_timer_tick()) 287 !$OMP END MASTER 288 ! Deallocate thread-local arrays 289 DEALLOCATE(tm1, tm2, tm3) 290 !$OMP END PARALLEL 291 292 CALL performance(duration, m, n, k, mm, nn, kk, size) 293 IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d)) 294 295 WRITE(*, "(A)") "Streamed... (specialized)" 296 CALL libxsmm_dispatch(xmm1, mm, n*k, m, & 297 & alpha=alpha, beta=beta) 298 CALL libxsmm_dispatch(xmm2, mm, nn, n, & 299 & alpha=alpha, beta=beta) 300 CALL libxsmm_dispatch(xmm3, mm*nn, kk, k, & 301 & alpha=alpha, beta=beta) 302 IF (libxsmm_available(xmm1).AND. & 303 & libxsmm_available(xmm2).AND. & 304 & libxsmm_available(xmm3)) & 305 & THEN 306 !$OMP PARALLEL PRIVATE(i, j, r, start) & !DEFAULT(NONE) 307 !$OMP SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, & 308 !$OMP duration, repetitions, xmm1, xmm2, xmm3) 309 ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk)) 310 tm1 = 0; tm2 = 0; tm3 = 3 311 !$OMP MASTER 312 start = libxsmm_timer_tick() 313 !$OMP END MASTER 314 !$OMP BARRIER 315 DO r = 1, repetitions 316 !$OMP DO 317 DO i = LBOUND(a, 4), UBOUND(a, 4) 318 ! [mm,m]x[m,n*k]->[mm,n*k] 319 CALL libxsmm_mmcall(xmm1, dx, a(1,1,1,i), tm1) 320 DO j = 1, k ! [mm,n]x[n,nn]->[mm,nn] 321 CALL libxsmm_mmcall(xmm2, tm1(1,1,j), dy, tm2(1,1,j)) 322 END DO 323 ! [mm*nn,k]x[k,kk]->[mm*nn,kk] 324 CALL libxsmm_mmcall(xmm3, tm2, dz, tm3(1,1,1)) 325 CALL stream_vector_copy( & 326 & tm3(1,1,1), c(1,1,1,i), mm*nn*kk) 327 END DO 328 END DO 329 !$OMP BARRIER 330 !$OMP MASTER 331 duration = libxsmm_timer_duration(start, libxsmm_timer_tick()) 332 !$OMP END MASTER 333 ! Deallocate thread-local arrays 334 DEALLOCATE(tm1, tm2, tm3) 335 !$OMP END PARALLEL 336 337 CALL performance(duration, m, n, k, mm, nn, kk, size) 338 IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d)) 339 ELSE 340 WRITE(*,*) "Could not build specialized function(s)!" 341 END IF 342 343 ! Deallocate global arrays 344 IF (check.NE.0) DEALLOCATE(d) 345 DEALLOCATE(dx, dy, dz) 346 DEALLOCATE(a, c) 347 348 ! finalize LIBXSMM 349 CALL libxsmm_finalize() 350 351 IF ((0.NE.check).AND.(1.LT.max_diff)) STOP 1 352 353 CONTAINS 354 FUNCTION validate(ref, test) RESULT(diff) 355 REAL(T), DIMENSION(:,:,:,:), intent(in) :: ref, test 356 REAL(T) :: diff 357 diff = MAXVAL((ref - test) * (ref - test)) 358 WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "diff: ", diff 359 END FUNCTION 360 361 SUBROUTINE performance(duration, m, n, k, mm, nn, kk, size) 362 DOUBLE PRECISION, INTENT(IN) :: duration 363 INTEGER, INTENT(IN) :: m, n, k, mm, nn, kk 364 INTEGER(8), INTENT(IN) :: size 365 IF (0.LT.duration) THEN 366 WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "performance:", (size & 367 & * ((2*m-1)*mm*n*k + mm*(2*n-1)*nn*k + mm*nn*(2*k-1)*kk) & 368 & * 1D-9 / duration), " GFLOPS/s" 369 WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "bandwidth: ", (size & 370 & * ((m*n*k) + (mm*nn*kk)) & 371 & * T / (duration * LSHIFT(1_8, 30))), " GB/s" 372 END IF 373 WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "duration: ", & 374 & (1D3 * duration) / repetitions, " ms" 375 END SUBROUTINE 376 END PROGRAM 377 378