1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <stdio.h>
25 #include <stdint.h>
26 #include <stdexcept>
27 
28 #include <directx/d3d12.h>
29 #include <dxgi1_4.h>
30 #include <gtest/gtest.h>
31 #include <wrl.h>
32 
33 #include "util/u_debug.h"
34 #include "clc_compiler.h"
35 #include "compute_test.h"
36 #include "dxcapi.h"
37 
38 #include <spirv-tools/libspirv.hpp>
39 
40 using std::runtime_error;
41 using Microsoft::WRL::ComPtr;
42 
43 enum compute_test_debug_flags {
44    COMPUTE_DEBUG_EXPERIMENTAL_SHADERS = 1 << 0,
45    COMPUTE_DEBUG_USE_HW_D3D           = 1 << 1,
46    COMPUTE_DEBUG_OPTIMIZE_LIBCLC      = 1 << 2,
47    COMPUTE_DEBUG_SERIALIZE_LIBCLC     = 1 << 3,
48 };
49 
50 static const struct debug_named_value compute_debug_options[] = {
51    { "experimental_shaders",  COMPUTE_DEBUG_EXPERIMENTAL_SHADERS, "Enable experimental shaders" },
52    { "use_hw_d3d",            COMPUTE_DEBUG_USE_HW_D3D,           "Use a hardware D3D device"   },
53    { "optimize_libclc",       COMPUTE_DEBUG_OPTIMIZE_LIBCLC,      "Optimize the clc_libclc before using it" },
54    { "serialize_libclc",      COMPUTE_DEBUG_SERIALIZE_LIBCLC,     "Serialize and deserialize the clc_libclc" },
55    DEBUG_NAMED_VALUE_END
56 };
57 
58 DEBUG_GET_ONCE_FLAGS_OPTION(debug_compute, "COMPUTE_TEST_DEBUG", compute_debug_options, 0)
59 
warning_callback(void * priv,const char * msg)60 static void warning_callback(void *priv, const char *msg)
61 {
62    fprintf(stderr, "WARNING: %s\n", msg);
63 }
64 
error_callback(void * priv,const char * msg)65 static void error_callback(void *priv, const char *msg)
66 {
67    fprintf(stderr, "ERROR: %s\n", msg);
68 }
69 
70 static const struct clc_logger logger = {
71    NULL,
72    error_callback,
73    warning_callback,
74 };
75 
76 void
enable_d3d12_debug_layer()77 ComputeTest::enable_d3d12_debug_layer()
78 {
79    HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
80    if (!hD3D12Mod) {
81       fprintf(stderr, "D3D12: failed to load D3D12.DLL\n");
82       return;
83    }
84 
85    typedef HRESULT(WINAPI * PFN_D3D12_GET_DEBUG_INTERFACE)(REFIID riid,
86                                                            void **ppFactory);
87    PFN_D3D12_GET_DEBUG_INTERFACE D3D12GetDebugInterface = (PFN_D3D12_GET_DEBUG_INTERFACE)GetProcAddress(hD3D12Mod, "D3D12GetDebugInterface");
88    if (!D3D12GetDebugInterface) {
89       fprintf(stderr, "D3D12: failed to load D3D12GetDebugInterface from D3D12.DLL\n");
90       return;
91    }
92 
93    ID3D12Debug *debug;
94    if (FAILED(D3D12GetDebugInterface(__uuidof(ID3D12Debug), (void **)& debug))) {
95       fprintf(stderr, "D3D12: D3D12GetDebugInterface failed\n");
96       return;
97    }
98 
99    debug->EnableDebugLayer();
100 }
101 
102 IDXGIFactory4 *
get_dxgi_factory()103 ComputeTest::get_dxgi_factory()
104 {
105    static const GUID IID_IDXGIFactory4 = {
106       0x1bc6ea02, 0xef36, 0x464f,
107       { 0xbf, 0x0c, 0x21, 0xca, 0x39, 0xe5, 0x16, 0x8a }
108    };
109 
110    typedef HRESULT(WINAPI * PFN_CREATE_DXGI_FACTORY)(REFIID riid,
111                                                      void **ppFactory);
112    PFN_CREATE_DXGI_FACTORY CreateDXGIFactory;
113 
114    HMODULE hDXGIMod = LoadLibrary("DXGI.DLL");
115    if (!hDXGIMod)
116       throw runtime_error("Failed to load DXGI.DLL");
117 
118    CreateDXGIFactory = (PFN_CREATE_DXGI_FACTORY)GetProcAddress(hDXGIMod, "CreateDXGIFactory");
119    if (!CreateDXGIFactory)
120       throw runtime_error("Failed to load CreateDXGIFactory from DXGI.DLL");
121 
122    IDXGIFactory4 *factory = NULL;
123    HRESULT hr = CreateDXGIFactory(IID_IDXGIFactory4, (void **)&factory);
124    if (FAILED(hr))
125       throw runtime_error("CreateDXGIFactory failed");
126 
127    return factory;
128 }
129 
130 IDXGIAdapter1 *
choose_adapter(IDXGIFactory4 * factory)131 ComputeTest::choose_adapter(IDXGIFactory4 *factory)
132 {
133    IDXGIAdapter1 *ret;
134 
135    if (debug_get_option_debug_compute() & COMPUTE_DEBUG_USE_HW_D3D) {
136       for (unsigned i = 0; SUCCEEDED(factory->EnumAdapters1(i, &ret)); i++) {
137          DXGI_ADAPTER_DESC1 desc;
138          ret->GetDesc1(&desc);
139          if (!(desc.Flags & D3D_DRIVER_TYPE_SOFTWARE))
140             return ret;
141       }
142       throw runtime_error("Failed to enum hardware adapter");
143    } else {
144       if (FAILED(factory->EnumWarpAdapter(__uuidof(IDXGIAdapter1),
145          (void **)& ret)))
146          throw runtime_error("Failed to enum warp adapter");
147       return ret;
148    }
149 }
150 
151 ID3D12Device *
create_device(IDXGIAdapter1 * adapter)152 ComputeTest::create_device(IDXGIAdapter1 *adapter)
153 {
154    typedef HRESULT(WINAPI *PFN_D3D12CREATEDEVICE)(IUnknown *, D3D_FEATURE_LEVEL, REFIID, void **);
155    PFN_D3D12CREATEDEVICE D3D12CreateDevice;
156 
157    HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
158    if (!hD3D12Mod)
159       throw runtime_error("failed to load D3D12.DLL");
160 
161    if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS) {
162       typedef HRESULT(WINAPI *PFN_D3D12ENABLEEXPERIMENTALFEATURES)(UINT, const IID *, void *, UINT *);
163       PFN_D3D12ENABLEEXPERIMENTALFEATURES D3D12EnableExperimentalFeatures;
164       D3D12EnableExperimentalFeatures = (PFN_D3D12ENABLEEXPERIMENTALFEATURES)
165          GetProcAddress(hD3D12Mod, "D3D12EnableExperimentalFeatures");
166       if (FAILED(D3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModels, NULL, NULL)))
167          throw runtime_error("failed to enable experimental shader models");
168    }
169 
170    D3D12CreateDevice = (PFN_D3D12CREATEDEVICE)GetProcAddress(hD3D12Mod, "D3D12CreateDevice");
171    if (!D3D12CreateDevice)
172       throw runtime_error("failed to load D3D12CreateDevice from D3D12.DLL");
173 
174    ID3D12Device *dev;
175    if (FAILED(D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_12_0,
176        __uuidof(ID3D12Device), (void **)& dev)))
177       throw runtime_error("D3D12CreateDevice failed");
178 
179    return dev;
180 }
181 
182 ComPtr<ID3D12RootSignature>
create_root_signature(const ComputeTest::Resources & resources)183 ComputeTest::create_root_signature(const ComputeTest::Resources &resources)
184 {
185    D3D12_ROOT_PARAMETER1 root_param;
186    root_param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
187    root_param.DescriptorTable.NumDescriptorRanges = resources.ranges.size();
188    root_param.DescriptorTable.pDescriptorRanges = resources.ranges.data();
189    root_param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
190 
191    D3D12_ROOT_SIGNATURE_DESC1 root_sig_desc;
192    root_sig_desc.NumParameters = 1;
193    root_sig_desc.pParameters = &root_param;
194    root_sig_desc.NumStaticSamplers = 0;
195    root_sig_desc.pStaticSamplers = NULL;
196    root_sig_desc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
197 
198    D3D12_VERSIONED_ROOT_SIGNATURE_DESC versioned_desc;
199    versioned_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
200    versioned_desc.Desc_1_1 = root_sig_desc;
201 
202    ID3DBlob *sig, *error;
203    if (FAILED(D3D12SerializeVersionedRootSignature(&versioned_desc,
204        &sig, &error)))
205       throw runtime_error("D3D12SerializeVersionedRootSignature failed");
206 
207    ComPtr<ID3D12RootSignature> ret;
208    if (FAILED(dev->CreateRootSignature(0,
209        sig->GetBufferPointer(),
210        sig->GetBufferSize(),
211        __uuidof(ret),
212        (void **)& ret)))
213       throw runtime_error("CreateRootSignature failed");
214 
215    return ret;
216 }
217 
218 ComPtr<ID3D12PipelineState>
create_pipeline_state(ComPtr<ID3D12RootSignature> & root_sig,const struct clc_dxil_object & dxil)219 ComputeTest::create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
220                                    const struct clc_dxil_object &dxil)
221 {
222    D3D12_COMPUTE_PIPELINE_STATE_DESC pipeline_desc = { root_sig.Get() };
223    pipeline_desc.CS.pShaderBytecode = dxil.binary.data;
224    pipeline_desc.CS.BytecodeLength = dxil.binary.size;
225 
226    ComPtr<ID3D12PipelineState> pipeline_state;
227    if (FAILED(dev->CreateComputePipelineState(&pipeline_desc,
228                                               __uuidof(pipeline_state),
229                                               (void **)& pipeline_state)))
230       throw runtime_error("Failed to create pipeline state");
231    return pipeline_state;
232 }
233 
234 ComPtr<ID3D12Resource>
create_buffer(int size,D3D12_HEAP_TYPE heap_type)235 ComputeTest::create_buffer(int size, D3D12_HEAP_TYPE heap_type)
236 {
237    D3D12_RESOURCE_DESC desc;
238    desc.Format = DXGI_FORMAT_UNKNOWN;
239    desc.Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
240    desc.Width = size;
241    desc.Height = 1;
242    desc.DepthOrArraySize = 1;
243    desc.MipLevels = 1;
244    desc.SampleDesc.Count = 1;
245    desc.SampleDesc.Quality = 0;
246    desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
247    desc.Flags = heap_type == D3D12_HEAP_TYPE_DEFAULT ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS : D3D12_RESOURCE_FLAG_NONE;
248    desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
249 
250    D3D12_HEAP_PROPERTIES heap_pris = dev->GetCustomHeapProperties(0, heap_type);
251 
252    ComPtr<ID3D12Resource> res;
253    if (FAILED(dev->CreateCommittedResource(&heap_pris,
254        D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_COMMON,
255        NULL, __uuidof(ID3D12Resource), (void **)&res)))
256       throw runtime_error("CreateCommittedResource failed");
257 
258    return res;
259 }
260 
261 ComPtr<ID3D12Resource>
create_upload_buffer_with_data(const void * data,size_t size)262 ComputeTest::create_upload_buffer_with_data(const void *data, size_t size)
263 {
264    auto upload_res = create_buffer(size, D3D12_HEAP_TYPE_UPLOAD);
265 
266    void *ptr = NULL;
267    D3D12_RANGE res_range = { 0, (SIZE_T)size };
268    if (FAILED(upload_res->Map(0, &res_range, (void **)&ptr)))
269       throw runtime_error("Failed to map upload-buffer");
270    assert(ptr);
271    memcpy(ptr, data, size);
272    upload_res->Unmap(0, &res_range);
273    return upload_res;
274 }
275 
276 ComPtr<ID3D12Resource>
create_sized_buffer_with_data(size_t buffer_size,const void * data,size_t data_size)277 ComputeTest::create_sized_buffer_with_data(size_t buffer_size,
278                                            const void *data,
279                                            size_t data_size)
280 {
281    auto upload_res = create_upload_buffer_with_data(data, data_size);
282 
283    auto res = create_buffer(buffer_size, D3D12_HEAP_TYPE_DEFAULT);
284    resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST);
285    cmdlist->CopyBufferRegion(res.Get(), 0, upload_res.Get(), 0, data_size);
286    resource_barrier(res, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON);
287    execute_cmdlist();
288 
289    return res;
290 }
291 
292 void
get_buffer_data(ComPtr<ID3D12Resource> res,void * buf,size_t size)293 ComputeTest::get_buffer_data(ComPtr<ID3D12Resource> res,
294                              void *buf, size_t size)
295 {
296    auto readback_res = create_buffer(align(size, 4), D3D12_HEAP_TYPE_READBACK);
297    resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_SOURCE);
298    cmdlist->CopyResource(readback_res.Get(), res.Get());
299    resource_barrier(res, D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_COMMON);
300    execute_cmdlist();
301 
302    void *ptr = NULL;
303    D3D12_RANGE res_range = { 0, size };
304    if (FAILED(readback_res->Map(0, &res_range, &ptr)))
305       throw runtime_error("Failed to map readback-buffer");
306 
307    memcpy(buf, ptr, size);
308 
309    D3D12_RANGE empty_range = { 0, 0 };
310    readback_res->Unmap(0, &empty_range);
311 }
312 
313 void
resource_barrier(ComPtr<ID3D12Resource> & res,D3D12_RESOURCE_STATES state_before,D3D12_RESOURCE_STATES state_after)314 ComputeTest::resource_barrier(ComPtr<ID3D12Resource> &res,
315                               D3D12_RESOURCE_STATES state_before,
316                               D3D12_RESOURCE_STATES state_after)
317 {
318    D3D12_RESOURCE_BARRIER barrier;
319    barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
320    barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
321    barrier.Transition.pResource = res.Get();
322    barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
323    barrier.Transition.StateBefore = state_before;
324    barrier.Transition.StateAfter = state_after;
325    cmdlist->ResourceBarrier(1, &barrier);
326 }
327 
328 void
execute_cmdlist()329 ComputeTest::execute_cmdlist()
330 {
331    if (FAILED(cmdlist->Close()))
332       throw runtime_error("Closing ID3D12GraphicsCommandList failed");
333 
334    ID3D12CommandList *cmdlists[] = { cmdlist };
335    cmdqueue->ExecuteCommandLists(1, cmdlists);
336    cmdqueue_fence->SetEventOnCompletion(fence_value, event);
337    cmdqueue->Signal(cmdqueue_fence, fence_value);
338    fence_value++;
339    WaitForSingleObject(event, INFINITE);
340 
341    if (FAILED(cmdalloc->Reset()))
342       throw runtime_error("resetting ID3D12CommandAllocator failed");
343 
344    if (FAILED(cmdlist->Reset(cmdalloc, NULL)))
345       throw runtime_error("resetting ID3D12GraphicsCommandList failed");
346 }
347 
348 void
create_uav_buffer(ComPtr<ID3D12Resource> res,size_t width,size_t byte_stride,D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)349 ComputeTest::create_uav_buffer(ComPtr<ID3D12Resource> res,
350                                size_t width, size_t byte_stride,
351                                D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
352 {
353    D3D12_UNORDERED_ACCESS_VIEW_DESC uav_desc;
354    uav_desc.Format = DXGI_FORMAT_R32_TYPELESS;
355    uav_desc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
356    uav_desc.Buffer.FirstElement = 0;
357    uav_desc.Buffer.NumElements = DIV_ROUND_UP(width * byte_stride, 4);
358    uav_desc.Buffer.StructureByteStride = 0;
359    uav_desc.Buffer.CounterOffsetInBytes = 0;
360    uav_desc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
361 
362    dev->CreateUnorderedAccessView(res.Get(), NULL, &uav_desc, cpu_handle);
363 }
364 
365 void
create_cbv(ComPtr<ID3D12Resource> res,size_t size,D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)366 ComputeTest::create_cbv(ComPtr<ID3D12Resource> res, size_t size,
367                         D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
368 {
369    D3D12_CONSTANT_BUFFER_VIEW_DESC cbv_desc;
370    cbv_desc.BufferLocation = res ? res->GetGPUVirtualAddress() : 0;
371    cbv_desc.SizeInBytes = size;
372 
373    dev->CreateConstantBufferView(&cbv_desc, cpu_handle);
374 }
375 
376 ComPtr<ID3D12Resource>
add_uav_resource(ComputeTest::Resources & resources,unsigned spaceid,unsigned resid,const void * data,size_t num_elems,size_t elem_size)377 ComputeTest::add_uav_resource(ComputeTest::Resources &resources,
378                               unsigned spaceid, unsigned resid,
379                               const void *data, size_t num_elems,
380                               size_t elem_size)
381 {
382    size_t size = align(elem_size * num_elems, 4);
383    D3D12_CPU_DESCRIPTOR_HANDLE handle;
384    ComPtr<ID3D12Resource> res;
385    handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
386    handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
387 
388    if (size) {
389       if (data)
390          res = create_buffer_with_data(data, size);
391       else
392          res = create_buffer(size, D3D12_HEAP_TYPE_DEFAULT);
393 
394       resource_barrier(res, D3D12_RESOURCE_STATE_COMMON,
395                        D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
396    }
397    create_uav_buffer(res, num_elems, elem_size, handle);
398    resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_UAV, spaceid, resid);
399    return res;
400 }
401 
402 ComPtr<ID3D12Resource>
add_cbv_resource(ComputeTest::Resources & resources,unsigned spaceid,unsigned resid,const void * data,size_t size)403 ComputeTest::add_cbv_resource(ComputeTest::Resources &resources,
404                               unsigned spaceid, unsigned resid,
405                               const void *data, size_t size)
406 {
407    unsigned aligned_size = align(size, 256);
408    D3D12_CPU_DESCRIPTOR_HANDLE handle;
409    ComPtr<ID3D12Resource> res;
410    handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
411    handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
412 
413    if (size) {
414      assert(data);
415      res = create_sized_buffer_with_data(aligned_size, data, size);
416    }
417    create_cbv(res, aligned_size, handle);
418    resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_CBV, spaceid, resid);
419    return res;
420 }
421 
422 void
run_shader_with_raw_args(Shader shader,const CompileArgs & compile_args,const std::vector<RawShaderArg * > & args)423 ComputeTest::run_shader_with_raw_args(Shader shader,
424                                       const CompileArgs &compile_args,
425                                       const std::vector<RawShaderArg *> &args)
426 {
427    if (args.size() < 1)
428       throw runtime_error("no inputs");
429 
430    static HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
431    if (!hD3D12Mod)
432       throw runtime_error("Failed to load D3D12.DLL");
433 
434    D3D12SerializeVersionedRootSignature = (PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE)GetProcAddress(hD3D12Mod, "D3D12SerializeVersionedRootSignature");
435 
436    if (args.size() != shader.dxil->kernel->num_args)
437       throw runtime_error("incorrect number of inputs");
438 
439    struct clc_runtime_kernel_conf conf = { 0 };
440 
441    // Older WARP and some hardware doesn't support int64, so for these tests, unconditionally lower away int64
442    // A more complex runtime can be smarter about detecting when this needs to be done
443    conf.lower_bit_size = 64;
444 
445    if (!shader.dxil->metadata.local_size[0])
446       conf.local_size[0] = compile_args.x;
447    else
448       conf.local_size[0] = shader.dxil->metadata.local_size[0];
449 
450    if (!shader.dxil->metadata.local_size[1])
451       conf.local_size[1] = compile_args.y;
452    else
453       conf.local_size[1] = shader.dxil->metadata.local_size[1];
454 
455    if (!shader.dxil->metadata.local_size[2])
456       conf.local_size[2] = compile_args.z;
457    else
458       conf.local_size[2] = shader.dxil->metadata.local_size[2];
459 
460    if (compile_args.x % conf.local_size[0] ||
461        compile_args.y % conf.local_size[1] ||
462        compile_args.z % conf.local_size[2])
463       throw runtime_error("invalid global size must be a multiple of local size");
464 
465    std::vector<struct clc_runtime_arg_info> argsinfo(args.size());
466 
467    conf.args = argsinfo.data();
468    conf.support_global_work_id_offsets =
469       compile_args.work_props.global_offset_x != 0 ||
470       compile_args.work_props.global_offset_y != 0 ||
471       compile_args.work_props.global_offset_z != 0;
472    conf.support_workgroup_id_offsets =
473       compile_args.work_props.group_id_offset_x != 0 ||
474       compile_args.work_props.group_id_offset_y != 0 ||
475       compile_args.work_props.group_id_offset_z != 0;
476 
477    for (unsigned i = 0; i < shader.dxil->kernel->num_args; ++i) {
478       RawShaderArg *arg = args[i];
479       size_t size = arg->get_elem_size() * arg->get_num_elems();
480 
481       switch (shader.dxil->kernel->args[i].address_qualifier) {
482       case CLC_KERNEL_ARG_ADDRESS_LOCAL:
483          argsinfo[i].localptr.size = size;
484          break;
485       default:
486          break;
487       }
488    }
489 
490    configure(shader, &conf);
491    validate(shader);
492 
493    std::shared_ptr<struct clc_dxil_object> &dxil = shader.dxil;
494 
495    std::vector<uint8_t> argsbuf(dxil->metadata.kernel_inputs_buf_size);
496    std::vector<ComPtr<ID3D12Resource>> argres(shader.dxil->kernel->num_args);
497    clc_work_properties_data work_props = compile_args.work_props;
498    if (!conf.support_workgroup_id_offsets) {
499       work_props.group_count_total_x = compile_args.x / conf.local_size[0];
500       work_props.group_count_total_y = compile_args.y / conf.local_size[1];
501       work_props.group_count_total_z = compile_args.z / conf.local_size[2];
502    }
503    if (work_props.work_dim == 0)
504       work_props.work_dim = 3;
505    Resources resources;
506 
507    for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
508       RawShaderArg *arg = args[i];
509       size_t size = arg->get_elem_size() * arg->get_num_elems();
510       void *slot = argsbuf.data() + dxil->metadata.args[i].offset;
511 
512       switch (dxil->kernel->args[i].address_qualifier) {
513       case CLC_KERNEL_ARG_ADDRESS_CONSTANT:
514       case CLC_KERNEL_ARG_ADDRESS_GLOBAL: {
515          assert(dxil->metadata.args[i].size == sizeof(uint64_t));
516          uint64_t *ptr_slot = (uint64_t *)slot;
517          if (arg->get_data())
518             *ptr_slot = (uint64_t)dxil->metadata.args[i].globconstptr.buf_id << 32;
519          else
520             *ptr_slot = ~0ull;
521          break;
522       }
523       case CLC_KERNEL_ARG_ADDRESS_LOCAL: {
524          assert(dxil->metadata.args[i].size == sizeof(uint64_t));
525          uint64_t *ptr_slot = (uint64_t *)slot;
526          *ptr_slot = dxil->metadata.args[i].localptr.sharedmem_offset;
527          break;
528       }
529       case CLC_KERNEL_ARG_ADDRESS_PRIVATE: {
530          assert(size == dxil->metadata.args[i].size);
531          memcpy(slot, arg->get_data(), size);
532          break;
533       }
534       default:
535          assert(0);
536       }
537    }
538 
539    for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
540       RawShaderArg *arg = args[i];
541 
542       if (dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL ||
543           dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) {
544          argres[i] = add_uav_resource(resources, 0,
545                                       dxil->metadata.args[i].globconstptr.buf_id,
546                                       arg->get_data(), arg->get_num_elems(),
547                                       arg->get_elem_size());
548       }
549    }
550 
551    if (dxil->metadata.printf.uav_id > 0)
552       add_uav_resource(resources, 0, dxil->metadata.printf.uav_id, NULL, 1024 * 1024 / 4, 4);
553 
554    for (unsigned i = 0; i < dxil->metadata.num_consts; ++i)
555       add_uav_resource(resources, 0, dxil->metadata.consts[i].uav_id,
556                        dxil->metadata.consts[i].data,
557                        dxil->metadata.consts[i].size / 4, 4);
558 
559    if (argsbuf.size())
560       add_cbv_resource(resources, 0, dxil->metadata.kernel_inputs_cbv_id,
561                        argsbuf.data(), argsbuf.size());
562 
563    add_cbv_resource(resources, 0, dxil->metadata.work_properties_cbv_id,
564                     &work_props, sizeof(work_props));
565 
566    auto root_sig = create_root_signature(resources);
567    auto pipeline_state = create_pipeline_state(root_sig, *dxil);
568 
569    cmdlist->SetDescriptorHeaps(1, &uav_heap);
570    cmdlist->SetComputeRootSignature(root_sig.Get());
571    cmdlist->SetComputeRootDescriptorTable(0, uav_heap->GetGPUDescriptorHandleForHeapStart());
572    cmdlist->SetPipelineState(pipeline_state.Get());
573 
574    cmdlist->Dispatch(compile_args.x / conf.local_size[0],
575                      compile_args.y / conf.local_size[1],
576                      compile_args.z / conf.local_size[2]);
577 
578    for (auto &range : resources.ranges) {
579       if (range.RangeType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV) {
580          for (unsigned i = range.OffsetInDescriptorsFromTableStart;
581               i < range.NumDescriptors; i++) {
582             if (!resources.descs[i].Get())
583                continue;
584 
585             resource_barrier(resources.descs[i],
586                              D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
587                              D3D12_RESOURCE_STATE_COMMON);
588          }
589       }
590    }
591 
592    execute_cmdlist();
593 
594    for (unsigned i = 0; i < args.size(); i++) {
595       if (!(args[i]->get_direction() & SHADER_ARG_OUTPUT))
596          continue;
597 
598       assert(dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL);
599       get_buffer_data(argres[i], args[i]->get_data(),
600                       args[i]->get_elem_size() * args[i]->get_num_elems());
601    }
602 
603    ComPtr<ID3D12InfoQueue> info_queue;
604    dev->QueryInterface(info_queue.ReleaseAndGetAddressOf());
605    if (info_queue)
606    {
607       EXPECT_EQ(0, info_queue->GetNumStoredMessages());
608       for (unsigned i = 0; i < info_queue->GetNumStoredMessages(); ++i) {
609          SIZE_T message_size = 0;
610          info_queue->GetMessageA(i, nullptr, &message_size);
611          D3D12_MESSAGE* message = (D3D12_MESSAGE*)malloc(message_size);
612          info_queue->GetMessageA(i, message, &message_size);
613          FAIL() << message->pDescription;
614          free(message);
615       }
616    }
617 }
618 
619 void
SetUp()620 ComputeTest::SetUp()
621 {
622    static struct clc_libclc *compiler_ctx_g = nullptr;
623 
624    if (!compiler_ctx_g) {
625       clc_libclc_dxil_options options = { };
626       options.optimize = (debug_get_option_debug_compute() & COMPUTE_DEBUG_OPTIMIZE_LIBCLC) != 0;
627 
628       compiler_ctx_g = clc_libclc_new_dxil(&logger, &options);
629       if (!compiler_ctx_g)
630          throw runtime_error("failed to create CLC compiler context");
631 
632       if (debug_get_option_debug_compute() & COMPUTE_DEBUG_SERIALIZE_LIBCLC) {
633          void *serialized = nullptr;
634          size_t serialized_size = 0;
635          clc_libclc_serialize(compiler_ctx_g, &serialized, &serialized_size);
636          if (!serialized)
637             throw runtime_error("failed to serialize CLC compiler context");
638 
639          clc_free_libclc(compiler_ctx_g);
640          compiler_ctx_g = nullptr;
641 
642          compiler_ctx_g = clc_libclc_deserialize(serialized, serialized_size);
643          if (!compiler_ctx_g)
644             throw runtime_error("failed to deserialize CLC compiler context");
645 
646          clc_libclc_free_serialized(serialized);
647       }
648    }
649    compiler_ctx = compiler_ctx_g;
650 
651    enable_d3d12_debug_layer();
652 
653    factory = get_dxgi_factory();
654    if (!factory)
655       throw runtime_error("failed to create DXGI factory");
656 
657    adapter = choose_adapter(factory);
658    if (!adapter)
659       throw runtime_error("failed to choose adapter");
660 
661    dev = create_device(adapter);
662    if (!dev)
663       throw runtime_error("failed to create device");
664 
665    if (FAILED(dev->CreateFence(0, D3D12_FENCE_FLAG_NONE,
666                                __uuidof(cmdqueue_fence),
667                                (void **)&cmdqueue_fence)))
668       throw runtime_error("failed to create fence\n");
669 
670    D3D12_COMMAND_QUEUE_DESC queue_desc;
671    queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
672    queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
673    queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
674    queue_desc.NodeMask = 0;
675    if (FAILED(dev->CreateCommandQueue(&queue_desc,
676                                       __uuidof(cmdqueue),
677                                       (void **)&cmdqueue)))
678       throw runtime_error("failed to create command queue");
679 
680    if (FAILED(dev->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE,
681              __uuidof(cmdalloc), (void **)&cmdalloc)))
682       throw runtime_error("failed to create command allocator");
683 
684    if (FAILED(dev->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE,
685              cmdalloc, NULL, __uuidof(cmdlist), (void **)&cmdlist)))
686       throw runtime_error("failed to create command list");
687 
688    D3D12_DESCRIPTOR_HEAP_DESC heap_desc;
689    heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
690    heap_desc.NumDescriptors = 1000;
691    heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
692    heap_desc.NodeMask = 0;
693    if (FAILED(dev->CreateDescriptorHeap(&heap_desc,
694        __uuidof(uav_heap), (void **)&uav_heap)))
695       throw runtime_error("failed to create descriptor heap");
696 
697    uav_heap_incr = dev->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
698 
699    event = CreateEvent(NULL, FALSE, FALSE, NULL);
700    if (!event)
701       throw runtime_error("Failed to create event");
702    fence_value = 1;
703 }
704 
705 void
TearDown()706 ComputeTest::TearDown()
707 {
708    CloseHandle(event);
709 
710    uav_heap->Release();
711    cmdlist->Release();
712    cmdalloc->Release();
713    cmdqueue->Release();
714    cmdqueue_fence->Release();
715    dev->Release();
716    adapter->Release();
717    factory->Release();
718 }
719 
720 PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE ComputeTest::D3D12SerializeVersionedRootSignature;
721 
722 bool
validate_module(const struct clc_dxil_object & dxil)723 validate_module(const struct clc_dxil_object &dxil)
724 {
725    static HMODULE hmod = LoadLibrary("DXIL.DLL");
726    if (!hmod) {
727       /* Enabling experimental shaders allows us to run unsigned shader code,
728        * such as when under the debugger where we can't run the validator. */
729       if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS)
730          return true;
731       else
732          throw runtime_error("failed to load DXIL.DLL");
733    }
734 
735    DxcCreateInstanceProc pfnDxcCreateInstance =
736       (DxcCreateInstanceProc)GetProcAddress(hmod, "DxcCreateInstance");
737    if (!pfnDxcCreateInstance)
738       throw runtime_error("failed to load DxcCreateInstance");
739 
740    struct shader_blob : public IDxcBlob {
741       shader_blob(void *data, size_t size) : data(data), size(size) {}
742       LPVOID STDMETHODCALLTYPE GetBufferPointer() override { return data; }
743       SIZE_T STDMETHODCALLTYPE GetBufferSize() override { return size; }
744       HRESULT STDMETHODCALLTYPE QueryInterface(REFIID, void **) override { return E_NOINTERFACE; }
745       ULONG STDMETHODCALLTYPE AddRef() override { return 1; }
746       ULONG STDMETHODCALLTYPE Release() override { return 0; }
747       void *data;
748       size_t size;
749    } blob(dxil.binary.data, dxil.binary.size);
750 
751    IDxcValidator *validator;
752    if (FAILED(pfnDxcCreateInstance(CLSID_DxcValidator, __uuidof(IDxcValidator),
753                                    (void **)&validator)))
754       throw runtime_error("failed to create IDxcValidator");
755 
756    IDxcOperationResult *result;
757    if (FAILED(validator->Validate(&blob, DxcValidatorFlags_InPlaceEdit,
758                                   &result)))
759       throw runtime_error("Validate failed");
760 
761    HRESULT hr;
762    if (FAILED(result->GetStatus(&hr)) ||
763        FAILED(hr)) {
764       IDxcBlobEncoding *message;
765       result->GetErrorBuffer(&message);
766       fprintf(stderr, "D3D12: validation failed: %*s\n",
767                    (int)message->GetBufferSize(),
768                    (char *)message->GetBufferPointer());
769       message->Release();
770       validator->Release();
771       result->Release();
772       return false;
773    }
774 
775    validator->Release();
776    result->Release();
777    return true;
778 }
779 
780 static void
dump_blob(const char * path,const struct clc_dxil_object & dxil)781 dump_blob(const char *path, const struct clc_dxil_object &dxil)
782 {
783    FILE *fp = fopen(path, "wb");
784    if (fp) {
785       fwrite(dxil.binary.data, 1, dxil.binary.size, fp);
786       fclose(fp);
787       printf("D3D12: wrote '%s'...\n", path);
788    }
789 }
790 
791 ComputeTest::Shader
compile(const std::vector<const char * > & sources,const std::vector<const char * > & compile_args,bool create_library)792 ComputeTest::compile(const std::vector<const char *> &sources,
793                      const std::vector<const char *> &compile_args,
794                      bool create_library)
795 {
796    struct clc_compile_args args = {
797    };
798    args.args = compile_args.data();
799    args.num_args = (unsigned)compile_args.size();
800    ComputeTest::Shader shader;
801 
802    std::vector<Shader> shaders;
803 
804    args.source.name = "obj.cl";
805 
806    for (unsigned i = 0; i < sources.size(); i++) {
807       args.source.value = sources[i];
808 
809       clc_binary spirv{};
810       if (!clc_compile_c_to_spirv(&args, &logger, &spirv))
811          throw runtime_error("failed to compile object!");
812 
813       Shader shader;
814       shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
815          {
816             clc_free_spirv(spirv);
817             delete spirv;
818          });
819       shaders.push_back(shader);
820    }
821 
822    if (shaders.size() == 1 && create_library)
823       return shaders[0];
824 
825    return link(shaders, create_library);
826 }
827 
828 ComputeTest::Shader
link(const std::vector<Shader> & sources,bool create_library)829 ComputeTest::link(const std::vector<Shader> &sources,
830                   bool create_library)
831 {
832    std::vector<const clc_binary*> objs;
833    for (auto& source : sources)
834       objs.push_back(&*source.obj);
835 
836    struct clc_linker_args link_args = {};
837    link_args.in_objs = objs.data();
838    link_args.num_in_objs = (unsigned)objs.size();
839    link_args.create_library = create_library;
840    clc_binary spirv{};
841    if (!clc_link_spirv(&link_args, &logger, &spirv))
842       throw runtime_error("failed to link objects!");
843 
844    ComputeTest::Shader shader;
845    shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
846       {
847          clc_free_spirv(spirv);
848          delete spirv;
849       });
850    if (!link_args.create_library)
851       configure(shader, NULL);
852 
853    return shader;
854 }
855 
856 ComputeTest::Shader
assemble(const char * source)857 ComputeTest::assemble(const char *source)
858 {
859    spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
860    std::vector<uint32_t> binary;
861    if (!tools.Assemble(source, strlen(source), &binary))
862       throw runtime_error("failed to assemble");
863 
864    ComputeTest::Shader shader;
865    shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
866       {
867          free(spirv->data);
868          delete spirv;
869       });
870    shader.obj->size = binary.size() * 4;
871    shader.obj->data = malloc(shader.obj->size);
872    memcpy(shader.obj->data, binary.data(), shader.obj->size);
873 
874    configure(shader, NULL);
875 
876    return shader;
877 }
878 
879 void
configure(Shader & shader,const struct clc_runtime_kernel_conf * conf)880 ComputeTest::configure(Shader &shader,
881                        const struct clc_runtime_kernel_conf *conf)
882 {
883    if (!shader.metadata) {
884       shader.metadata = std::shared_ptr<clc_parsed_spirv>(new clc_parsed_spirv{}, [](clc_parsed_spirv *metadata)
885          {
886             clc_free_parsed_spirv(metadata);
887             delete metadata;
888          });
889       if (!clc_parse_spirv(shader.obj.get(), NULL, shader.metadata.get()))
890          throw runtime_error("failed to parse spirv!");
891    }
892 
893    shader.dxil = std::shared_ptr<clc_dxil_object>(new clc_dxil_object{}, [](clc_dxil_object *dxil)
894       {
895          clc_free_dxil_object(dxil);
896          delete dxil;
897       });
898    if (!clc_spirv_to_dxil(compiler_ctx, shader.obj.get(), shader.metadata.get(), "main_test", conf, nullptr, &logger, shader.dxil.get()))
899       throw runtime_error("failed to compile kernel!");
900 }
901 
902 void
validate(ComputeTest::Shader & shader)903 ComputeTest::validate(ComputeTest::Shader &shader)
904 {
905    dump_blob("unsigned.cso", *shader.dxil);
906    if (!validate_module(*shader.dxil))
907       throw runtime_error("failed to validate module!");
908 
909    dump_blob("signed.cso", *shader.dxil);
910 }
911