1!=======================================================================!
2! Copyright (c) Intel Corporation - All rights reserved.                !
3! This file is part of the LIBXSMM library.                             !
4!                                                                       !
5! For information on the license, see the LICENSE file.                 !
6! Further information: https://github.com/hfp/libxsmm/                  !
7! SPDX-License-Identifier: BSD-3-Clause                                 !
8!=======================================================================!
9! Hans Pabst (Intel Corp.), Alexander Heinecke (Intel Corp.), and
10! Maxwell Hutchinson (University of Chicago)
11!=======================================================================!
12
13      PROGRAM stpm
14        USE :: LIBXSMM, libxsmm_mmcall => libxsmm_dmmcall_abc
15        USE :: STREAM_UPDATE_KERNELS
16
17        !$ USE omp_lib
18        IMPLICIT NONE
19
20        INTEGER, PARAMETER :: T = KIND(0D0)
21        REAL(T), PARAMETER :: alpha = 1, beta = 0
22
23        REAL(T), ALLOCATABLE, DIMENSION(:,:,:,:), TARGET :: a, c, d
24        !DIR$ ATTRIBUTES ALIGN:64 :: a, c, d
25        REAL(T), ALLOCATABLE, TARGET :: dx(:,:), dy(:,:), dz(:,:)
26        REAL(T), ALLOCATABLE, TARGET, SAVE :: tm1(:,:,:)
27        REAL(T), ALLOCATABLE, TARGET, SAVE :: tm2(:,:,:)
28        REAL(T), ALLOCATABLE, TARGET, SAVE :: tm3(:,:,:)
29        !$OMP THREADPRIVATE(tm1, tm2, tm3)
30        TYPE(LIBXSMM_DMMFUNCTION) :: xmm1, xmm2, xmm3
31        DOUBLE PRECISION :: duration, max_diff
32        INTEGER :: argc, m, n, k, routine, check, mm, nn, kk
33        INTEGER(8) :: i, j, ix, iy, iz, r, s
34        INTEGER(8) :: size0, size1, size
35        INTEGER(8) :: repetitions, start
36        CHARACTER(32) :: argv
37
38        argc = COMMAND_ARGUMENT_COUNT()
39        IF (1 <= argc) THEN
40          CALL GET_COMMAND_ARGUMENT(1, argv)
41          READ(argv, "(I32)") m
42        ELSE
43          m = 8
44        END IF
45        IF (3 <= argc) THEN
46          CALL GET_COMMAND_ARGUMENT(3, argv)
47          READ(argv, "(I32)") k
48        ELSE
49          k = m
50        END IF
51        IF (2 <= argc) THEN
52          CALL GET_COMMAND_ARGUMENT(2, argv)
53          READ(argv, "(I32)") n
54        ELSE
55          n = k
56        END IF
57        mm = 0
58        IF (4 <= argc) THEN
59          CALL GET_COMMAND_ARGUMENT(4, argv)
60          READ(argv, "(I32)") mm
61        END IF
62        mm = MERGE(10, mm, 0.EQ.mm)
63        nn = 0
64        IF (5 <= argc) THEN
65          CALL GET_COMMAND_ARGUMENT(5, argv)
66          READ(argv, "(I32)") nn
67        END IF
68        nn = MERGE(mm, nn, 0.EQ.nn)
69        kk = 0
70        IF (6 <= argc) THEN
71          CALL GET_COMMAND_ARGUMENT(6, argv)
72          READ(argv, "(I32)") kk
73        END IF
74        kk = MERGE(mm, kk, 0.EQ.kk)
75        IF (7 <= argc) THEN
76          CALL GET_COMMAND_ARGUMENT(7, argv)
77          READ(argv, "(I32)") size1
78        ELSE
79          size1 = 0
80        END IF
81        IF (8 <= argc) THEN
82          CALL GET_COMMAND_ARGUMENT(8, argv)
83          READ(argv, "(I32)") size
84        ELSE
85          size = 0 ! 1 repetition by default
86        END IF
87
88        ! Initialize LIBXSMM
89        CALL libxsmm_init()
90
91        ! workload is about 2 GByte in memory by default
92        size0 = ((m * n * k) + (nn * mm * kk)) * T ! size of single stream element in Byte
93        size1 = MERGE(2048_8, MERGE(size1, ISHFT(ABS(size0 * size1)     &
94     &          + ISHFT(1, 20) - 1, -20), 0.LE.size1), 0.EQ.size1)
95        size = ISHFT(MERGE(MAX(size, size1), ISHFT(ABS(size) * size0    &
96     &          + ISHFT(1, 20) - 1, -20), 0.LE.size), 20) / size0
97        s = ISHFT(size1, 20) / size0
98        repetitions = size / s
99        duration = 0
100        max_diff = 0
101
102        ALLOCATE(a(m,n,k,s))
103        ALLOCATE(c(mm,nn,kk,s))
104        ALLOCATE(dx(mm,m), dy(n,nn), dz(k,kk))
105
106        ! Initialize
107        !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) &
108        !$OMP   SHARED(a, m, mm, n, nn, k, kk, s)
109        DO i = 1, s
110          DO ix = 1, m
111            DO iy = 1, n
112              DO iz = 1, k
113                a(ix,iy,iz,i) = ix + iy*m + iz*m*n
114              END DO
115            END DO
116          END DO
117        END DO
118        !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) &
119        !$OMP   SHARED(c, m, mm, n, nn, k, kk, s)
120        DO i = 1, s
121          DO ix = 1, mm
122            DO iy = 1, nn
123              DO iz = 1, kk
124                c(ix,iy,iz,i) = 0.0
125              END DO
126            END DO
127          END DO
128        END DO
129        dx = 1.
130        dy = 1.
131        dz = 1.
132
133        WRITE(*, "(6(A,I0),A,I0,A,I0,A,I0)")                            &
134     &    "m=", m, " n=", n, " k=", k,                                  &
135     &    " mm=", mm, " nn=", nn, " kk=", kk,                           &
136     &    " elements=", UBOUND(a, 4),                                   &
137     &    " size=", size1, "MB repetitions=", repetitions
138
139        CALL GETENV("CHECK", argv)
140        READ(argv, "(I32)") check
141        IF (0.NE.check) THEN
142          ALLOCATE(d(mm,nn,kk,s))
143          !$OMP PARALLEL DO PRIVATE(i, ix, iy, iz) DEFAULT(NONE) &
144          !$OMP   SHARED(d, m, mm, n, nn, k, kk, s)
145          DO i = 1, s
146            DO ix = 1, mm
147              DO iy = 1, nn
148                DO iz = 1, kk
149                  d(ix,iy,iz,i) = 0.0
150                END DO
151              END DO
152            END DO
153          END DO
154
155          WRITE(*, "(A)") "Calculating check..."
156          !$OMP PARALLEL PRIVATE(i, j, r) DEFAULT(NONE) &
157          !$OMP   SHARED(a, dx, dy, dz, d, m, n, k, mm, nn, kk, &
158          !$OMP          repetitions)
159          ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k))
160          tm1 = 0; tm2 = 0;
161          DO r = 1, repetitions
162            !$OMP DO
163            DO i = LBOUND(a, 4), UBOUND(a, 4)
164              tm1 = RESHAPE(                                            &
165     &                MATMUL(dx, RESHAPE(a(:,:,:,i), (/m,n*k/))),       &
166     &                (/mm, n, k/)) ! [mm,m]x[m,n*k]->[mm,n*k]
167              DO j = 1, k
168                tm2(:,:,j) = MATMUL(tm1(:,:,j), dy) ! [mm,n]x[n,nn]->[mm,nn]
169              END DO
170              ! because we can't RESHAPE d
171              d(:,:,:,i) = RESHAPE(                                     &
172     &                        MATMUL(RESHAPE(tm2, (/mm*nn, k/)), dz),   &
173     &                        (/mm,nn,kk/)) ! [mm*nn,k]x[k,kk]->[mm*nn,kk]
174            END DO
175          END DO
176          ! Deallocate thread-local arrays
177          DEALLOCATE(tm1, tm2)
178          !$OMP END PARALLEL
179        END IF
180
181        WRITE(*, "(A)") "Streamed... (BLAS)"
182        !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) &
183        !$OMP   SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, &
184        !$OMP          duration, repetitions)
185        ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk))
186        tm1 = 0; tm2 = 0; tm3 = 3
187        !$OMP MASTER
188        start = libxsmm_timer_tick()
189        !$OMP END MASTER
190        !$OMP BARRIER
191        DO r = 1, repetitions
192          !$OMP DO
193          DO i = LBOUND(a, 4), UBOUND(a, 4)
194            ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm)
195            CALL libxsmm_blas_dgemm(m=mm, n=n*k, k=m,                   &
196     &              a=dx, b=a(:,:,1,i), c=tm1(:,:,1),                   &
197     &              alpha=alpha, beta=beta)
198            DO j = 1, k
199              ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm)
200              CALL libxsmm_blas_dgemm(m=mm, n=nn, k=n,                  &
201     &              a=tm1(:,:,j), b=dy, c=tm2(:,:,j),                   &
202     &              alpha=alpha, beta=beta)
203            END DO
204            ! PGI: cannot deduce generic procedure (libxsmm_blas_gemm)
205            CALL libxsmm_blas_dgemm(m=mm*nn, n=kk, k=k,                 &
206     &              a=tm2(:,:,1), b=dz, c=tm3(:,:,1),                   &
207     &              alpha=alpha, beta=beta)
208            CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk)
209          END DO
210        END DO
211        !$OMP BARRIER
212        !$OMP MASTER
213        duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
214        !$OMP END MASTER
215        ! Deallocate thread-local arrays
216        DEALLOCATE(tm1, tm2, tm3)
217        !$OMP END PARALLEL
218
219        CALL performance(duration, m, n, k, mm, nn, kk, size)
220        IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d))
221
222        WRITE(*, "(A)") "Streamed... (mxm)"
223        !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) &
224        !$OMP   SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, &
225        !$OMP          duration, repetitions)
226        ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk))
227        tm1 = 0; tm2 = 0; tm3 = 3
228        !$OMP MASTER
229        start = libxsmm_timer_tick()
230        !$OMP END MASTER
231        !$OMP BARRIER
232        DO r = 1, repetitions
233          !$OMP DO
234          DO i = LBOUND(a, 4), UBOUND(a, 4)
235            CALL mxmf2(dx, mm, a(:,:,:,i), m, tm1, n*k)
236            DO j = 1, k
237              CALL mxmf2(tm1(:,:,j), mm, dy, n, tm2(:,:,j), nn)
238            END DO
239            CALL mxmf2(tm2, mm*nn, dz, k, tm3, kk)
240            CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk)
241          END DO
242        END DO
243        !$OMP BARRIER
244        !$OMP MASTER
245        duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
246        !$OMP END MASTER
247        ! Deallocate thread-local arrays
248        DEALLOCATE(tm1, tm2, tm3)
249        !$OMP END PARALLEL
250
251        CALL performance(duration, m, n, k, mm, nn, kk, size)
252        IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d))
253
254        WRITE(*, "(A)") "Streamed... (auto-dispatched)"
255        !$OMP PARALLEL PRIVATE(i, j, r, start) DEFAULT(NONE) &
256        !$OMP   SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, &
257        !$OMP          duration, repetitions)
258        ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk))
259        tm1 = 0; tm2 = 0; tm3 = 3
260        !$OMP MASTER
261        start = libxsmm_timer_tick()
262        !$OMP END MASTER
263        !$OMP BARRIER
264        DO r = 1, repetitions
265          !$OMP DO
266          DO i = LBOUND(a, 4), UBOUND(a, 4)
267            ! PGI: cannot deduce generic procedure (libxsmm_gemm)
268            CALL libxsmm_dgemm(m=mm, n=n*k, k=m,                        &
269     &              a=dx, b=a(:,:,1,i), c=tm1(:,:,1),                   &
270     &              alpha=alpha, beta=beta)
271            DO j = 1, k
272              ! PGI: cannot deduce generic procedure (libxsmm_gemm)
273              CALL libxsmm_dgemm(m=mm, n=nn, k=n,                       &
274     &              a=tm1(:,:,j), b=dy, c=tm2(:,:,j),                   &
275     &              alpha=alpha, beta=beta)
276            END DO
277            ! PGI: cannot deduce generic procedure (libxsmm_gemm)
278            CALL libxsmm_dgemm(m=mm*nn, n=kk, k=k,                      &
279     &              a=tm2(:,:,1), b=dz, c=tm3(:,:,1),                   &
280     &              alpha=alpha, beta=beta)
281            CALL stream_vector_copy(tm3(1,1,1), c(1,1,1,i), mm*nn*kk)
282          END DO
283        END DO
284        !$OMP BARRIER
285        !$OMP MASTER
286        duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
287        !$OMP END MASTER
288        ! Deallocate thread-local arrays
289        DEALLOCATE(tm1, tm2, tm3)
290        !$OMP END PARALLEL
291
292        CALL performance(duration, m, n, k, mm, nn, kk, size)
293        IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d))
294
295        WRITE(*, "(A)") "Streamed... (specialized)"
296        CALL libxsmm_dispatch(xmm1, mm, n*k, m,                         &
297     &          alpha=alpha, beta=beta)
298        CALL libxsmm_dispatch(xmm2, mm, nn, n,                          &
299     &          alpha=alpha, beta=beta)
300        CALL libxsmm_dispatch(xmm3, mm*nn, kk, k,                       &
301     &          alpha=alpha, beta=beta)
302        IF (libxsmm_available(xmm1).AND.                                &
303     &      libxsmm_available(xmm2).AND.                                &
304     &      libxsmm_available(xmm3))                                    &
305     &  THEN
306          !$OMP PARALLEL PRIVATE(i, j, r, start) & !DEFAULT(NONE)
307          !$OMP   SHARED(a, dx, dy, dz, c, m, n, k, mm, nn, kk, &
308          !$OMP          duration, repetitions, xmm1, xmm2, xmm3)
309          ALLOCATE(tm1(mm,n,k), tm2(mm,nn,k), tm3(mm,nn,kk))
310          tm1 = 0; tm2 = 0; tm3 = 3
311          !$OMP MASTER
312          start = libxsmm_timer_tick()
313          !$OMP END MASTER
314          !$OMP BARRIER
315          DO r = 1, repetitions
316            !$OMP DO
317            DO i = LBOUND(a, 4), UBOUND(a, 4)
318              ! [mm,m]x[m,n*k]->[mm,n*k]
319              CALL libxsmm_mmcall(xmm1, dx, a(1,1,1,i), tm1)
320              DO j = 1, k ! [mm,n]x[n,nn]->[mm,nn]
321                CALL libxsmm_mmcall(xmm2, tm1(1,1,j), dy, tm2(1,1,j))
322              END DO
323              ! [mm*nn,k]x[k,kk]->[mm*nn,kk]
324              CALL libxsmm_mmcall(xmm3, tm2, dz, tm3(1,1,1))
325              CALL stream_vector_copy(                                  &
326     &                tm3(1,1,1), c(1,1,1,i), mm*nn*kk)
327            END DO
328          END DO
329          !$OMP BARRIER
330          !$OMP MASTER
331          duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
332          !$OMP END MASTER
333          ! Deallocate thread-local arrays
334          DEALLOCATE(tm1, tm2, tm3)
335          !$OMP END PARALLEL
336
337          CALL performance(duration, m, n, k, mm, nn, kk, size)
338          IF (check.NE.0) max_diff = MAX(max_diff, validate(c, d))
339        ELSE
340          WRITE(*,*) "Could not build specialized function(s)!"
341        END IF
342
343        ! Deallocate global arrays
344        IF (check.NE.0) DEALLOCATE(d)
345        DEALLOCATE(dx, dy, dz)
346        DEALLOCATE(a, c)
347
348        ! finalize LIBXSMM
349        CALL libxsmm_finalize()
350
351        IF ((0.NE.check).AND.(1.LT.max_diff)) STOP 1
352
353      CONTAINS
354        FUNCTION validate(ref, test) RESULT(diff)
355          REAL(T), DIMENSION(:,:,:,:), intent(in) :: ref, test
356          REAL(T) :: diff
357          diff = MAXVAL((ref - test) * (ref - test))
358          WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "diff:       ", diff
359        END FUNCTION
360
361        SUBROUTINE performance(duration, m, n, k, mm, nn, kk, size)
362          DOUBLE PRECISION, INTENT(IN) :: duration
363          INTEGER, INTENT(IN)    :: m, n, k, mm, nn, kk
364          INTEGER(8), INTENT(IN) :: size
365          IF (0.LT.duration) THEN
366            WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "performance:", (size   &
367     &        * ((2*m-1)*mm*n*k + mm*(2*n-1)*nn*k + mm*nn*(2*k-1)*kk)   &
368     &        * 1D-9 / duration), " GFLOPS/s"
369            WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "bandwidth:  ", (size   &
370     &        * ((m*n*k) + (mm*nn*kk))                                  &
371     &        * T / (duration * LSHIFT(1_8, 30))), " GB/s"
372          END IF
373          WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "duration:   ",           &
374     &      (1D3 * duration) / repetitions, " ms"
375        END SUBROUTINE
376      END PROGRAM
377
378