1 // Copyright 2018 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/ComputePassEncoder.h"
16 
17 #include "dawn_native/Buffer.h"
18 #include "dawn_native/CommandEncoder.h"
19 #include "dawn_native/CommandValidation.h"
20 #include "dawn_native/Commands.h"
21 #include "dawn_native/ComputePipeline.h"
22 #include "dawn_native/Device.h"
23 #include "dawn_native/QuerySet.h"
24 
25 namespace dawn_native {
26 
ComputePassEncoder(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext)27     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
28                                            CommandEncoder* commandEncoder,
29                                            EncodingContext* encodingContext)
30         : ProgrammablePassEncoder(device, encodingContext, PassType::Compute),
31           mCommandEncoder(commandEncoder) {
32     }
33 
ComputePassEncoder(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext,ErrorTag errorTag)34     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
35                                            CommandEncoder* commandEncoder,
36                                            EncodingContext* encodingContext,
37                                            ErrorTag errorTag)
38         : ProgrammablePassEncoder(device, encodingContext, errorTag, PassType::Compute),
39           mCommandEncoder(commandEncoder) {
40     }
41 
MakeError(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext)42     ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device,
43                                                       CommandEncoder* commandEncoder,
44                                                       EncodingContext* encodingContext) {
45         return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
46     }
47 
EndPass()48     void ComputePassEncoder::EndPass() {
49         if (mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
50                 allocator->Allocate<EndComputePassCmd>(Command::EndComputePass);
51 
52                 return {};
53             })) {
54             mEncodingContext->ExitPass(this, mUsageTracker.AcquireResourceUsage());
55         }
56     }
57 
Dispatch(uint32_t x,uint32_t y,uint32_t z)58     void ComputePassEncoder::Dispatch(uint32_t x, uint32_t y, uint32_t z) {
59         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
60             DispatchCmd* dispatch = allocator->Allocate<DispatchCmd>(Command::Dispatch);
61             dispatch->x = x;
62             dispatch->y = y;
63             dispatch->z = z;
64 
65             return {};
66         });
67     }
68 
DispatchIndirect(BufferBase * indirectBuffer,uint64_t indirectOffset)69     void ComputePassEncoder::DispatchIndirect(BufferBase* indirectBuffer, uint64_t indirectOffset) {
70         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
71             DAWN_TRY(GetDevice()->ValidateObject(indirectBuffer));
72 
73             // Indexed dispatches need a compute-shader based validation to check that the dispatch
74             // sizes aren't too big. Disallow them as unsafe until the validation is implemented.
75             if (GetDevice()->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
76                 return DAWN_VALIDATION_ERROR(
77                     "DispatchIndirect is disallowed because it doesn't validate that the dispatch "
78                     "size is valid yet.");
79             }
80 
81             if (indirectOffset % 4 != 0) {
82                 return DAWN_VALIDATION_ERROR("Indirect offset must be a multiple of 4");
83             }
84 
85             if (indirectOffset >= indirectBuffer->GetSize() ||
86                 indirectOffset + kDispatchIndirectSize > indirectBuffer->GetSize()) {
87                 return DAWN_VALIDATION_ERROR("Indirect offset out of bounds");
88             }
89 
90             DispatchIndirectCmd* dispatch =
91                 allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
92             dispatch->indirectBuffer = indirectBuffer;
93             dispatch->indirectOffset = indirectOffset;
94 
95             mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
96 
97             return {};
98         });
99     }
100 
SetPipeline(ComputePipelineBase * pipeline)101     void ComputePassEncoder::SetPipeline(ComputePipelineBase* pipeline) {
102         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
103             DAWN_TRY(GetDevice()->ValidateObject(pipeline));
104 
105             SetComputePipelineCmd* cmd =
106                 allocator->Allocate<SetComputePipelineCmd>(Command::SetComputePipeline);
107             cmd->pipeline = pipeline;
108 
109             return {};
110         });
111     }
112 
WriteTimestamp(QuerySetBase * querySet,uint32_t queryIndex)113     void ComputePassEncoder::WriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex) {
114         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
115             if (GetDevice()->IsValidationEnabled()) {
116                 DAWN_TRY(GetDevice()->ValidateObject(querySet));
117                 DAWN_TRY(ValidateTimestampQuery(querySet, queryIndex,
118                                                 mCommandEncoder->GetUsedQueryIndices()));
119                 mCommandEncoder->TrackUsedQuerySet(querySet);
120             }
121 
122             mCommandEncoder->TrackUsedQueryIndex(querySet, queryIndex);
123 
124             WriteTimestampCmd* cmd =
125                 allocator->Allocate<WriteTimestampCmd>(Command::WriteTimestamp);
126             cmd->querySet = querySet;
127             cmd->queryIndex = queryIndex;
128 
129             return {};
130         });
131     }
132 
133 }  // namespace dawn_native
134