1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/writer/msl/generator_impl.h"
16
17 #include <algorithm>
18 #include <limits>
19 #include <utility>
20 #include <vector>
21
22 #include "src/ast/array_accessor_expression.h"
23 #include "src/ast/assignment_statement.h"
24 #include "src/ast/binary_expression.h"
25 #include "src/ast/bitcast_expression.h"
26 #include "src/ast/block_statement.h"
27 #include "src/ast/bool_literal.h"
28 #include "src/ast/break_statement.h"
29 #include "src/ast/call_expression.h"
30 #include "src/ast/call_statement.h"
31 #include "src/ast/case_statement.h"
32 #include "src/ast/continue_statement.h"
33 #include "src/ast/decorated_variable.h"
34 #include "src/ast/else_statement.h"
35 #include "src/ast/float_literal.h"
36 #include "src/ast/function.h"
37 #include "src/ast/identifier_expression.h"
38 #include "src/ast/if_statement.h"
39 #include "src/ast/location_decoration.h"
40 #include "src/ast/loop_statement.h"
41 #include "src/ast/member_accessor_expression.h"
42 #include "src/ast/return_statement.h"
43 #include "src/ast/sint_literal.h"
44 #include "src/ast/struct_member_offset_decoration.h"
45 #include "src/ast/switch_statement.h"
46 #include "src/ast/type/access_control_type.h"
47 #include "src/ast/type/alias_type.h"
48 #include "src/ast/type/array_type.h"
49 #include "src/ast/type/bool_type.h"
50 #include "src/ast/type/depth_texture_type.h"
51 #include "src/ast/type/f32_type.h"
52 #include "src/ast/type/i32_type.h"
53 #include "src/ast/type/matrix_type.h"
54 #include "src/ast/type/multisampled_texture_type.h"
55 #include "src/ast/type/pointer_type.h"
56 #include "src/ast/type/sampled_texture_type.h"
57 #include "src/ast/type/sampler_type.h"
58 #include "src/ast/type/storage_texture_type.h"
59 #include "src/ast/type/struct_type.h"
60 #include "src/ast/type/u32_type.h"
61 #include "src/ast/type/vector_type.h"
62 #include "src/ast/type/void_type.h"
63 #include "src/ast/uint_literal.h"
64 #include "src/ast/unary_op_expression.h"
65 #include "src/ast/variable_decl_statement.h"
66
67 namespace tint {
68 namespace writer {
69 namespace msl {
70 namespace {
71
72 const char kInStructNameSuffix[] = "in";
73 const char kOutStructNameSuffix[] = "out";
74 const char kTintStructInVarPrefix[] = "tint_in";
75 const char kTintStructOutVarPrefix[] = "tint_out";
76
last_is_break_or_fallthrough(const ast::BlockStatement * stmts)77 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
78 if (stmts->empty()) {
79 return false;
80 }
81
82 return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
83 }
84
adjust_for_alignment(uint32_t count,uint32_t alignment)85 uint32_t adjust_for_alignment(uint32_t count, uint32_t alignment) {
86 const auto spill = count % alignment;
87 if (spill == 0) {
88 return count;
89 }
90 return count + alignment - spill;
91 }
92
93 } // namespace
94
GeneratorImpl(ast::Module * module)95 GeneratorImpl::GeneratorImpl(ast::Module* module) : module_(module) {}
96
97 GeneratorImpl::~GeneratorImpl() = default;
98
generate_name(const std::string & prefix)99 std::string GeneratorImpl::generate_name(const std::string& prefix) {
100 std::string name = prefix;
101 uint32_t i = 0;
102 while (namer_.IsMapped(name)) {
103 name = prefix + "_" + std::to_string(i);
104 ++i;
105 }
106 namer_.RegisterRemappedName(name);
107 return name;
108 }
109
Generate()110 bool GeneratorImpl::Generate() {
111 out_ << "#include <metal_stdlib>" << std::endl << std::endl;
112
113 for (const auto& global : module_->global_variables()) {
114 global_variables_.set(global->name(), global.get());
115 }
116
117 for (auto* const ty : module_->constructed_types()) {
118 if (!EmitConstructedType(ty)) {
119 return false;
120 }
121 }
122 if (!module_->constructed_types().empty()) {
123 out_ << std::endl;
124 }
125
126 for (const auto& var : module_->global_variables()) {
127 if (!var->is_const()) {
128 continue;
129 }
130 if (!EmitProgramConstVariable(var.get())) {
131 return false;
132 }
133 }
134
135 // Make sure all entry point data is emitted before the entry point functions
136 for (const auto& func : module_->functions()) {
137 if (!func->IsEntryPoint()) {
138 continue;
139 }
140
141 if (!EmitEntryPointData(func.get())) {
142 return false;
143 }
144 }
145
146 for (const auto& func : module_->functions()) {
147 if (!EmitFunction(func.get())) {
148 return false;
149 }
150 }
151
152 for (const auto& func : module_->functions()) {
153 if (!func->IsEntryPoint()) {
154 continue;
155 }
156 if (!EmitEntryPointFunction(func.get())) {
157 return false;
158 }
159 out_ << std::endl;
160 }
161
162 return true;
163 }
164
calculate_largest_alignment(ast::type::StructType * type)165 uint32_t GeneratorImpl::calculate_largest_alignment(
166 ast::type::StructType* type) {
167 auto* stct = type->AsStruct()->impl();
168 uint32_t largest_alignment = 0;
169 for (const auto& mem : stct->members()) {
170 auto align = calculate_alignment_size(mem->type());
171 if (align == 0) {
172 return 0;
173 }
174 if (!mem->type()->IsStruct()) {
175 largest_alignment = std::max(largest_alignment, align);
176 } else {
177 largest_alignment =
178 std::max(largest_alignment,
179 calculate_largest_alignment(mem->type()->AsStruct()));
180 }
181 }
182 return largest_alignment;
183 }
184
calculate_alignment_size(ast::type::Type * type)185 uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
186 if (type->IsAlias()) {
187 return calculate_alignment_size(type->AsAlias()->type());
188 }
189 if (type->IsArray()) {
190 auto* ary = type->AsArray();
191 // TODO(dsinclair): Handle array stride and adjust for alignment.
192 uint32_t type_size = calculate_alignment_size(ary->type());
193 return ary->size() * type_size;
194 }
195 if (type->IsBool()) {
196 return 1;
197 }
198 if (type->IsPointer()) {
199 return 0;
200 }
201 if (type->IsF32() || type->IsI32() || type->IsU32()) {
202 return 4;
203 }
204 if (type->IsMatrix()) {
205 auto* mat = type->AsMatrix();
206 // TODO(dsinclair): Handle MatrixStride
207 // https://github.com/gpuweb/gpuweb/issues/773
208 uint32_t type_size = calculate_alignment_size(mat->type());
209 return mat->rows() * mat->columns() * type_size;
210 }
211 if (type->IsStruct()) {
212 auto* stct = type->AsStruct()->impl();
213 uint32_t count = 0;
214 uint32_t largest_alignment = 0;
215 // Offset decorations in WGSL must be in increasing order.
216 for (const auto& mem : stct->members()) {
217 for (const auto& deco : mem->decorations()) {
218 if (deco->IsOffset()) {
219 count = deco->AsOffset()->offset();
220 }
221 }
222 auto align = calculate_alignment_size(mem->type());
223 if (align == 0) {
224 return 0;
225 }
226 if (!mem->type()->IsStruct()) {
227 largest_alignment = std::max(largest_alignment, align);
228 } else {
229 largest_alignment =
230 std::max(largest_alignment,
231 calculate_largest_alignment(mem->type()->AsStruct()));
232 }
233
234 // Round up to the alignment size
235 count = adjust_for_alignment(count, align);
236 count += align;
237 }
238 // Round struct up to largest align size
239 count = adjust_for_alignment(count, largest_alignment);
240 return count;
241 }
242 if (type->IsVector()) {
243 auto* vec = type->AsVector();
244 uint32_t type_size = calculate_alignment_size(vec->type());
245 if (vec->size() == 2) {
246 return 2 * type_size;
247 }
248 return 4 * type_size;
249 }
250 return 0;
251 }
252
EmitConstructedType(const ast::type::Type * ty)253 bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) {
254 make_indent();
255
256 if (ty->IsAlias()) {
257 auto* alias = ty->AsAlias();
258
259 out_ << "typedef ";
260 if (!EmitType(alias->type(), "")) {
261 return false;
262 }
263 out_ << " " << namer_.NameFor(alias->name()) << ";" << std::endl;
264 } else if (ty->IsStruct()) {
265 if (!EmitStructType(ty->AsStruct())) {
266 return false;
267 }
268 } else {
269 error_ = "unknown alias type: " + ty->type_name();
270 return false;
271 }
272
273 return true;
274 }
275
EmitArrayAccessor(ast::ArrayAccessorExpression * expr)276 bool GeneratorImpl::EmitArrayAccessor(ast::ArrayAccessorExpression* expr) {
277 if (!EmitExpression(expr->array())) {
278 return false;
279 }
280 out_ << "[";
281
282 if (!EmitExpression(expr->idx_expr())) {
283 return false;
284 }
285 out_ << "]";
286
287 return true;
288 }
289
EmitBitcast(ast::BitcastExpression * expr)290 bool GeneratorImpl::EmitBitcast(ast::BitcastExpression* expr) {
291 out_ << "as_type<";
292 if (!EmitType(expr->type(), "")) {
293 return false;
294 }
295
296 out_ << ">(";
297 if (!EmitExpression(expr->expr())) {
298 return false;
299 }
300
301 out_ << ")";
302 return true;
303 }
304
EmitAssign(ast::AssignmentStatement * stmt)305 bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) {
306 make_indent();
307
308 if (!EmitExpression(stmt->lhs())) {
309 return false;
310 }
311
312 out_ << " = ";
313
314 if (!EmitExpression(stmt->rhs())) {
315 return false;
316 }
317
318 out_ << ";" << std::endl;
319
320 return true;
321 }
322
EmitBinary(ast::BinaryExpression * expr)323 bool GeneratorImpl::EmitBinary(ast::BinaryExpression* expr) {
324 out_ << "(";
325
326 if (!EmitExpression(expr->lhs())) {
327 return false;
328 }
329 out_ << " ";
330
331 switch (expr->op()) {
332 case ast::BinaryOp::kAnd:
333 out_ << "&";
334 break;
335 case ast::BinaryOp::kOr:
336 out_ << "|";
337 break;
338 case ast::BinaryOp::kXor:
339 out_ << "^";
340 break;
341 case ast::BinaryOp::kLogicalAnd:
342 out_ << "&&";
343 break;
344 case ast::BinaryOp::kLogicalOr:
345 out_ << "||";
346 break;
347 case ast::BinaryOp::kEqual:
348 out_ << "==";
349 break;
350 case ast::BinaryOp::kNotEqual:
351 out_ << "!=";
352 break;
353 case ast::BinaryOp::kLessThan:
354 out_ << "<";
355 break;
356 case ast::BinaryOp::kGreaterThan:
357 out_ << ">";
358 break;
359 case ast::BinaryOp::kLessThanEqual:
360 out_ << "<=";
361 break;
362 case ast::BinaryOp::kGreaterThanEqual:
363 out_ << ">=";
364 break;
365 case ast::BinaryOp::kShiftLeft:
366 out_ << "<<";
367 break;
368 case ast::BinaryOp::kShiftRight:
369 // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
370 // implementation-defined behaviour for negative LHS. We may have to
371 // generate extra code to implement WGSL-specified behaviour for negative
372 // LHS.
373 out_ << R"(>>)";
374 break;
375
376 case ast::BinaryOp::kAdd:
377 out_ << "+";
378 break;
379 case ast::BinaryOp::kSubtract:
380 out_ << "-";
381 break;
382 case ast::BinaryOp::kMultiply:
383 out_ << "*";
384 break;
385 case ast::BinaryOp::kDivide:
386 out_ << "/";
387 break;
388 case ast::BinaryOp::kModulo:
389 out_ << "%";
390 break;
391 case ast::BinaryOp::kNone:
392 error_ = "missing binary operation type";
393 return false;
394 }
395 out_ << " ";
396
397 if (!EmitExpression(expr->rhs())) {
398 return false;
399 }
400
401 out_ << ")";
402 return true;
403 }
404
EmitBreak(ast::BreakStatement *)405 bool GeneratorImpl::EmitBreak(ast::BreakStatement*) {
406 make_indent();
407 out_ << "break;" << std::endl;
408 return true;
409 }
410
current_ep_var_name(VarType type)411 std::string GeneratorImpl::current_ep_var_name(VarType type) {
412 std::string name = "";
413 switch (type) {
414 case VarType::kIn: {
415 auto in_it = ep_name_to_in_data_.find(current_ep_name_);
416 if (in_it != ep_name_to_in_data_.end()) {
417 name = in_it->second.var_name;
418 }
419 break;
420 }
421 case VarType::kOut: {
422 auto out_it = ep_name_to_out_data_.find(current_ep_name_);
423 if (out_it != ep_name_to_out_data_.end()) {
424 name = out_it->second.var_name;
425 }
426 break;
427 }
428 }
429 return name;
430 }
431
generate_intrinsic_name(ast::Intrinsic intrinsic)432 std::string GeneratorImpl::generate_intrinsic_name(ast::Intrinsic intrinsic) {
433 if (intrinsic == ast::Intrinsic::kAny) {
434 return "any";
435 }
436 if (intrinsic == ast::Intrinsic::kAll) {
437 return "all";
438 }
439 if (intrinsic == ast::Intrinsic::kCountOneBits) {
440 return "popcount";
441 }
442 if (intrinsic == ast::Intrinsic::kDot) {
443 return "dot";
444 }
445 if (intrinsic == ast::Intrinsic::kDpdy ||
446 intrinsic == ast::Intrinsic::kDpdyFine ||
447 intrinsic == ast::Intrinsic::kDpdyCoarse) {
448 return "dfdy";
449 }
450 if (intrinsic == ast::Intrinsic::kDpdx ||
451 intrinsic == ast::Intrinsic::kDpdxFine ||
452 intrinsic == ast::Intrinsic::kDpdxCoarse) {
453 return "dfdx";
454 }
455 if (intrinsic == ast::Intrinsic::kFwidth ||
456 intrinsic == ast::Intrinsic::kFwidthFine ||
457 intrinsic == ast::Intrinsic::kFwidthCoarse) {
458 return "fwidth";
459 }
460 if (intrinsic == ast::Intrinsic::kIsFinite) {
461 return "isfinite";
462 }
463 if (intrinsic == ast::Intrinsic::kIsInf) {
464 return "isinf";
465 }
466 if (intrinsic == ast::Intrinsic::kIsNan) {
467 return "isnan";
468 }
469 if (intrinsic == ast::Intrinsic::kIsNormal) {
470 return "isnormal";
471 }
472 if (intrinsic == ast::Intrinsic::kReverseBits) {
473 return "reverse_bits";
474 }
475 if (intrinsic == ast::Intrinsic::kSelect) {
476 return "select";
477 }
478 return "";
479 }
480
EmitCall(ast::CallExpression * expr)481 bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
482 if (!expr->func()->IsIdentifier()) {
483 error_ = "invalid function name";
484 return 0;
485 }
486
487 auto* ident = expr->func()->AsIdentifier();
488 if (ident->IsIntrinsic()) {
489 const auto& params = expr->params();
490 if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) {
491 error_ = "outer_product not supported yet";
492 return false;
493 // TODO(dsinclair): This gets tricky. We need to generate two variables to
494 // hold the outer_product expressions, but we maybe inside an expression
495 // ourselves. So, this will need to, possibly, output the variables
496 // _before_ the expression which contains the outer product.
497 //
498 // This then has the follow on, what if we have `(false &&
499 // outer_product())` in that case, we shouldn't evaluate the expressions
500 // at all because of short circuting.
501 //
502 // So .... this turns out to be hard ...
503
504 // // We create variables to hold the two parameters in case they're
505 // // function calls with side effects.
506 // auto* param0 = param[0].get();
507 // auto* name0 = generate_name("outer_product_expr_0");
508
509 // auto* param1 = param[1].get();
510 // auto* name1 = generate_name("outer_product_expr_1");
511
512 // make_indent();
513 // if (!EmitType(expr->result_type(), "")) {
514 // return false;
515 // }
516 // out_ << "(";
517
518 // auto param1_type = params[1]->result_type()->UnwrapPtrIfNeeded();
519 // if (!param1_type->IsVector()) {
520 // error_ = "invalid param type in outer_product got: " +
521 // param1_type->type_name();
522 // return false;
523 // }
524
525 // for (uint32_t i = 0; i < param1_type->AsVector()->size(); ++i) {
526 // if (i > 0) {
527 // out_ << ", ";
528 // }
529
530 // if (!EmitExpression(params[0].get())) {
531 // return false;
532 // }
533 // out_ << " * ";
534
535 // if (!EmitExpression(params[1].get())) {
536 // return false;
537 // }
538 // out_ << "[" << i << "]";
539 // }
540
541 // out_ << ")";
542 } else {
543 auto name = generate_intrinsic_name(ident->intrinsic());
544 if (name.empty()) {
545 if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) {
546 error_ = "Textures not implemented yet";
547 return false;
548 }
549 name = generate_builtin_name(ident);
550 if (name.empty()) {
551 return false;
552 }
553 }
554
555 make_indent();
556 out_ << name << "(";
557
558 bool first = true;
559 for (const auto& param : params) {
560 if (!first) {
561 out_ << ", ";
562 }
563 first = false;
564
565 if (!EmitExpression(param.get())) {
566 return false;
567 }
568 }
569
570 out_ << ")";
571 }
572 return true;
573 }
574
575 auto name = ident->name();
576 auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
577 if (it != ep_func_name_remapped_.end()) {
578 name = it->second;
579 }
580
581 auto* func = module_->FindFunctionByName(ident->name());
582 if (func == nullptr) {
583 error_ = "Unable to find function: " + name;
584 return false;
585 }
586
587 out_ << name << "(";
588
589 bool first = true;
590 if (has_referenced_in_var_needing_struct(func)) {
591 auto var_name = current_ep_var_name(VarType::kIn);
592 if (!var_name.empty()) {
593 out_ << var_name;
594 first = false;
595 }
596 }
597 if (has_referenced_out_var_needing_struct(func)) {
598 auto var_name = current_ep_var_name(VarType::kOut);
599 if (!var_name.empty()) {
600 if (!first) {
601 out_ << ", ";
602 }
603 first = false;
604 out_ << var_name;
605 }
606 }
607
608 for (const auto& data : func->referenced_builtin_variables()) {
609 auto* var = data.first;
610 if (var->storage_class() != ast::StorageClass::kInput) {
611 continue;
612 }
613 if (!first) {
614 out_ << ", ";
615 }
616 first = false;
617 out_ << var->name();
618 }
619
620 for (const auto& data : func->referenced_uniform_variables()) {
621 auto* var = data.first;
622 if (!first) {
623 out_ << ", ";
624 }
625 first = false;
626 out_ << var->name();
627 }
628
629 for (const auto& data : func->referenced_storagebuffer_variables()) {
630 auto* var = data.first;
631 if (!first) {
632 out_ << ", ";
633 }
634 first = false;
635 out_ << var->name();
636 }
637
638 const auto& params = expr->params();
639 for (const auto& param : params) {
640 if (!first) {
641 out_ << ", ";
642 }
643 first = false;
644
645 if (!EmitExpression(param.get())) {
646 return false;
647 }
648 }
649
650 out_ << ")";
651
652 return true;
653 }
654
generate_builtin_name(ast::IdentifierExpression * ident)655 std::string GeneratorImpl::generate_builtin_name(
656 ast::IdentifierExpression* ident) {
657 std::string out = "metal::";
658 switch (ident->intrinsic()) {
659 case ast::Intrinsic::kAcos:
660 case ast::Intrinsic::kAsin:
661 case ast::Intrinsic::kAtan:
662 case ast::Intrinsic::kAtan2:
663 case ast::Intrinsic::kCeil:
664 case ast::Intrinsic::kCos:
665 case ast::Intrinsic::kCosh:
666 case ast::Intrinsic::kCross:
667 case ast::Intrinsic::kDeterminant:
668 case ast::Intrinsic::kDistance:
669 case ast::Intrinsic::kExp:
670 case ast::Intrinsic::kExp2:
671 case ast::Intrinsic::kFloor:
672 case ast::Intrinsic::kFma:
673 case ast::Intrinsic::kFract:
674 case ast::Intrinsic::kLength:
675 case ast::Intrinsic::kLog:
676 case ast::Intrinsic::kLog2:
677 case ast::Intrinsic::kMix:
678 case ast::Intrinsic::kNormalize:
679 case ast::Intrinsic::kPow:
680 case ast::Intrinsic::kReflect:
681 case ast::Intrinsic::kRound:
682 case ast::Intrinsic::kSin:
683 case ast::Intrinsic::kSinh:
684 case ast::Intrinsic::kSqrt:
685 case ast::Intrinsic::kStep:
686 case ast::Intrinsic::kTan:
687 case ast::Intrinsic::kTanh:
688 case ast::Intrinsic::kTrunc:
689 case ast::Intrinsic::kSign:
690 case ast::Intrinsic::kClamp:
691 out += ident->name();
692 break;
693 case ast::Intrinsic::kAbs:
694 if (ident->result_type()->IsF32()) {
695 out += "fabs";
696 } else if (ident->result_type()->IsU32() ||
697 ident->result_type()->IsI32()) {
698 out += "abs";
699 }
700 break;
701 case ast::Intrinsic::kMax:
702 if (ident->result_type()->IsF32()) {
703 out += "fmax";
704 } else if (ident->result_type()->IsU32() ||
705 ident->result_type()->IsI32()) {
706 out += "max";
707 }
708 break;
709 case ast::Intrinsic::kMin:
710 if (ident->result_type()->IsF32()) {
711 out += "fmin";
712 } else if (ident->result_type()->IsU32() ||
713 ident->result_type()->IsI32()) {
714 out += "min";
715 }
716 break;
717 case ast::Intrinsic::kFaceForward:
718 out += "faceforward";
719 break;
720 case ast::Intrinsic::kSmoothStep:
721 out += "smoothstep";
722 break;
723 case ast::Intrinsic::kInverseSqrt:
724 out += "rsqrt";
725 break;
726 default:
727 error_ = "Unknown import method: " + ident->name();
728 return "";
729 }
730 return out;
731 }
732
EmitCase(ast::CaseStatement * stmt)733 bool GeneratorImpl::EmitCase(ast::CaseStatement* stmt) {
734 make_indent();
735
736 if (stmt->IsDefault()) {
737 out_ << "default:";
738 } else {
739 bool first = true;
740 for (const auto& selector : stmt->selectors()) {
741 if (!first) {
742 out_ << std::endl;
743 make_indent();
744 }
745 first = false;
746
747 out_ << "case ";
748 if (!EmitLiteral(selector.get())) {
749 return false;
750 }
751 out_ << ":";
752 }
753 }
754
755 out_ << " {" << std::endl;
756
757 increment_indent();
758
759 for (const auto& s : *(stmt->body())) {
760 if (!EmitStatement(s.get())) {
761 return false;
762 }
763 }
764
765 if (!last_is_break_or_fallthrough(stmt->body())) {
766 make_indent();
767 out_ << "break;" << std::endl;
768 }
769
770 decrement_indent();
771 make_indent();
772 out_ << "}" << std::endl;
773
774 return true;
775 }
776
EmitConstructor(ast::ConstructorExpression * expr)777 bool GeneratorImpl::EmitConstructor(ast::ConstructorExpression* expr) {
778 if (expr->IsScalarConstructor()) {
779 return EmitScalarConstructor(expr->AsScalarConstructor());
780 }
781 return EmitTypeConstructor(expr->AsTypeConstructor());
782 }
783
EmitContinue(ast::ContinueStatement *)784 bool GeneratorImpl::EmitContinue(ast::ContinueStatement*) {
785 make_indent();
786 out_ << "continue;" << std::endl;
787 return true;
788 }
789
EmitTypeConstructor(ast::TypeConstructorExpression * expr)790 bool GeneratorImpl::EmitTypeConstructor(ast::TypeConstructorExpression* expr) {
791 if (expr->type()->IsArray()) {
792 out_ << "{";
793 } else {
794 if (!EmitType(expr->type(), "")) {
795 return false;
796 }
797 out_ << "(";
798 }
799
800 // If the type constructor is empty then we need to construct with the zero
801 // value for all components.
802 if (expr->values().empty()) {
803 if (!EmitZeroValue(expr->type())) {
804 return false;
805 }
806 } else {
807 bool first = true;
808 for (const auto& e : expr->values()) {
809 if (!first) {
810 out_ << ", ";
811 }
812 first = false;
813
814 if (!EmitExpression(e.get())) {
815 return false;
816 }
817 }
818 }
819
820 if (expr->type()->IsArray()) {
821 out_ << "}";
822 } else {
823 out_ << ")";
824 }
825 return true;
826 }
827
EmitZeroValue(ast::type::Type * type)828 bool GeneratorImpl::EmitZeroValue(ast::type::Type* type) {
829 if (type->IsBool()) {
830 out_ << "false";
831 } else if (type->IsF32()) {
832 out_ << "0.0f";
833 } else if (type->IsI32()) {
834 out_ << "0";
835 } else if (type->IsU32()) {
836 out_ << "0u";
837 } else if (type->IsVector()) {
838 return EmitZeroValue(type->AsVector()->type());
839 } else if (type->IsMatrix()) {
840 return EmitZeroValue(type->AsMatrix()->type());
841 } else if (type->IsArray()) {
842 out_ << "{";
843 if (!EmitZeroValue(type->AsArray()->type())) {
844 return false;
845 }
846 out_ << "}";
847 } else if (type->IsStruct()) {
848 out_ << "{}";
849 } else {
850 error_ = "Invalid type for zero emission: " + type->type_name();
851 return false;
852 }
853 return true;
854 }
855
EmitScalarConstructor(ast::ScalarConstructorExpression * expr)856 bool GeneratorImpl::EmitScalarConstructor(
857 ast::ScalarConstructorExpression* expr) {
858 return EmitLiteral(expr->literal());
859 }
860
EmitLiteral(ast::Literal * lit)861 bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
862 if (lit->IsBool()) {
863 out_ << (lit->AsBool()->IsTrue() ? "true" : "false");
864 } else if (lit->IsFloat()) {
865 auto flags = out_.flags();
866 auto precision = out_.precision();
867
868 out_.flags(flags | std::ios_base::showpoint);
869 out_.precision(std::numeric_limits<float>::max_digits10);
870
871 out_ << lit->AsFloat()->value() << "f";
872
873 out_.precision(precision);
874 out_.flags(flags);
875 } else if (lit->IsSint()) {
876 out_ << lit->AsSint()->value();
877 } else if (lit->IsUint()) {
878 out_ << lit->AsUint()->value() << "u";
879 } else {
880 error_ = "unknown literal type";
881 return false;
882 }
883 return true;
884 }
885
EmitEntryPointData(ast::Function * func)886 bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
887 std::vector<std::pair<ast::Variable*, uint32_t>> in_locations;
888 std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
889 out_variables;
890 for (auto data : func->referenced_location_variables()) {
891 auto* var = data.first;
892 auto* deco = data.second;
893
894 if (var->storage_class() == ast::StorageClass::kInput) {
895 in_locations.push_back({var, deco->value()});
896 } else if (var->storage_class() == ast::StorageClass::kOutput) {
897 out_variables.push_back({var, deco});
898 }
899 }
900
901 for (auto data : func->referenced_builtin_variables()) {
902 auto* var = data.first;
903 auto* deco = data.second;
904
905 if (var->storage_class() == ast::StorageClass::kOutput) {
906 out_variables.push_back({var, deco});
907 }
908 }
909
910 if (!in_locations.empty()) {
911 auto in_struct_name =
912 generate_name(func->name() + "_" + kInStructNameSuffix);
913 auto in_var_name = generate_name(kTintStructInVarPrefix);
914 ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
915
916 make_indent();
917 out_ << "struct " << in_struct_name << " {" << std::endl;
918
919 increment_indent();
920
921 for (auto& data : in_locations) {
922 auto* var = data.first;
923 uint32_t loc = data.second;
924
925 make_indent();
926 if (!EmitType(var->type(), var->name())) {
927 return false;
928 }
929
930 out_ << " " << var->name() << " [[";
931 if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
932 out_ << "attribute(" << loc << ")";
933 } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
934 out_ << "user(locn" << loc << ")";
935 } else {
936 error_ = "invalid location variable for pipeline stage";
937 return false;
938 }
939 out_ << "]];" << std::endl;
940 }
941 decrement_indent();
942 make_indent();
943
944 out_ << "};" << std::endl << std::endl;
945 }
946
947 if (!out_variables.empty()) {
948 auto out_struct_name =
949 generate_name(func->name() + "_" + kOutStructNameSuffix);
950 auto out_var_name = generate_name(kTintStructOutVarPrefix);
951 ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name};
952
953 make_indent();
954 out_ << "struct " << out_struct_name << " {" << std::endl;
955
956 increment_indent();
957 for (auto& data : out_variables) {
958 auto* var = data.first;
959 auto* deco = data.second;
960
961 make_indent();
962 if (!EmitType(var->type(), var->name())) {
963 return false;
964 }
965
966 out_ << " " << var->name() << " [[";
967
968 if (deco->IsLocation()) {
969 auto loc = deco->AsLocation()->value();
970 if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
971 out_ << "user(locn" << loc << ")";
972 } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
973 out_ << "color(" << loc << ")";
974 } else {
975 error_ = "invalid location variable for pipeline stage";
976 return false;
977 }
978 } else if (deco->IsBuiltin()) {
979 auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
980 if (attr.empty()) {
981 error_ = "unsupported builtin";
982 return false;
983 }
984 out_ << attr;
985 } else {
986 error_ = "unsupported variable decoration for entry point output";
987 return false;
988 }
989 out_ << "]];" << std::endl;
990 }
991 decrement_indent();
992 make_indent();
993 out_ << "};" << std::endl << std::endl;
994 }
995
996 return true;
997 }
998
EmitExpression(ast::Expression * expr)999 bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
1000 if (expr->IsArrayAccessor()) {
1001 return EmitArrayAccessor(expr->AsArrayAccessor());
1002 }
1003 if (expr->IsBinary()) {
1004 return EmitBinary(expr->AsBinary());
1005 }
1006 if (expr->IsBitcast()) {
1007 return EmitBitcast(expr->AsBitcast());
1008 }
1009 if (expr->IsCall()) {
1010 return EmitCall(expr->AsCall());
1011 }
1012 if (expr->IsConstructor()) {
1013 return EmitConstructor(expr->AsConstructor());
1014 }
1015 if (expr->IsIdentifier()) {
1016 return EmitIdentifier(expr->AsIdentifier());
1017 }
1018 if (expr->IsMemberAccessor()) {
1019 return EmitMemberAccessor(expr->AsMemberAccessor());
1020 }
1021 if (expr->IsUnaryOp()) {
1022 return EmitUnaryOp(expr->AsUnaryOp());
1023 }
1024
1025 error_ = "unknown expression type: " + expr->str();
1026 return false;
1027 }
1028
EmitStage(ast::PipelineStage stage)1029 void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
1030 switch (stage) {
1031 case ast::PipelineStage::kFragment:
1032 out_ << "fragment";
1033 break;
1034 case ast::PipelineStage::kVertex:
1035 out_ << "vertex";
1036 break;
1037 case ast::PipelineStage::kCompute:
1038 out_ << "kernel";
1039 break;
1040 case ast::PipelineStage::kNone:
1041 break;
1042 }
1043 return;
1044 }
1045
has_referenced_in_var_needing_struct(ast::Function * func)1046 bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) {
1047 for (auto data : func->referenced_location_variables()) {
1048 auto* var = data.first;
1049 if (var->storage_class() == ast::StorageClass::kInput) {
1050 return true;
1051 }
1052 }
1053 return false;
1054 }
1055
has_referenced_out_var_needing_struct(ast::Function * func)1056 bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) {
1057 for (auto data : func->referenced_location_variables()) {
1058 auto* var = data.first;
1059 if (var->storage_class() == ast::StorageClass::kOutput) {
1060 return true;
1061 }
1062 }
1063
1064 for (auto data : func->referenced_builtin_variables()) {
1065 auto* var = data.first;
1066 if (var->storage_class() == ast::StorageClass::kOutput) {
1067 return true;
1068 }
1069 }
1070 return false;
1071 }
1072
has_referenced_var_needing_struct(ast::Function * func)1073 bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
1074 return has_referenced_in_var_needing_struct(func) ||
1075 has_referenced_out_var_needing_struct(func);
1076 }
1077
EmitFunction(ast::Function * func)1078 bool GeneratorImpl::EmitFunction(ast::Function* func) {
1079 make_indent();
1080
1081 // Entry points will be emitted later, skip for now.
1082 if (func->IsEntryPoint()) {
1083 return true;
1084 }
1085
1086 // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
1087 // entry points are the same we could generate a single struct and then have
1088 // this determine it's the same struct and just emit once.
1089 bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
1090 has_referenced_var_needing_struct(func);
1091
1092 if (emit_duplicate_functions) {
1093 for (const auto& ep_name : func->ancestor_entry_points()) {
1094 if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
1095 return false;
1096 }
1097 out_ << std::endl;
1098 }
1099 } else {
1100 // Emit as non-duplicated
1101 if (!EmitFunctionInternal(func, false, "")) {
1102 return false;
1103 }
1104 out_ << std::endl;
1105 }
1106
1107 return true;
1108 }
1109
EmitFunctionInternal(ast::Function * func,bool emit_duplicate_functions,const std::string & ep_name)1110 bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
1111 bool emit_duplicate_functions,
1112 const std::string& ep_name) {
1113 auto name = func->name();
1114
1115 if (!EmitType(func->return_type(), "")) {
1116 return false;
1117 }
1118
1119 out_ << " ";
1120 if (emit_duplicate_functions) {
1121 name = generate_name(name + "_" + ep_name);
1122 ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
1123 } else {
1124 name = namer_.NameFor(name);
1125 }
1126 out_ << name << "(";
1127
1128 bool first = true;
1129
1130 // If we're emitting duplicate functions that means the function takes
1131 // the stage_in or stage_out value from the entry point, emit them.
1132 //
1133 // We emit both of them if they're there regardless of if they're both used.
1134 if (emit_duplicate_functions) {
1135 auto in_it = ep_name_to_in_data_.find(ep_name);
1136 if (in_it != ep_name_to_in_data_.end()) {
1137 out_ << "thread " << in_it->second.struct_name << "& "
1138 << in_it->second.var_name;
1139 first = false;
1140 }
1141
1142 auto out_it = ep_name_to_out_data_.find(ep_name);
1143 if (out_it != ep_name_to_out_data_.end()) {
1144 if (!first) {
1145 out_ << ", ";
1146 }
1147 out_ << "thread " << out_it->second.struct_name << "& "
1148 << out_it->second.var_name;
1149 first = false;
1150 }
1151 }
1152
1153 for (const auto& data : func->referenced_builtin_variables()) {
1154 auto* var = data.first;
1155 if (var->storage_class() != ast::StorageClass::kInput) {
1156 continue;
1157 }
1158 if (!first) {
1159 out_ << ", ";
1160 }
1161 first = false;
1162
1163 out_ << "thread ";
1164 if (!EmitType(var->type(), "")) {
1165 return false;
1166 }
1167 out_ << "& " << var->name();
1168 }
1169
1170 for (const auto& data : func->referenced_uniform_variables()) {
1171 auto* var = data.first;
1172 if (!first) {
1173 out_ << ", ";
1174 }
1175 first = false;
1176
1177 out_ << "constant ";
1178 // TODO(dsinclair): Can arrays be uniform? If so, fix this ...
1179 if (!EmitType(var->type(), "")) {
1180 return false;
1181 }
1182 out_ << "& " << var->name();
1183 }
1184
1185 for (const auto& data : func->referenced_storagebuffer_variables()) {
1186 auto* var = data.first;
1187 if (!first) {
1188 out_ << ", ";
1189 }
1190 first = false;
1191
1192 if (!var->type()->IsAccessControl()) {
1193 error_ = "invalid type for storage buffer, expected access control";
1194 return false;
1195 }
1196 auto* ac = var->type()->AsAccessControl();
1197 if (ac->IsReadOnly()) {
1198 out_ << "const ";
1199 }
1200
1201 out_ << "device ";
1202 if (!EmitType(ac->type(), "")) {
1203 return false;
1204 }
1205 out_ << "& " << var->name();
1206 }
1207
1208 for (const auto& v : func->params()) {
1209 if (!first) {
1210 out_ << ", ";
1211 }
1212 first = false;
1213
1214 if (!EmitType(v->type(), v->name())) {
1215 return false;
1216 }
1217 // Array name is output as part of the type
1218 if (!v->type()->IsArray()) {
1219 out_ << " " << v->name();
1220 }
1221 }
1222
1223 out_ << ") ";
1224
1225 current_ep_name_ = ep_name;
1226
1227 if (!EmitBlockAndNewline(func->body())) {
1228 return false;
1229 }
1230
1231 current_ep_name_ = "";
1232
1233 return true;
1234 }
1235
builtin_to_attribute(ast::Builtin builtin) const1236 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
1237 switch (builtin) {
1238 case ast::Builtin::kPosition:
1239 return "position";
1240 case ast::Builtin::kVertexIdx:
1241 return "vertex_id";
1242 case ast::Builtin::kInstanceIdx:
1243 return "instance_id";
1244 case ast::Builtin::kFrontFacing:
1245 return "front_facing";
1246 case ast::Builtin::kFragCoord:
1247 return "position";
1248 case ast::Builtin::kFragDepth:
1249 return "depth(any)";
1250 case ast::Builtin::kLocalInvocationId:
1251 return "thread_position_in_threadgroup";
1252 case ast::Builtin::kLocalInvocationIdx:
1253 return "thread_index_in_threadgroup";
1254 case ast::Builtin::kGlobalInvocationId:
1255 return "thread_position_in_grid";
1256 default:
1257 break;
1258 }
1259 return "";
1260 }
1261
EmitEntryPointFunction(ast::Function * func)1262 bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
1263 make_indent();
1264
1265 current_ep_name_ = func->name();
1266
1267 EmitStage(func->pipeline_stage());
1268 out_ << " ";
1269
1270 // This is an entry point, the return type is the entry point output structure
1271 // if one exists, or void otherwise.
1272 auto out_data = ep_name_to_out_data_.find(current_ep_name_);
1273 bool has_out_data = out_data != ep_name_to_out_data_.end();
1274 if (has_out_data) {
1275 out_ << out_data->second.struct_name;
1276 } else {
1277 out_ << "void";
1278 }
1279 out_ << " " << namer_.NameFor(current_ep_name_) << "(";
1280
1281 bool first = true;
1282 auto in_data = ep_name_to_in_data_.find(current_ep_name_);
1283 if (in_data != ep_name_to_in_data_.end()) {
1284 out_ << in_data->second.struct_name << " " << in_data->second.var_name
1285 << " [[stage_in]]";
1286 first = false;
1287 }
1288
1289 for (auto data : func->referenced_builtin_variables()) {
1290 auto* var = data.first;
1291 if (var->storage_class() != ast::StorageClass::kInput) {
1292 continue;
1293 }
1294
1295 if (!first) {
1296 out_ << ", ";
1297 }
1298 first = false;
1299
1300 auto* builtin = data.second;
1301
1302 if (!EmitType(var->type(), "")) {
1303 return false;
1304 }
1305
1306 auto attr = builtin_to_attribute(builtin->value());
1307 if (attr.empty()) {
1308 error_ = "unknown builtin";
1309 return false;
1310 }
1311 out_ << " " << var->name() << " [[" << attr << "]]";
1312 }
1313
1314 for (auto data : func->referenced_uniform_variables()) {
1315 if (!first) {
1316 out_ << ", ";
1317 }
1318 first = false;
1319
1320 auto* var = data.first;
1321 // TODO(dsinclair): We're using the binding to make up the buffer number but
1322 // we should instead be using a provided mapping that uses both buffer and
1323 // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
1324 auto* binding = data.second.binding;
1325 if (binding == nullptr) {
1326 error_ = "unable to find binding information for uniform: " + var->name();
1327 return false;
1328 }
1329 // auto* set = data.second.set;
1330
1331 out_ << "constant ";
1332 // TODO(dsinclair): Can you have a uniform array? If so, this needs to be
1333 // updated to handle arrays property.
1334 if (!EmitType(var->type(), "")) {
1335 return false;
1336 }
1337 out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
1338 }
1339
1340 for (auto data : func->referenced_storagebuffer_variables()) {
1341 if (!first) {
1342 out_ << ", ";
1343 }
1344 first = false;
1345
1346 auto* var = data.first;
1347 // TODO(dsinclair): We're using the binding to make up the buffer number but
1348 // we should instead be using a provided mapping that uses both buffer and
1349 // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
1350 auto* binding = data.second.binding;
1351 // auto* set = data.second.set;
1352
1353 if (!var->type()->IsAccessControl()) {
1354 error_ = "invalid type for storage buffer, expected access control";
1355 return false;
1356 }
1357 auto* ac = var->type()->AsAccessControl();
1358 if (ac->IsReadOnly()) {
1359 out_ << "const ";
1360 }
1361
1362 out_ << "device ";
1363 if (!EmitType(ac->type(), "")) {
1364 return false;
1365 }
1366 out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
1367 }
1368
1369 out_ << ") {" << std::endl;
1370
1371 increment_indent();
1372
1373 if (has_out_data) {
1374 make_indent();
1375 out_ << out_data->second.struct_name << " " << out_data->second.var_name
1376 << " = {};" << std::endl;
1377 }
1378
1379 generating_entry_point_ = true;
1380 for (const auto& s : *(func->body())) {
1381 if (!EmitStatement(s.get())) {
1382 return false;
1383 }
1384 }
1385 generating_entry_point_ = false;
1386
1387 decrement_indent();
1388 make_indent();
1389 out_ << "}" << std::endl;
1390
1391 current_ep_name_ = "";
1392 return true;
1393 }
1394
global_is_in_struct(ast::Variable * var) const1395 bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
1396 bool in_or_out_struct_has_location =
1397 var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() &&
1398 (var->storage_class() == ast::StorageClass::kInput ||
1399 var->storage_class() == ast::StorageClass::kOutput);
1400 bool in_struct_has_builtin =
1401 var->IsDecorated() && var->AsDecorated()->HasBuiltinDecoration() &&
1402 var->storage_class() == ast::StorageClass::kOutput;
1403 return in_or_out_struct_has_location || in_struct_has_builtin;
1404 }
1405
EmitIdentifier(ast::IdentifierExpression * expr)1406 bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
1407 auto* ident = expr->AsIdentifier();
1408 ast::Variable* var = nullptr;
1409 if (global_variables_.get(ident->name(), &var)) {
1410 if (global_is_in_struct(var)) {
1411 auto var_type = var->storage_class() == ast::StorageClass::kInput
1412 ? VarType::kIn
1413 : VarType::kOut;
1414 auto name = current_ep_var_name(var_type);
1415 if (name.empty()) {
1416 error_ = "unable to find entry point data for variable";
1417 return false;
1418 }
1419 out_ << name << ".";
1420 }
1421 }
1422 out_ << namer_.NameFor(ident->name());
1423
1424 return true;
1425 }
1426
EmitLoop(ast::LoopStatement * stmt)1427 bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) {
1428 loop_emission_counter_++;
1429
1430 std::string guard = namer_.NameFor("tint_msl_is_first_" +
1431 std::to_string(loop_emission_counter_));
1432
1433 if (stmt->has_continuing()) {
1434 make_indent();
1435
1436 // Continuing variables get their own scope.
1437 out_ << "{" << std::endl;
1438 increment_indent();
1439
1440 make_indent();
1441 out_ << "bool " << guard << " = true;" << std::endl;
1442
1443 // A continuing block may use variables declared in the method body. As a
1444 // first pass, if we have a continuing, we pull all declarations outside
1445 // the for loop into the continuing scope. Then, the variable declarations
1446 // will be turned into assignments.
1447 for (const auto& s : *(stmt->body())) {
1448 if (!s->IsVariableDecl()) {
1449 continue;
1450 }
1451 if (!EmitVariable(s->AsVariableDecl()->variable(), true)) {
1452 return false;
1453 }
1454 }
1455 }
1456
1457 make_indent();
1458 out_ << "for(;;) {" << std::endl;
1459 increment_indent();
1460
1461 if (stmt->has_continuing()) {
1462 make_indent();
1463 out_ << "if (!" << guard << ") ";
1464
1465 if (!EmitBlockAndNewline(stmt->continuing())) {
1466 return false;
1467 }
1468
1469 make_indent();
1470 out_ << guard << " = false;" << std::endl;
1471 out_ << std::endl;
1472 }
1473
1474 for (const auto& s : *(stmt->body())) {
1475 // If we have a continuing block we've already emitted the variable
1476 // declaration before the loop, so treat it as an assignment.
1477 if (s->IsVariableDecl() && stmt->has_continuing()) {
1478 make_indent();
1479
1480 auto* var = s->AsVariableDecl()->variable();
1481 out_ << var->name() << " = ";
1482 if (var->constructor() != nullptr) {
1483 if (!EmitExpression(var->constructor())) {
1484 return false;
1485 }
1486 } else {
1487 if (!EmitZeroValue(var->type())) {
1488 return false;
1489 }
1490 }
1491 out_ << ";" << std::endl;
1492 continue;
1493 }
1494
1495 if (!EmitStatement(s.get())) {
1496 return false;
1497 }
1498 }
1499
1500 decrement_indent();
1501 make_indent();
1502 out_ << "}" << std::endl;
1503
1504 // Close the scope for any continuing variables.
1505 if (stmt->has_continuing()) {
1506 decrement_indent();
1507 make_indent();
1508 out_ << "}" << std::endl;
1509 }
1510
1511 return true;
1512 }
1513
EmitDiscard(ast::DiscardStatement *)1514 bool GeneratorImpl::EmitDiscard(ast::DiscardStatement*) {
1515 make_indent();
1516 // TODO(dsinclair): Verify this is correct when the discard semantics are
1517 // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
1518 out_ << "discard_fragment();" << std::endl;
1519 return true;
1520 }
1521
EmitElse(ast::ElseStatement * stmt)1522 bool GeneratorImpl::EmitElse(ast::ElseStatement* stmt) {
1523 if (stmt->HasCondition()) {
1524 out_ << " else if (";
1525 if (!EmitExpression(stmt->condition())) {
1526 return false;
1527 }
1528 out_ << ") ";
1529 } else {
1530 out_ << " else ";
1531 }
1532
1533 return EmitBlock(stmt->body());
1534 }
1535
EmitIf(ast::IfStatement * stmt)1536 bool GeneratorImpl::EmitIf(ast::IfStatement* stmt) {
1537 make_indent();
1538
1539 out_ << "if (";
1540 if (!EmitExpression(stmt->condition())) {
1541 return false;
1542 }
1543 out_ << ") ";
1544
1545 if (!EmitBlock(stmt->body())) {
1546 return false;
1547 }
1548
1549 for (const auto& e : stmt->else_statements()) {
1550 if (!EmitElse(e.get())) {
1551 return false;
1552 }
1553 }
1554 out_ << std::endl;
1555
1556 return true;
1557 }
1558
EmitMemberAccessor(ast::MemberAccessorExpression * expr)1559 bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) {
1560 if (!EmitExpression(expr->structure())) {
1561 return false;
1562 }
1563
1564 out_ << ".";
1565
1566 return EmitExpression(expr->member());
1567 }
1568
EmitReturn(ast::ReturnStatement * stmt)1569 bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
1570 make_indent();
1571
1572 out_ << "return";
1573
1574 if (generating_entry_point_) {
1575 auto out_data = ep_name_to_out_data_.find(current_ep_name_);
1576 if (out_data != ep_name_to_out_data_.end()) {
1577 out_ << " " << out_data->second.var_name;
1578 }
1579 } else if (stmt->has_value()) {
1580 out_ << " ";
1581 if (!EmitExpression(stmt->value())) {
1582 return false;
1583 }
1584 }
1585 out_ << ";" << std::endl;
1586 return true;
1587 }
1588
EmitBlock(const ast::BlockStatement * stmt)1589 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
1590 out_ << "{" << std::endl;
1591 increment_indent();
1592
1593 for (const auto& s : *stmt) {
1594 if (!EmitStatement(s.get())) {
1595 return false;
1596 }
1597 }
1598
1599 decrement_indent();
1600 make_indent();
1601 out_ << "}";
1602
1603 return true;
1604 }
1605
EmitBlockAndNewline(const ast::BlockStatement * stmt)1606 bool GeneratorImpl::EmitBlockAndNewline(const ast::BlockStatement* stmt) {
1607 const bool result = EmitBlock(stmt);
1608 if (result) {
1609 out_ << std::endl;
1610 }
1611 return result;
1612 }
1613
EmitIndentedBlockAndNewline(ast::BlockStatement * stmt)1614 bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) {
1615 make_indent();
1616 const bool result = EmitBlock(stmt);
1617 if (result) {
1618 out_ << std::endl;
1619 }
1620 return result;
1621 }
1622
EmitStatement(ast::Statement * stmt)1623 bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
1624 if (stmt->IsAssign()) {
1625 return EmitAssign(stmt->AsAssign());
1626 }
1627 if (stmt->IsBlock()) {
1628 return EmitIndentedBlockAndNewline(stmt->AsBlock());
1629 }
1630 if (stmt->IsBreak()) {
1631 return EmitBreak(stmt->AsBreak());
1632 }
1633 if (stmt->IsCall()) {
1634 make_indent();
1635 if (!EmitCall(stmt->AsCall()->expr())) {
1636 return false;
1637 }
1638 out_ << ";" << std::endl;
1639 return true;
1640 }
1641 if (stmt->IsContinue()) {
1642 return EmitContinue(stmt->AsContinue());
1643 }
1644 if (stmt->IsDiscard()) {
1645 return EmitDiscard(stmt->AsDiscard());
1646 }
1647 if (stmt->IsFallthrough()) {
1648 make_indent();
1649 out_ << "/* fallthrough */" << std::endl;
1650 return true;
1651 }
1652 if (stmt->IsIf()) {
1653 return EmitIf(stmt->AsIf());
1654 }
1655 if (stmt->IsLoop()) {
1656 return EmitLoop(stmt->AsLoop());
1657 }
1658 if (stmt->IsReturn()) {
1659 return EmitReturn(stmt->AsReturn());
1660 }
1661 if (stmt->IsSwitch()) {
1662 return EmitSwitch(stmt->AsSwitch());
1663 }
1664 if (stmt->IsVariableDecl()) {
1665 return EmitVariable(stmt->AsVariableDecl()->variable(), false);
1666 }
1667
1668 error_ = "unknown statement type: " + stmt->str();
1669 return false;
1670 }
1671
EmitSwitch(ast::SwitchStatement * stmt)1672 bool GeneratorImpl::EmitSwitch(ast::SwitchStatement* stmt) {
1673 make_indent();
1674
1675 out_ << "switch(";
1676 if (!EmitExpression(stmt->condition())) {
1677 return false;
1678 }
1679 out_ << ") {" << std::endl;
1680
1681 increment_indent();
1682
1683 for (const auto& s : stmt->body()) {
1684 if (!EmitCase(s.get())) {
1685 return false;
1686 }
1687 }
1688
1689 decrement_indent();
1690 make_indent();
1691 out_ << "}" << std::endl;
1692
1693 return true;
1694 }
1695
EmitType(ast::type::Type * type,const std::string & name)1696 bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
1697 if (type->IsAlias()) {
1698 auto* alias = type->AsAlias();
1699 out_ << namer_.NameFor(alias->name());
1700 } else if (type->IsArray()) {
1701 auto* ary = type->AsArray();
1702
1703 ast::type::Type* base_type = ary;
1704 std::vector<uint32_t> sizes;
1705 while (base_type->IsArray()) {
1706 if (base_type->AsArray()->IsRuntimeArray()) {
1707 sizes.push_back(1);
1708 } else {
1709 sizes.push_back(base_type->AsArray()->size());
1710 }
1711 base_type = base_type->AsArray()->type();
1712 }
1713 if (!EmitType(base_type, "")) {
1714 return false;
1715 }
1716 if (!name.empty()) {
1717 out_ << " " << namer_.NameFor(name);
1718 }
1719 for (uint32_t size : sizes) {
1720 out_ << "[" << size << "]";
1721 }
1722 } else if (type->IsBool()) {
1723 out_ << "bool";
1724 } else if (type->IsF32()) {
1725 out_ << "float";
1726 } else if (type->IsI32()) {
1727 out_ << "int";
1728 } else if (type->IsMatrix()) {
1729 auto* mat = type->AsMatrix();
1730 if (!EmitType(mat->type(), "")) {
1731 return false;
1732 }
1733 out_ << mat->columns() << "x" << mat->rows();
1734 } else if (type->IsPointer()) {
1735 auto* ptr = type->AsPointer();
1736 // TODO(dsinclair): Storage class?
1737 if (!EmitType(ptr->type(), "")) {
1738 return false;
1739 }
1740 out_ << "*";
1741 } else if (type->IsSampler()) {
1742 out_ << "sampler";
1743 } else if (type->IsStruct()) {
1744 // The struct type emits as just the name. The declaration would be emitted
1745 // as part of emitting the constructed types.
1746 out_ << type->AsStruct()->name();
1747 } else if (type->IsTexture()) {
1748 auto* tex = type->AsTexture();
1749
1750 if (tex->IsDepth()) {
1751 out_ << "depth";
1752 } else {
1753 out_ << "texture";
1754 }
1755
1756 switch (tex->dim()) {
1757 case ast::type::TextureDimension::k1d:
1758 out_ << "1d";
1759 break;
1760 case ast::type::TextureDimension::k1dArray:
1761 out_ << "1d_array";
1762 break;
1763 case ast::type::TextureDimension::k2d:
1764 out_ << "2d";
1765 break;
1766 case ast::type::TextureDimension::k2dArray:
1767 out_ << "2d_array";
1768 break;
1769 case ast::type::TextureDimension::k3d:
1770 out_ << "3d";
1771 break;
1772 case ast::type::TextureDimension::kCube:
1773 out_ << "cube";
1774 break;
1775 case ast::type::TextureDimension::kCubeArray:
1776 out_ << "cube_array";
1777 break;
1778 default:
1779 error_ = "Invalid texture dimensions";
1780 return false;
1781 }
1782 if (tex->IsMultisampled()) {
1783 out_ << "_ms";
1784 }
1785 out_ << "<";
1786 if (tex->IsDepth()) {
1787 out_ << "float, access::sample";
1788 } else if (tex->IsStorage()) {
1789 auto* storage = tex->AsStorage();
1790 if (!EmitType(storage->type(), "")) {
1791 return false;
1792 }
1793 out_ << ", access::";
1794 if (storage->access() == ast::AccessControl::kReadOnly) {
1795 out_ << "read";
1796 } else if (storage->access() == ast::AccessControl::kWriteOnly) {
1797 out_ << "write";
1798 } else {
1799 error_ = "Invalid access control for storage texture";
1800 return false;
1801 }
1802 } else if (tex->IsMultisampled()) {
1803 if (!EmitType(tex->AsMultisampled()->type(), "")) {
1804 return false;
1805 }
1806 out_ << ", access::sample";
1807 } else if (tex->IsSampled()) {
1808 if (!EmitType(tex->AsSampled()->type(), "")) {
1809 return false;
1810 }
1811 out_ << ", access::sample";
1812 } else {
1813 error_ = "invalid texture type";
1814 return false;
1815 }
1816 out_ << ">";
1817
1818 } else if (type->IsU32()) {
1819 out_ << "uint";
1820 } else if (type->IsVector()) {
1821 auto* vec = type->AsVector();
1822 if (!EmitType(vec->type(), "")) {
1823 return false;
1824 }
1825 out_ << vec->size();
1826 } else if (type->IsVoid()) {
1827 out_ << "void";
1828 } else {
1829 error_ = "unknown type in EmitType: " + type->type_name();
1830 return false;
1831 }
1832
1833 return true;
1834 }
1835
EmitStructType(const ast::type::StructType * str)1836 bool GeneratorImpl::EmitStructType(const ast::type::StructType* str) {
1837 // TODO(dsinclair): Block decoration?
1838 // if (str->impl()->decoration() != ast::StructDecoration::kNone) {
1839 // }
1840 out_ << "struct " << str->name() << " {" << std::endl;
1841
1842 increment_indent();
1843 uint32_t current_offset = 0;
1844 uint32_t pad_count = 0;
1845 for (const auto& mem : str->impl()->members()) {
1846 make_indent();
1847 for (const auto& deco : mem->decorations()) {
1848 if (deco->IsOffset()) {
1849 uint32_t offset = deco->AsOffset()->offset();
1850 if (offset != current_offset) {
1851 out_ << "int8_t pad_" << pad_count << "[" << (offset - current_offset)
1852 << "];" << std::endl;
1853 pad_count++;
1854 make_indent();
1855 }
1856 current_offset = offset;
1857 } else {
1858 error_ = "unsupported member decoration: " + deco->to_str();
1859 return false;
1860 }
1861 }
1862
1863 if (!EmitType(mem->type(), mem->name())) {
1864 return false;
1865 }
1866 auto size = calculate_alignment_size(mem->type());
1867 if (size == 0) {
1868 error_ = "unable to calculate byte size for: " + mem->type()->type_name();
1869 return false;
1870 }
1871 current_offset += size;
1872
1873 // Array member name will be output with the type
1874 if (!mem->type()->IsArray()) {
1875 out_ << " " << namer_.NameFor(mem->name());
1876 }
1877 out_ << ";" << std::endl;
1878 }
1879 decrement_indent();
1880 make_indent();
1881
1882 out_ << "};" << std::endl;
1883 return true;
1884 }
1885
EmitUnaryOp(ast::UnaryOpExpression * expr)1886 bool GeneratorImpl::EmitUnaryOp(ast::UnaryOpExpression* expr) {
1887 switch (expr->op()) {
1888 case ast::UnaryOp::kNot:
1889 out_ << "!";
1890 break;
1891 case ast::UnaryOp::kNegation:
1892 out_ << "-";
1893 break;
1894 }
1895 out_ << "(";
1896
1897 if (!EmitExpression(expr->expr())) {
1898 return false;
1899 }
1900
1901 out_ << ")";
1902
1903 return true;
1904 }
1905
EmitVariable(ast::Variable * var,bool skip_constructor)1906 bool GeneratorImpl::EmitVariable(ast::Variable* var, bool skip_constructor) {
1907 make_indent();
1908
1909 // TODO(dsinclair): Handle variable decorations
1910 if (var->IsDecorated()) {
1911 error_ = "Variable decorations are not handled yet";
1912 return false;
1913 }
1914
1915 if (var->is_const()) {
1916 out_ << "const ";
1917 }
1918 if (!EmitType(var->type(), var->name())) {
1919 return false;
1920 }
1921 if (!var->type()->IsArray()) {
1922 out_ << " " << var->name();
1923 }
1924
1925 if (!skip_constructor) {
1926 out_ << " = ";
1927 if (var->constructor() != nullptr) {
1928 if (!EmitExpression(var->constructor())) {
1929 return false;
1930 }
1931 } else if (var->storage_class() == ast::StorageClass::kPrivate ||
1932 var->storage_class() == ast::StorageClass::kFunction ||
1933 var->storage_class() == ast::StorageClass::kNone ||
1934 var->storage_class() == ast::StorageClass::kOutput) {
1935 if (!EmitZeroValue(var->type())) {
1936 return false;
1937 }
1938 }
1939 }
1940 out_ << ";" << std::endl;
1941
1942 return true;
1943 }
1944
EmitProgramConstVariable(const ast::Variable * var)1945 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
1946 make_indent();
1947
1948 if (var->IsDecorated() && !var->AsDecorated()->HasConstantIdDecoration()) {
1949 error_ = "Decorated const values not valid";
1950 return false;
1951 }
1952 if (!var->is_const()) {
1953 error_ = "Expected a const value";
1954 return false;
1955 }
1956
1957 out_ << "constant ";
1958 if (!EmitType(var->type(), var->name())) {
1959 return false;
1960 }
1961 if (!var->type()->IsArray()) {
1962 out_ << " " << var->name();
1963 }
1964
1965 if (var->IsDecorated() && var->AsDecorated()->HasConstantIdDecoration()) {
1966 out_ << " [[function_constant(" << var->AsDecorated()->constant_id()
1967 << ")]]";
1968 } else if (var->constructor() != nullptr) {
1969 out_ << " = ";
1970 if (!EmitExpression(var->constructor())) {
1971 return false;
1972 }
1973 }
1974 out_ << ";" << std::endl;
1975
1976 return true;
1977 }
1978
1979 } // namespace msl
1980 } // namespace writer
1981 } // namespace tint
1982