1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 /*
18  * rotmg generator
19  */
20 //#define DEBUG_ROTMG
21 
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <clblas_stddef.h>
26 #include <clBLAS.h>
27 #include <blas_mempat.h>
28 #include <clkern.h>
29 #include <clblas-internal.h>
30 #include "blas_kgen.h"
31 #include <kprintf.hpp>
32 #include <rotmg.clT>
33 #include <solution_seq.h>
34 
35 extern "C"
36 unsigned int dtypeSize(DataType type);
37 
38 
39 static char Prefix[4];
40 
41 static SolverFlags
solverFlags(void)42 solverFlags(void)
43 {
44     return (SF_WSPACE_1D);
45 }
46 
47 static void
48 calcNrThreads(
49     size_t threads[2],
50     const SubproblemDim *subdims,
51     const PGranularity *pgran,
52     const void *args,
53     const void *extra);
54 
55 static ssize_t
56 generator(
57    char *buf,
58    size_t buflen,
59    const struct SubproblemDim *subdims,
60    const struct PGranularity *pgran,
61    void *extra);
62 
63 
64 static void
65 assignKargs(KernelArg *args, const void *params, const void* extra );
66 
67 extern "C"
68 void initRotmgRegisterPattern(MemoryPattern *mempat);
69 
70 static void
71 setBuildOpts(
72     char * buildOptStr,
73     const void *kArgs);
74 
75 static SolverOps rotmgOps = {
76     generator,
77     assignKargs,
78     NULL,
79     NULL, // Prepare Translate Dims
80     NULL, // Inner Decomposition Axis
81     calcNrThreads,
82     NULL,
83     solverFlags,
84 	NULL,
85 	NULL,
86 	NULL,
87 	setBuildOpts,
88 	NULL
89 };
90 
91 static void
setBuildOpts(char * buildOptStr,const void * args)92 setBuildOpts(
93     char * buildOptStr,
94     const void *args)
95 {
96 	const SolutionStep *step = (const SolutionStep *)args;
97     const CLBlasKargs *kargs = (const CLBlasKargs *)(&step->args);
98 	if ( kargs->dtype == TYPE_DOUBLE || kargs->dtype == TYPE_COMPLEX_DOUBLE)
99 	{
100 		addBuildOpt( buildOptStr, BUILD_OPTS_MAXLEN, "-DDOUBLE_PRECISION");
101 	}
102 
103 	return;
104 }
105 
106 
107 static CLBLASMpatExtra mpatExtra;
108 
109 extern "C"
initRotmgRegisterPattern(MemoryPattern * mempat)110 void initRotmgRegisterPattern(MemoryPattern *mempat)
111 {
112 	#ifdef DEBUG_ROTMG
113 	printf("initRegPattern called with mempat = 0x%p\n", mempat);
114 	#endif
115 
116 	fflush(stdout);
117     mempat->name = "Register accumulation based swap";
118     mempat->nrLevels = 2;
119     mempat->cuLevel = 0;
120     mempat->thLevel = 1;
121     mempat->sops = &rotmgOps;
122 
123     mpatExtra.aMset = CLMEM_LEVEL_L2;
124     mpatExtra.bMset = CLMEM_LEVEL_L2;
125     mpatExtra.mobjA = CLMEM_GLOBAL_MEMORY;
126     mpatExtra.mobjB = CLMEM_GLOBAL_MEMORY;
127     mempat->extra = &mpatExtra;
128 
129 	Prefix[TYPE_FLOAT] = 'S';
130 	Prefix[TYPE_DOUBLE] = 'D';
131 	Prefix[TYPE_COMPLEX_FLOAT] = 'C';
132 	Prefix[TYPE_COMPLEX_DOUBLE] = 'Z';
133 }
134 
135 static void
calcNrThreads(size_t threads[2],const SubproblemDim * subdims,const PGranularity * pgran,const void * args,const void * _extra)136 calcNrThreads(
137     size_t threads[2],
138     const SubproblemDim *subdims,
139     const PGranularity *pgran,
140     const void *args,
141     const void *_extra)
142 {
143 	DUMMY_ARGS_USAGE_3(subdims, _extra, args);
144 	int BLOCKSIZE = pgran->wgSize[0] * pgran->wgSize[1]; // 1D Block
145 
146 	size_t blocks = 1;  // Only 1 work-group is enough
147 	#ifdef DEBUG_ROTMG
148 	printf("blocks : %d\n", blocks);
149 	#endif
150 
151 	threads[0] = blocks * BLOCKSIZE;
152 	#ifdef DEBUG_ROTMG
153 	printf("pgran-wgSize[0] : %d, globalthreads[0]  : %d\n", pgran->wgSize[0], threads[0]);
154 	#endif
155 	threads[1] = 1;
156 }
157 
158 //
159 // FIXME: Report correct return value - Needs change in KPRINTF
160 //
161 static ssize_t
generator(char * buf,size_t buflen,const struct SubproblemDim * subdims,const struct PGranularity * pgran,void * extra)162 generator(
163    char *buf,
164    size_t buflen,
165    const struct SubproblemDim *subdims,
166    const struct PGranularity *pgran,
167    void *extra)
168 {
169 
170 	DUMMY_ARGS_USAGE_2(subdims, pgran);
171 	CLBLASKernExtra *extraFlags = ( CLBLASKernExtra *)extra;
172 	char tempTemplate[32*1024];
173 
174 	if ( buf == NULL) // return buffer size
175 	{
176 		buflen = (32 * 1024 * sizeof(char));
177         return (ssize_t)buflen;
178 	}
179 
180 	#ifdef DEBUG_ROTMG
181 	printf("dataType : %c\n", Prefix[extraFlags->dtype]);
182 	#endif
183 
184     strcpy( tempTemplate, (char*)rotmg_kernel );
185 
186 	kprintf kobj( Prefix[extraFlags->dtype], 1, false, false);
187     kobj.spit((char*)buf, tempTemplate);
188 
189     return (32 * 1024 * sizeof(char));
190 }
191 
192 /*
193 __kernel void %PREFIXrotmg_kernel( __global %TYPE *_D1, __global %TYPE *_D2, __global %TYPE *_X1,
194                                 __global %TYPE *_Y1, __global %TYPE *_param,
195                                 uint offD1, uint offD2, uint offX1, uint offY1, uint offParam )
196 
197 */
198 static void
assignKargs(KernelArg * args,const void * params,const void *)199 assignKargs(KernelArg *args, const void *params, const void* )
200 {
201     CLBlasKargs *blasArgs = (CLBlasKargs*)params;
202 
203     INIT_KARG(&args[0], blasArgs->A);
204 	INIT_KARG(&args[1], blasArgs->B);
205 	INIT_KARG(&args[2], blasArgs->C);
206     INIT_KARG(&args[3], blasArgs->D);
207     INIT_KARG(&args[4], blasArgs->E);
208     initSizeKarg(&args[5], blasArgs->offa);
209     initSizeKarg(&args[6], blasArgs->offb);
210     initSizeKarg(&args[7], blasArgs->offc);
211     initSizeKarg(&args[8], blasArgs->offd);
212     initSizeKarg(&args[9], blasArgs->offe);
213 
214 	return;
215 }
216