1 //===--------- support.h - OpenMP GPU support functions ---------- CUDA -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper to some functions natively supported by the GPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef OMPTARGET_SUPPORT_H
14 #define OMPTARGET_SUPPORT_H
15 
16 #include "interface.h"
17 #include "target_impl.h"
18 
19 ////////////////////////////////////////////////////////////////////////////////
20 // Execution Parameters
21 ////////////////////////////////////////////////////////////////////////////////
22 enum ExecutionMode {
23   Spmd = 0x00u,
24   Generic = 0x01u,
25   ModeMask = 0x01u,
26 };
27 
28 enum RuntimeMode {
29   RuntimeInitialized = 0x00u,
30   RuntimeUninitialized = 0x02u,
31   RuntimeMask = 0x02u,
32 };
33 
34 DEVICE void setExecutionParameters(ExecutionMode EMode, RuntimeMode RMode);
35 DEVICE bool isGenericMode();
36 DEVICE bool isSPMDMode();
37 DEVICE bool isRuntimeUninitialized();
38 DEVICE bool isRuntimeInitialized();
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 // Execution Modes based on location parameter fields
42 ////////////////////////////////////////////////////////////////////////////////
43 
44 DEVICE bool checkSPMDMode(kmp_Ident *loc);
45 DEVICE bool checkGenericMode(kmp_Ident *loc);
46 DEVICE bool checkRuntimeUninitialized(kmp_Ident *loc);
47 DEVICE bool checkRuntimeInitialized(kmp_Ident *loc);
48 
49 ////////////////////////////////////////////////////////////////////////////////
50 // get info from machine
51 ////////////////////////////////////////////////////////////////////////////////
52 
53 // get global ids to locate tread/team info (constant regardless of OMP)
54 DEVICE int GetLogicalThreadIdInBlock(bool isSPMDExecutionMode);
55 DEVICE int GetMasterThreadID();
56 DEVICE int GetNumberOfWorkersInTeam();
57 
58 // get OpenMP thread and team ids
59 DEVICE int GetOmpThreadId(int threadId,
60                           bool isSPMDExecutionMode);    // omp_thread_num
61 DEVICE int GetOmpTeamId();                              // omp_team_num
62 
63 // get OpenMP number of threads and team
64 DEVICE int GetNumberOfOmpThreads(bool isSPMDExecutionMode); // omp_num_threads
65 DEVICE int GetNumberOfOmpTeams();                           // omp_num_teams
66 
67 // get OpenMP number of procs
68 DEVICE int GetNumberOfProcsInTeam(bool isSPMDExecutionMode);
69 DEVICE int GetNumberOfProcsInDevice(bool isSPMDExecutionMode);
70 
71 // masters
72 DEVICE int IsTeamMaster(int ompThreadId);
73 
74 // Parallel level
75 DEVICE void IncParallelLevel(bool ActiveParallel, __kmpc_impl_lanemask_t Mask);
76 DEVICE void DecParallelLevel(bool ActiveParallel, __kmpc_impl_lanemask_t Mask);
77 
78 ////////////////////////////////////////////////////////////////////////////////
79 // Memory
80 ////////////////////////////////////////////////////////////////////////////////
81 
82 // safe alloc and free
83 DEVICE void *SafeMalloc(size_t size, const char *msg); // check if success
84 DEVICE void *SafeFree(void *ptr, const char *msg);
85 // pad to a alignment (power of 2 only)
86 DEVICE unsigned long PadBytes(unsigned long size, unsigned long alignment);
87 #define ADD_BYTES(_addr, _bytes)                                               \
88   ((void *)((char *)((void *)(_addr)) + (_bytes)))
89 #define SUB_BYTES(_addr, _bytes)                                               \
90   ((void *)((char *)((void *)(_addr)) - (_bytes)))
91 
92 ////////////////////////////////////////////////////////////////////////////////
93 // Teams Reduction Scratchpad Helpers
94 ////////////////////////////////////////////////////////////////////////////////
95 DEVICE unsigned int *GetTeamsReductionTimestamp();
96 DEVICE char *GetTeamsReductionScratchpad();
97 
98 #endif
99