1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/msl/generator_impl.h"
16 
17 #include <algorithm>
18 #include <limits>
19 #include <utility>
20 #include <vector>
21 
22 #include "src/ast/array_accessor_expression.h"
23 #include "src/ast/assignment_statement.h"
24 #include "src/ast/binary_expression.h"
25 #include "src/ast/bitcast_expression.h"
26 #include "src/ast/block_statement.h"
27 #include "src/ast/bool_literal.h"
28 #include "src/ast/break_statement.h"
29 #include "src/ast/call_expression.h"
30 #include "src/ast/call_statement.h"
31 #include "src/ast/case_statement.h"
32 #include "src/ast/continue_statement.h"
33 #include "src/ast/decorated_variable.h"
34 #include "src/ast/else_statement.h"
35 #include "src/ast/float_literal.h"
36 #include "src/ast/function.h"
37 #include "src/ast/identifier_expression.h"
38 #include "src/ast/if_statement.h"
39 #include "src/ast/location_decoration.h"
40 #include "src/ast/loop_statement.h"
41 #include "src/ast/member_accessor_expression.h"
42 #include "src/ast/return_statement.h"
43 #include "src/ast/sint_literal.h"
44 #include "src/ast/struct_member_offset_decoration.h"
45 #include "src/ast/switch_statement.h"
46 #include "src/ast/type/access_control_type.h"
47 #include "src/ast/type/alias_type.h"
48 #include "src/ast/type/array_type.h"
49 #include "src/ast/type/bool_type.h"
50 #include "src/ast/type/depth_texture_type.h"
51 #include "src/ast/type/f32_type.h"
52 #include "src/ast/type/i32_type.h"
53 #include "src/ast/type/matrix_type.h"
54 #include "src/ast/type/multisampled_texture_type.h"
55 #include "src/ast/type/pointer_type.h"
56 #include "src/ast/type/sampled_texture_type.h"
57 #include "src/ast/type/sampler_type.h"
58 #include "src/ast/type/storage_texture_type.h"
59 #include "src/ast/type/struct_type.h"
60 #include "src/ast/type/u32_type.h"
61 #include "src/ast/type/vector_type.h"
62 #include "src/ast/type/void_type.h"
63 #include "src/ast/uint_literal.h"
64 #include "src/ast/unary_op_expression.h"
65 #include "src/ast/variable_decl_statement.h"
66 
67 namespace tint {
68 namespace writer {
69 namespace msl {
70 namespace {
71 
72 const char kInStructNameSuffix[] = "in";
73 const char kOutStructNameSuffix[] = "out";
74 const char kTintStructInVarPrefix[] = "tint_in";
75 const char kTintStructOutVarPrefix[] = "tint_out";
76 
last_is_break_or_fallthrough(const ast::BlockStatement * stmts)77 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
78   if (stmts->empty()) {
79     return false;
80   }
81 
82   return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
83 }
84 
adjust_for_alignment(uint32_t count,uint32_t alignment)85 uint32_t adjust_for_alignment(uint32_t count, uint32_t alignment) {
86   const auto spill = count % alignment;
87   if (spill == 0) {
88     return count;
89   }
90   return count + alignment - spill;
91 }
92 
93 }  // namespace
94 
GeneratorImpl(ast::Module * module)95 GeneratorImpl::GeneratorImpl(ast::Module* module) : module_(module) {}
96 
97 GeneratorImpl::~GeneratorImpl() = default;
98 
generate_name(const std::string & prefix)99 std::string GeneratorImpl::generate_name(const std::string& prefix) {
100   std::string name = prefix;
101   uint32_t i = 0;
102   while (namer_.IsMapped(name)) {
103     name = prefix + "_" + std::to_string(i);
104     ++i;
105   }
106   namer_.RegisterRemappedName(name);
107   return name;
108 }
109 
Generate()110 bool GeneratorImpl::Generate() {
111   out_ << "#include <metal_stdlib>" << std::endl << std::endl;
112 
113   for (const auto& global : module_->global_variables()) {
114     global_variables_.set(global->name(), global.get());
115   }
116 
117   for (auto* const ty : module_->constructed_types()) {
118     if (!EmitConstructedType(ty)) {
119       return false;
120     }
121   }
122   if (!module_->constructed_types().empty()) {
123     out_ << std::endl;
124   }
125 
126   for (const auto& var : module_->global_variables()) {
127     if (!var->is_const()) {
128       continue;
129     }
130     if (!EmitProgramConstVariable(var.get())) {
131       return false;
132     }
133   }
134 
135   // Make sure all entry point data is emitted before the entry point functions
136   for (const auto& func : module_->functions()) {
137     if (!func->IsEntryPoint()) {
138       continue;
139     }
140 
141     if (!EmitEntryPointData(func.get())) {
142       return false;
143     }
144   }
145 
146   for (const auto& func : module_->functions()) {
147     if (!EmitFunction(func.get())) {
148       return false;
149     }
150   }
151 
152   for (const auto& func : module_->functions()) {
153     if (!func->IsEntryPoint()) {
154       continue;
155     }
156     if (!EmitEntryPointFunction(func.get())) {
157       return false;
158     }
159     out_ << std::endl;
160   }
161 
162   return true;
163 }
164 
calculate_largest_alignment(ast::type::StructType * type)165 uint32_t GeneratorImpl::calculate_largest_alignment(
166     ast::type::StructType* type) {
167   auto* stct = type->AsStruct()->impl();
168   uint32_t largest_alignment = 0;
169   for (const auto& mem : stct->members()) {
170     auto align = calculate_alignment_size(mem->type());
171     if (align == 0) {
172       return 0;
173     }
174     if (!mem->type()->IsStruct()) {
175       largest_alignment = std::max(largest_alignment, align);
176     } else {
177       largest_alignment =
178           std::max(largest_alignment,
179                    calculate_largest_alignment(mem->type()->AsStruct()));
180     }
181   }
182   return largest_alignment;
183 }
184 
calculate_alignment_size(ast::type::Type * type)185 uint32_t GeneratorImpl::calculate_alignment_size(ast::type::Type* type) {
186   if (type->IsAlias()) {
187     return calculate_alignment_size(type->AsAlias()->type());
188   }
189   if (type->IsArray()) {
190     auto* ary = type->AsArray();
191     // TODO(dsinclair): Handle array stride and adjust for alignment.
192     uint32_t type_size = calculate_alignment_size(ary->type());
193     return ary->size() * type_size;
194   }
195   if (type->IsBool()) {
196     return 1;
197   }
198   if (type->IsPointer()) {
199     return 0;
200   }
201   if (type->IsF32() || type->IsI32() || type->IsU32()) {
202     return 4;
203   }
204   if (type->IsMatrix()) {
205     auto* mat = type->AsMatrix();
206     // TODO(dsinclair): Handle MatrixStride
207     // https://github.com/gpuweb/gpuweb/issues/773
208     uint32_t type_size = calculate_alignment_size(mat->type());
209     return mat->rows() * mat->columns() * type_size;
210   }
211   if (type->IsStruct()) {
212     auto* stct = type->AsStruct()->impl();
213     uint32_t count = 0;
214     uint32_t largest_alignment = 0;
215     // Offset decorations in WGSL must be in increasing order.
216     for (const auto& mem : stct->members()) {
217       for (const auto& deco : mem->decorations()) {
218         if (deco->IsOffset()) {
219           count = deco->AsOffset()->offset();
220         }
221       }
222       auto align = calculate_alignment_size(mem->type());
223       if (align == 0) {
224         return 0;
225       }
226       if (!mem->type()->IsStruct()) {
227         largest_alignment = std::max(largest_alignment, align);
228       } else {
229         largest_alignment =
230             std::max(largest_alignment,
231                      calculate_largest_alignment(mem->type()->AsStruct()));
232       }
233 
234       // Round up to the alignment size
235       count = adjust_for_alignment(count, align);
236       count += align;
237     }
238     // Round struct up to largest align size
239     count = adjust_for_alignment(count, largest_alignment);
240     return count;
241   }
242   if (type->IsVector()) {
243     auto* vec = type->AsVector();
244     uint32_t type_size = calculate_alignment_size(vec->type());
245     if (vec->size() == 2) {
246       return 2 * type_size;
247     }
248     return 4 * type_size;
249   }
250   return 0;
251 }
252 
EmitConstructedType(const ast::type::Type * ty)253 bool GeneratorImpl::EmitConstructedType(const ast::type::Type* ty) {
254   make_indent();
255 
256   if (ty->IsAlias()) {
257     auto* alias = ty->AsAlias();
258 
259     out_ << "typedef ";
260     if (!EmitType(alias->type(), "")) {
261       return false;
262     }
263     out_ << " " << namer_.NameFor(alias->name()) << ";" << std::endl;
264   } else if (ty->IsStruct()) {
265     if (!EmitStructType(ty->AsStruct())) {
266       return false;
267     }
268   } else {
269     error_ = "unknown alias type: " + ty->type_name();
270     return false;
271   }
272 
273   return true;
274 }
275 
EmitArrayAccessor(ast::ArrayAccessorExpression * expr)276 bool GeneratorImpl::EmitArrayAccessor(ast::ArrayAccessorExpression* expr) {
277   if (!EmitExpression(expr->array())) {
278     return false;
279   }
280   out_ << "[";
281 
282   if (!EmitExpression(expr->idx_expr())) {
283     return false;
284   }
285   out_ << "]";
286 
287   return true;
288 }
289 
EmitBitcast(ast::BitcastExpression * expr)290 bool GeneratorImpl::EmitBitcast(ast::BitcastExpression* expr) {
291   out_ << "as_type<";
292   if (!EmitType(expr->type(), "")) {
293     return false;
294   }
295 
296   out_ << ">(";
297   if (!EmitExpression(expr->expr())) {
298     return false;
299   }
300 
301   out_ << ")";
302   return true;
303 }
304 
EmitAssign(ast::AssignmentStatement * stmt)305 bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) {
306   make_indent();
307 
308   if (!EmitExpression(stmt->lhs())) {
309     return false;
310   }
311 
312   out_ << " = ";
313 
314   if (!EmitExpression(stmt->rhs())) {
315     return false;
316   }
317 
318   out_ << ";" << std::endl;
319 
320   return true;
321 }
322 
EmitBinary(ast::BinaryExpression * expr)323 bool GeneratorImpl::EmitBinary(ast::BinaryExpression* expr) {
324   out_ << "(";
325 
326   if (!EmitExpression(expr->lhs())) {
327     return false;
328   }
329   out_ << " ";
330 
331   switch (expr->op()) {
332     case ast::BinaryOp::kAnd:
333       out_ << "&";
334       break;
335     case ast::BinaryOp::kOr:
336       out_ << "|";
337       break;
338     case ast::BinaryOp::kXor:
339       out_ << "^";
340       break;
341     case ast::BinaryOp::kLogicalAnd:
342       out_ << "&&";
343       break;
344     case ast::BinaryOp::kLogicalOr:
345       out_ << "||";
346       break;
347     case ast::BinaryOp::kEqual:
348       out_ << "==";
349       break;
350     case ast::BinaryOp::kNotEqual:
351       out_ << "!=";
352       break;
353     case ast::BinaryOp::kLessThan:
354       out_ << "<";
355       break;
356     case ast::BinaryOp::kGreaterThan:
357       out_ << ">";
358       break;
359     case ast::BinaryOp::kLessThanEqual:
360       out_ << "<=";
361       break;
362     case ast::BinaryOp::kGreaterThanEqual:
363       out_ << ">=";
364       break;
365     case ast::BinaryOp::kShiftLeft:
366       out_ << "<<";
367       break;
368     case ast::BinaryOp::kShiftRight:
369       // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
370       // implementation-defined behaviour for negative LHS.  We may have to
371       // generate extra code to implement WGSL-specified behaviour for negative
372       // LHS.
373       out_ << R"(>>)";
374       break;
375 
376     case ast::BinaryOp::kAdd:
377       out_ << "+";
378       break;
379     case ast::BinaryOp::kSubtract:
380       out_ << "-";
381       break;
382     case ast::BinaryOp::kMultiply:
383       out_ << "*";
384       break;
385     case ast::BinaryOp::kDivide:
386       out_ << "/";
387       break;
388     case ast::BinaryOp::kModulo:
389       out_ << "%";
390       break;
391     case ast::BinaryOp::kNone:
392       error_ = "missing binary operation type";
393       return false;
394   }
395   out_ << " ";
396 
397   if (!EmitExpression(expr->rhs())) {
398     return false;
399   }
400 
401   out_ << ")";
402   return true;
403 }
404 
EmitBreak(ast::BreakStatement *)405 bool GeneratorImpl::EmitBreak(ast::BreakStatement*) {
406   make_indent();
407   out_ << "break;" << std::endl;
408   return true;
409 }
410 
current_ep_var_name(VarType type)411 std::string GeneratorImpl::current_ep_var_name(VarType type) {
412   std::string name = "";
413   switch (type) {
414     case VarType::kIn: {
415       auto in_it = ep_name_to_in_data_.find(current_ep_name_);
416       if (in_it != ep_name_to_in_data_.end()) {
417         name = in_it->second.var_name;
418       }
419       break;
420     }
421     case VarType::kOut: {
422       auto out_it = ep_name_to_out_data_.find(current_ep_name_);
423       if (out_it != ep_name_to_out_data_.end()) {
424         name = out_it->second.var_name;
425       }
426       break;
427     }
428   }
429   return name;
430 }
431 
generate_intrinsic_name(ast::Intrinsic intrinsic)432 std::string GeneratorImpl::generate_intrinsic_name(ast::Intrinsic intrinsic) {
433   if (intrinsic == ast::Intrinsic::kAny) {
434     return "any";
435   }
436   if (intrinsic == ast::Intrinsic::kAll) {
437     return "all";
438   }
439   if (intrinsic == ast::Intrinsic::kCountOneBits) {
440     return "popcount";
441   }
442   if (intrinsic == ast::Intrinsic::kDot) {
443     return "dot";
444   }
445   if (intrinsic == ast::Intrinsic::kDpdy ||
446       intrinsic == ast::Intrinsic::kDpdyFine ||
447       intrinsic == ast::Intrinsic::kDpdyCoarse) {
448     return "dfdy";
449   }
450   if (intrinsic == ast::Intrinsic::kDpdx ||
451       intrinsic == ast::Intrinsic::kDpdxFine ||
452       intrinsic == ast::Intrinsic::kDpdxCoarse) {
453     return "dfdx";
454   }
455   if (intrinsic == ast::Intrinsic::kFwidth ||
456       intrinsic == ast::Intrinsic::kFwidthFine ||
457       intrinsic == ast::Intrinsic::kFwidthCoarse) {
458     return "fwidth";
459   }
460   if (intrinsic == ast::Intrinsic::kIsFinite) {
461     return "isfinite";
462   }
463   if (intrinsic == ast::Intrinsic::kIsInf) {
464     return "isinf";
465   }
466   if (intrinsic == ast::Intrinsic::kIsNan) {
467     return "isnan";
468   }
469   if (intrinsic == ast::Intrinsic::kIsNormal) {
470     return "isnormal";
471   }
472   if (intrinsic == ast::Intrinsic::kReverseBits) {
473     return "reverse_bits";
474   }
475   if (intrinsic == ast::Intrinsic::kSelect) {
476     return "select";
477   }
478   return "";
479 }
480 
EmitCall(ast::CallExpression * expr)481 bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
482   if (!expr->func()->IsIdentifier()) {
483     error_ = "invalid function name";
484     return 0;
485   }
486 
487   auto* ident = expr->func()->AsIdentifier();
488   if (ident->IsIntrinsic()) {
489     const auto& params = expr->params();
490     if (ident->intrinsic() == ast::Intrinsic::kOuterProduct) {
491       error_ = "outer_product not supported yet";
492       return false;
493       // TODO(dsinclair): This gets tricky. We need to generate two variables to
494       // hold the outer_product expressions, but we maybe inside an expression
495       // ourselves. So, this will need to, possibly, output the variables
496       // _before_ the expression which contains the outer product.
497       //
498       // This then has the follow on, what if we have `(false &&
499       // outer_product())` in that case, we shouldn't evaluate the expressions
500       // at all because of short circuting.
501       //
502       // So .... this turns out to be hard ...
503 
504       // // We create variables to hold the two parameters in case they're
505       // // function calls with side effects.
506       // auto* param0 = param[0].get();
507       // auto* name0 = generate_name("outer_product_expr_0");
508 
509       // auto* param1 = param[1].get();
510       // auto* name1 = generate_name("outer_product_expr_1");
511 
512       // make_indent();
513       // if (!EmitType(expr->result_type(), "")) {
514       //   return false;
515       // }
516       // out_ << "(";
517 
518       // auto param1_type = params[1]->result_type()->UnwrapPtrIfNeeded();
519       // if (!param1_type->IsVector()) {
520       //   error_ = "invalid param type in outer_product got: " +
521       //            param1_type->type_name();
522       //   return false;
523       // }
524 
525       // for (uint32_t i = 0; i < param1_type->AsVector()->size(); ++i) {
526       //   if (i > 0) {
527       //     out_ << ", ";
528       //   }
529 
530       //   if (!EmitExpression(params[0].get())) {
531       //     return false;
532       //   }
533       //   out_ << " * ";
534 
535       //   if (!EmitExpression(params[1].get())) {
536       //     return false;
537       //   }
538       //   out_ << "[" << i << "]";
539       // }
540 
541       // out_ << ")";
542     } else {
543       auto name = generate_intrinsic_name(ident->intrinsic());
544       if (name.empty()) {
545         if (ast::intrinsic::IsTextureIntrinsic(ident->intrinsic())) {
546           error_ = "Textures not implemented yet";
547           return false;
548         }
549         name = generate_builtin_name(ident);
550         if (name.empty()) {
551           return false;
552         }
553       }
554 
555       make_indent();
556       out_ << name << "(";
557 
558       bool first = true;
559       for (const auto& param : params) {
560         if (!first) {
561           out_ << ", ";
562         }
563         first = false;
564 
565         if (!EmitExpression(param.get())) {
566           return false;
567         }
568       }
569 
570       out_ << ")";
571     }
572     return true;
573   }
574 
575   auto name = ident->name();
576   auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
577   if (it != ep_func_name_remapped_.end()) {
578     name = it->second;
579   }
580 
581   auto* func = module_->FindFunctionByName(ident->name());
582   if (func == nullptr) {
583     error_ = "Unable to find function: " + name;
584     return false;
585   }
586 
587   out_ << name << "(";
588 
589   bool first = true;
590   if (has_referenced_in_var_needing_struct(func)) {
591     auto var_name = current_ep_var_name(VarType::kIn);
592     if (!var_name.empty()) {
593       out_ << var_name;
594       first = false;
595     }
596   }
597   if (has_referenced_out_var_needing_struct(func)) {
598     auto var_name = current_ep_var_name(VarType::kOut);
599     if (!var_name.empty()) {
600       if (!first) {
601         out_ << ", ";
602       }
603       first = false;
604       out_ << var_name;
605     }
606   }
607 
608   for (const auto& data : func->referenced_builtin_variables()) {
609     auto* var = data.first;
610     if (var->storage_class() != ast::StorageClass::kInput) {
611       continue;
612     }
613     if (!first) {
614       out_ << ", ";
615     }
616     first = false;
617     out_ << var->name();
618   }
619 
620   for (const auto& data : func->referenced_uniform_variables()) {
621     auto* var = data.first;
622     if (!first) {
623       out_ << ", ";
624     }
625     first = false;
626     out_ << var->name();
627   }
628 
629   for (const auto& data : func->referenced_storagebuffer_variables()) {
630     auto* var = data.first;
631     if (!first) {
632       out_ << ", ";
633     }
634     first = false;
635     out_ << var->name();
636   }
637 
638   const auto& params = expr->params();
639   for (const auto& param : params) {
640     if (!first) {
641       out_ << ", ";
642     }
643     first = false;
644 
645     if (!EmitExpression(param.get())) {
646       return false;
647     }
648   }
649 
650   out_ << ")";
651 
652   return true;
653 }
654 
generate_builtin_name(ast::IdentifierExpression * ident)655 std::string GeneratorImpl::generate_builtin_name(
656     ast::IdentifierExpression* ident) {
657   std::string out = "metal::";
658   switch (ident->intrinsic()) {
659     case ast::Intrinsic::kAcos:
660     case ast::Intrinsic::kAsin:
661     case ast::Intrinsic::kAtan:
662     case ast::Intrinsic::kAtan2:
663     case ast::Intrinsic::kCeil:
664     case ast::Intrinsic::kCos:
665     case ast::Intrinsic::kCosh:
666     case ast::Intrinsic::kCross:
667     case ast::Intrinsic::kDeterminant:
668     case ast::Intrinsic::kDistance:
669     case ast::Intrinsic::kExp:
670     case ast::Intrinsic::kExp2:
671     case ast::Intrinsic::kFloor:
672     case ast::Intrinsic::kFma:
673     case ast::Intrinsic::kFract:
674     case ast::Intrinsic::kLength:
675     case ast::Intrinsic::kLog:
676     case ast::Intrinsic::kLog2:
677     case ast::Intrinsic::kMix:
678     case ast::Intrinsic::kNormalize:
679     case ast::Intrinsic::kPow:
680     case ast::Intrinsic::kReflect:
681     case ast::Intrinsic::kRound:
682     case ast::Intrinsic::kSin:
683     case ast::Intrinsic::kSinh:
684     case ast::Intrinsic::kSqrt:
685     case ast::Intrinsic::kStep:
686     case ast::Intrinsic::kTan:
687     case ast::Intrinsic::kTanh:
688     case ast::Intrinsic::kTrunc:
689     case ast::Intrinsic::kSign:
690     case ast::Intrinsic::kClamp:
691       out += ident->name();
692       break;
693     case ast::Intrinsic::kAbs:
694       if (ident->result_type()->IsF32()) {
695         out += "fabs";
696       } else if (ident->result_type()->IsU32() ||
697                  ident->result_type()->IsI32()) {
698         out += "abs";
699       }
700       break;
701     case ast::Intrinsic::kMax:
702       if (ident->result_type()->IsF32()) {
703         out += "fmax";
704       } else if (ident->result_type()->IsU32() ||
705                  ident->result_type()->IsI32()) {
706         out += "max";
707       }
708       break;
709     case ast::Intrinsic::kMin:
710       if (ident->result_type()->IsF32()) {
711         out += "fmin";
712       } else if (ident->result_type()->IsU32() ||
713                  ident->result_type()->IsI32()) {
714         out += "min";
715       }
716       break;
717     case ast::Intrinsic::kFaceForward:
718       out += "faceforward";
719       break;
720     case ast::Intrinsic::kSmoothStep:
721       out += "smoothstep";
722       break;
723     case ast::Intrinsic::kInverseSqrt:
724       out += "rsqrt";
725       break;
726     default:
727       error_ = "Unknown import method: " + ident->name();
728       return "";
729   }
730   return out;
731 }
732 
EmitCase(ast::CaseStatement * stmt)733 bool GeneratorImpl::EmitCase(ast::CaseStatement* stmt) {
734   make_indent();
735 
736   if (stmt->IsDefault()) {
737     out_ << "default:";
738   } else {
739     bool first = true;
740     for (const auto& selector : stmt->selectors()) {
741       if (!first) {
742         out_ << std::endl;
743         make_indent();
744       }
745       first = false;
746 
747       out_ << "case ";
748       if (!EmitLiteral(selector.get())) {
749         return false;
750       }
751       out_ << ":";
752     }
753   }
754 
755   out_ << " {" << std::endl;
756 
757   increment_indent();
758 
759   for (const auto& s : *(stmt->body())) {
760     if (!EmitStatement(s.get())) {
761       return false;
762     }
763   }
764 
765   if (!last_is_break_or_fallthrough(stmt->body())) {
766     make_indent();
767     out_ << "break;" << std::endl;
768   }
769 
770   decrement_indent();
771   make_indent();
772   out_ << "}" << std::endl;
773 
774   return true;
775 }
776 
EmitConstructor(ast::ConstructorExpression * expr)777 bool GeneratorImpl::EmitConstructor(ast::ConstructorExpression* expr) {
778   if (expr->IsScalarConstructor()) {
779     return EmitScalarConstructor(expr->AsScalarConstructor());
780   }
781   return EmitTypeConstructor(expr->AsTypeConstructor());
782 }
783 
EmitContinue(ast::ContinueStatement *)784 bool GeneratorImpl::EmitContinue(ast::ContinueStatement*) {
785   make_indent();
786   out_ << "continue;" << std::endl;
787   return true;
788 }
789 
EmitTypeConstructor(ast::TypeConstructorExpression * expr)790 bool GeneratorImpl::EmitTypeConstructor(ast::TypeConstructorExpression* expr) {
791   if (expr->type()->IsArray()) {
792     out_ << "{";
793   } else {
794     if (!EmitType(expr->type(), "")) {
795       return false;
796     }
797     out_ << "(";
798   }
799 
800   // If the type constructor is empty then we need to construct with the zero
801   // value for all components.
802   if (expr->values().empty()) {
803     if (!EmitZeroValue(expr->type())) {
804       return false;
805     }
806   } else {
807     bool first = true;
808     for (const auto& e : expr->values()) {
809       if (!first) {
810         out_ << ", ";
811       }
812       first = false;
813 
814       if (!EmitExpression(e.get())) {
815         return false;
816       }
817     }
818   }
819 
820   if (expr->type()->IsArray()) {
821     out_ << "}";
822   } else {
823     out_ << ")";
824   }
825   return true;
826 }
827 
EmitZeroValue(ast::type::Type * type)828 bool GeneratorImpl::EmitZeroValue(ast::type::Type* type) {
829   if (type->IsBool()) {
830     out_ << "false";
831   } else if (type->IsF32()) {
832     out_ << "0.0f";
833   } else if (type->IsI32()) {
834     out_ << "0";
835   } else if (type->IsU32()) {
836     out_ << "0u";
837   } else if (type->IsVector()) {
838     return EmitZeroValue(type->AsVector()->type());
839   } else if (type->IsMatrix()) {
840     return EmitZeroValue(type->AsMatrix()->type());
841   } else if (type->IsArray()) {
842     out_ << "{";
843     if (!EmitZeroValue(type->AsArray()->type())) {
844       return false;
845     }
846     out_ << "}";
847   } else if (type->IsStruct()) {
848     out_ << "{}";
849   } else {
850     error_ = "Invalid type for zero emission: " + type->type_name();
851     return false;
852   }
853   return true;
854 }
855 
EmitScalarConstructor(ast::ScalarConstructorExpression * expr)856 bool GeneratorImpl::EmitScalarConstructor(
857     ast::ScalarConstructorExpression* expr) {
858   return EmitLiteral(expr->literal());
859 }
860 
EmitLiteral(ast::Literal * lit)861 bool GeneratorImpl::EmitLiteral(ast::Literal* lit) {
862   if (lit->IsBool()) {
863     out_ << (lit->AsBool()->IsTrue() ? "true" : "false");
864   } else if (lit->IsFloat()) {
865     auto flags = out_.flags();
866     auto precision = out_.precision();
867 
868     out_.flags(flags | std::ios_base::showpoint);
869     out_.precision(std::numeric_limits<float>::max_digits10);
870 
871     out_ << lit->AsFloat()->value() << "f";
872 
873     out_.precision(precision);
874     out_.flags(flags);
875   } else if (lit->IsSint()) {
876     out_ << lit->AsSint()->value();
877   } else if (lit->IsUint()) {
878     out_ << lit->AsUint()->value() << "u";
879   } else {
880     error_ = "unknown literal type";
881     return false;
882   }
883   return true;
884 }
885 
EmitEntryPointData(ast::Function * func)886 bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
887   std::vector<std::pair<ast::Variable*, uint32_t>> in_locations;
888   std::vector<std::pair<ast::Variable*, ast::VariableDecoration*>>
889       out_variables;
890   for (auto data : func->referenced_location_variables()) {
891     auto* var = data.first;
892     auto* deco = data.second;
893 
894     if (var->storage_class() == ast::StorageClass::kInput) {
895       in_locations.push_back({var, deco->value()});
896     } else if (var->storage_class() == ast::StorageClass::kOutput) {
897       out_variables.push_back({var, deco});
898     }
899   }
900 
901   for (auto data : func->referenced_builtin_variables()) {
902     auto* var = data.first;
903     auto* deco = data.second;
904 
905     if (var->storage_class() == ast::StorageClass::kOutput) {
906       out_variables.push_back({var, deco});
907     }
908   }
909 
910   if (!in_locations.empty()) {
911     auto in_struct_name =
912         generate_name(func->name() + "_" + kInStructNameSuffix);
913     auto in_var_name = generate_name(kTintStructInVarPrefix);
914     ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
915 
916     make_indent();
917     out_ << "struct " << in_struct_name << " {" << std::endl;
918 
919     increment_indent();
920 
921     for (auto& data : in_locations) {
922       auto* var = data.first;
923       uint32_t loc = data.second;
924 
925       make_indent();
926       if (!EmitType(var->type(), var->name())) {
927         return false;
928       }
929 
930       out_ << " " << var->name() << " [[";
931       if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
932         out_ << "attribute(" << loc << ")";
933       } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
934         out_ << "user(locn" << loc << ")";
935       } else {
936         error_ = "invalid location variable for pipeline stage";
937         return false;
938       }
939       out_ << "]];" << std::endl;
940     }
941     decrement_indent();
942     make_indent();
943 
944     out_ << "};" << std::endl << std::endl;
945   }
946 
947   if (!out_variables.empty()) {
948     auto out_struct_name =
949         generate_name(func->name() + "_" + kOutStructNameSuffix);
950     auto out_var_name = generate_name(kTintStructOutVarPrefix);
951     ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name};
952 
953     make_indent();
954     out_ << "struct " << out_struct_name << " {" << std::endl;
955 
956     increment_indent();
957     for (auto& data : out_variables) {
958       auto* var = data.first;
959       auto* deco = data.second;
960 
961       make_indent();
962       if (!EmitType(var->type(), var->name())) {
963         return false;
964       }
965 
966       out_ << " " << var->name() << " [[";
967 
968       if (deco->IsLocation()) {
969         auto loc = deco->AsLocation()->value();
970         if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
971           out_ << "user(locn" << loc << ")";
972         } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
973           out_ << "color(" << loc << ")";
974         } else {
975           error_ = "invalid location variable for pipeline stage";
976           return false;
977         }
978       } else if (deco->IsBuiltin()) {
979         auto attr = builtin_to_attribute(deco->AsBuiltin()->value());
980         if (attr.empty()) {
981           error_ = "unsupported builtin";
982           return false;
983         }
984         out_ << attr;
985       } else {
986         error_ = "unsupported variable decoration for entry point output";
987         return false;
988       }
989       out_ << "]];" << std::endl;
990     }
991     decrement_indent();
992     make_indent();
993     out_ << "};" << std::endl << std::endl;
994   }
995 
996   return true;
997 }
998 
EmitExpression(ast::Expression * expr)999 bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
1000   if (expr->IsArrayAccessor()) {
1001     return EmitArrayAccessor(expr->AsArrayAccessor());
1002   }
1003   if (expr->IsBinary()) {
1004     return EmitBinary(expr->AsBinary());
1005   }
1006   if (expr->IsBitcast()) {
1007     return EmitBitcast(expr->AsBitcast());
1008   }
1009   if (expr->IsCall()) {
1010     return EmitCall(expr->AsCall());
1011   }
1012   if (expr->IsConstructor()) {
1013     return EmitConstructor(expr->AsConstructor());
1014   }
1015   if (expr->IsIdentifier()) {
1016     return EmitIdentifier(expr->AsIdentifier());
1017   }
1018   if (expr->IsMemberAccessor()) {
1019     return EmitMemberAccessor(expr->AsMemberAccessor());
1020   }
1021   if (expr->IsUnaryOp()) {
1022     return EmitUnaryOp(expr->AsUnaryOp());
1023   }
1024 
1025   error_ = "unknown expression type: " + expr->str();
1026   return false;
1027 }
1028 
EmitStage(ast::PipelineStage stage)1029 void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
1030   switch (stage) {
1031     case ast::PipelineStage::kFragment:
1032       out_ << "fragment";
1033       break;
1034     case ast::PipelineStage::kVertex:
1035       out_ << "vertex";
1036       break;
1037     case ast::PipelineStage::kCompute:
1038       out_ << "kernel";
1039       break;
1040     case ast::PipelineStage::kNone:
1041       break;
1042   }
1043   return;
1044 }
1045 
has_referenced_in_var_needing_struct(ast::Function * func)1046 bool GeneratorImpl::has_referenced_in_var_needing_struct(ast::Function* func) {
1047   for (auto data : func->referenced_location_variables()) {
1048     auto* var = data.first;
1049     if (var->storage_class() == ast::StorageClass::kInput) {
1050       return true;
1051     }
1052   }
1053   return false;
1054 }
1055 
has_referenced_out_var_needing_struct(ast::Function * func)1056 bool GeneratorImpl::has_referenced_out_var_needing_struct(ast::Function* func) {
1057   for (auto data : func->referenced_location_variables()) {
1058     auto* var = data.first;
1059     if (var->storage_class() == ast::StorageClass::kOutput) {
1060       return true;
1061     }
1062   }
1063 
1064   for (auto data : func->referenced_builtin_variables()) {
1065     auto* var = data.first;
1066     if (var->storage_class() == ast::StorageClass::kOutput) {
1067       return true;
1068     }
1069   }
1070   return false;
1071 }
1072 
has_referenced_var_needing_struct(ast::Function * func)1073 bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) {
1074   return has_referenced_in_var_needing_struct(func) ||
1075          has_referenced_out_var_needing_struct(func);
1076 }
1077 
EmitFunction(ast::Function * func)1078 bool GeneratorImpl::EmitFunction(ast::Function* func) {
1079   make_indent();
1080 
1081   // Entry points will be emitted later, skip for now.
1082   if (func->IsEntryPoint()) {
1083     return true;
1084   }
1085 
1086   // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
1087   // entry points are the same we could generate a single struct and then have
1088   // this determine it's the same struct and just emit once.
1089   bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 &&
1090                                   has_referenced_var_needing_struct(func);
1091 
1092   if (emit_duplicate_functions) {
1093     for (const auto& ep_name : func->ancestor_entry_points()) {
1094       if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
1095         return false;
1096       }
1097       out_ << std::endl;
1098     }
1099   } else {
1100     // Emit as non-duplicated
1101     if (!EmitFunctionInternal(func, false, "")) {
1102       return false;
1103     }
1104     out_ << std::endl;
1105   }
1106 
1107   return true;
1108 }
1109 
EmitFunctionInternal(ast::Function * func,bool emit_duplicate_functions,const std::string & ep_name)1110 bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
1111                                          bool emit_duplicate_functions,
1112                                          const std::string& ep_name) {
1113   auto name = func->name();
1114 
1115   if (!EmitType(func->return_type(), "")) {
1116     return false;
1117   }
1118 
1119   out_ << " ";
1120   if (emit_duplicate_functions) {
1121     name = generate_name(name + "_" + ep_name);
1122     ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
1123   } else {
1124     name = namer_.NameFor(name);
1125   }
1126   out_ << name << "(";
1127 
1128   bool first = true;
1129 
1130   // If we're emitting duplicate functions that means the function takes
1131   // the stage_in or stage_out value from the entry point, emit them.
1132   //
1133   // We emit both of them if they're there regardless of if they're both used.
1134   if (emit_duplicate_functions) {
1135     auto in_it = ep_name_to_in_data_.find(ep_name);
1136     if (in_it != ep_name_to_in_data_.end()) {
1137       out_ << "thread " << in_it->second.struct_name << "& "
1138            << in_it->second.var_name;
1139       first = false;
1140     }
1141 
1142     auto out_it = ep_name_to_out_data_.find(ep_name);
1143     if (out_it != ep_name_to_out_data_.end()) {
1144       if (!first) {
1145         out_ << ", ";
1146       }
1147       out_ << "thread " << out_it->second.struct_name << "& "
1148            << out_it->second.var_name;
1149       first = false;
1150     }
1151   }
1152 
1153   for (const auto& data : func->referenced_builtin_variables()) {
1154     auto* var = data.first;
1155     if (var->storage_class() != ast::StorageClass::kInput) {
1156       continue;
1157     }
1158     if (!first) {
1159       out_ << ", ";
1160     }
1161     first = false;
1162 
1163     out_ << "thread ";
1164     if (!EmitType(var->type(), "")) {
1165       return false;
1166     }
1167     out_ << "& " << var->name();
1168   }
1169 
1170   for (const auto& data : func->referenced_uniform_variables()) {
1171     auto* var = data.first;
1172     if (!first) {
1173       out_ << ", ";
1174     }
1175     first = false;
1176 
1177     out_ << "constant ";
1178     // TODO(dsinclair): Can arrays be uniform? If so, fix this ...
1179     if (!EmitType(var->type(), "")) {
1180       return false;
1181     }
1182     out_ << "& " << var->name();
1183   }
1184 
1185   for (const auto& data : func->referenced_storagebuffer_variables()) {
1186     auto* var = data.first;
1187     if (!first) {
1188       out_ << ", ";
1189     }
1190     first = false;
1191 
1192     if (!var->type()->IsAccessControl()) {
1193       error_ = "invalid type for storage buffer, expected access control";
1194       return false;
1195     }
1196     auto* ac = var->type()->AsAccessControl();
1197     if (ac->IsReadOnly()) {
1198       out_ << "const ";
1199     }
1200 
1201     out_ << "device ";
1202     if (!EmitType(ac->type(), "")) {
1203       return false;
1204     }
1205     out_ << "& " << var->name();
1206   }
1207 
1208   for (const auto& v : func->params()) {
1209     if (!first) {
1210       out_ << ", ";
1211     }
1212     first = false;
1213 
1214     if (!EmitType(v->type(), v->name())) {
1215       return false;
1216     }
1217     // Array name is output as part of the type
1218     if (!v->type()->IsArray()) {
1219       out_ << " " << v->name();
1220     }
1221   }
1222 
1223   out_ << ") ";
1224 
1225   current_ep_name_ = ep_name;
1226 
1227   if (!EmitBlockAndNewline(func->body())) {
1228     return false;
1229   }
1230 
1231   current_ep_name_ = "";
1232 
1233   return true;
1234 }
1235 
builtin_to_attribute(ast::Builtin builtin) const1236 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
1237   switch (builtin) {
1238     case ast::Builtin::kPosition:
1239       return "position";
1240     case ast::Builtin::kVertexIdx:
1241       return "vertex_id";
1242     case ast::Builtin::kInstanceIdx:
1243       return "instance_id";
1244     case ast::Builtin::kFrontFacing:
1245       return "front_facing";
1246     case ast::Builtin::kFragCoord:
1247       return "position";
1248     case ast::Builtin::kFragDepth:
1249       return "depth(any)";
1250     case ast::Builtin::kLocalInvocationId:
1251       return "thread_position_in_threadgroup";
1252     case ast::Builtin::kLocalInvocationIdx:
1253       return "thread_index_in_threadgroup";
1254     case ast::Builtin::kGlobalInvocationId:
1255       return "thread_position_in_grid";
1256     default:
1257       break;
1258   }
1259   return "";
1260 }
1261 
EmitEntryPointFunction(ast::Function * func)1262 bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
1263   make_indent();
1264 
1265   current_ep_name_ = func->name();
1266 
1267   EmitStage(func->pipeline_stage());
1268   out_ << " ";
1269 
1270   // This is an entry point, the return type is the entry point output structure
1271   // if one exists, or void otherwise.
1272   auto out_data = ep_name_to_out_data_.find(current_ep_name_);
1273   bool has_out_data = out_data != ep_name_to_out_data_.end();
1274   if (has_out_data) {
1275     out_ << out_data->second.struct_name;
1276   } else {
1277     out_ << "void";
1278   }
1279   out_ << " " << namer_.NameFor(current_ep_name_) << "(";
1280 
1281   bool first = true;
1282   auto in_data = ep_name_to_in_data_.find(current_ep_name_);
1283   if (in_data != ep_name_to_in_data_.end()) {
1284     out_ << in_data->second.struct_name << " " << in_data->second.var_name
1285          << " [[stage_in]]";
1286     first = false;
1287   }
1288 
1289   for (auto data : func->referenced_builtin_variables()) {
1290     auto* var = data.first;
1291     if (var->storage_class() != ast::StorageClass::kInput) {
1292       continue;
1293     }
1294 
1295     if (!first) {
1296       out_ << ", ";
1297     }
1298     first = false;
1299 
1300     auto* builtin = data.second;
1301 
1302     if (!EmitType(var->type(), "")) {
1303       return false;
1304     }
1305 
1306     auto attr = builtin_to_attribute(builtin->value());
1307     if (attr.empty()) {
1308       error_ = "unknown builtin";
1309       return false;
1310     }
1311     out_ << " " << var->name() << " [[" << attr << "]]";
1312   }
1313 
1314   for (auto data : func->referenced_uniform_variables()) {
1315     if (!first) {
1316       out_ << ", ";
1317     }
1318     first = false;
1319 
1320     auto* var = data.first;
1321     // TODO(dsinclair): We're using the binding to make up the buffer number but
1322     // we should instead be using a provided mapping that uses both buffer and
1323     // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
1324     auto* binding = data.second.binding;
1325     if (binding == nullptr) {
1326       error_ = "unable to find binding information for uniform: " + var->name();
1327       return false;
1328     }
1329     // auto* set = data.second.set;
1330 
1331     out_ << "constant ";
1332     // TODO(dsinclair): Can you have a uniform array? If so, this needs to be
1333     // updated to handle arrays property.
1334     if (!EmitType(var->type(), "")) {
1335       return false;
1336     }
1337     out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
1338   }
1339 
1340   for (auto data : func->referenced_storagebuffer_variables()) {
1341     if (!first) {
1342       out_ << ", ";
1343     }
1344     first = false;
1345 
1346     auto* var = data.first;
1347     // TODO(dsinclair): We're using the binding to make up the buffer number but
1348     // we should instead be using a provided mapping that uses both buffer and
1349     // set. https://bugs.chromium.org/p/tint/issues/detail?id=104
1350     auto* binding = data.second.binding;
1351     // auto* set = data.second.set;
1352 
1353     if (!var->type()->IsAccessControl()) {
1354       error_ = "invalid type for storage buffer, expected access control";
1355       return false;
1356     }
1357     auto* ac = var->type()->AsAccessControl();
1358     if (ac->IsReadOnly()) {
1359       out_ << "const ";
1360     }
1361 
1362     out_ << "device ";
1363     if (!EmitType(ac->type(), "")) {
1364       return false;
1365     }
1366     out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]";
1367   }
1368 
1369   out_ << ") {" << std::endl;
1370 
1371   increment_indent();
1372 
1373   if (has_out_data) {
1374     make_indent();
1375     out_ << out_data->second.struct_name << " " << out_data->second.var_name
1376          << " = {};" << std::endl;
1377   }
1378 
1379   generating_entry_point_ = true;
1380   for (const auto& s : *(func->body())) {
1381     if (!EmitStatement(s.get())) {
1382       return false;
1383     }
1384   }
1385   generating_entry_point_ = false;
1386 
1387   decrement_indent();
1388   make_indent();
1389   out_ << "}" << std::endl;
1390 
1391   current_ep_name_ = "";
1392   return true;
1393 }
1394 
global_is_in_struct(ast::Variable * var) const1395 bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const {
1396   bool in_or_out_struct_has_location =
1397       var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() &&
1398       (var->storage_class() == ast::StorageClass::kInput ||
1399        var->storage_class() == ast::StorageClass::kOutput);
1400   bool in_struct_has_builtin =
1401       var->IsDecorated() && var->AsDecorated()->HasBuiltinDecoration() &&
1402       var->storage_class() == ast::StorageClass::kOutput;
1403   return in_or_out_struct_has_location || in_struct_has_builtin;
1404 }
1405 
EmitIdentifier(ast::IdentifierExpression * expr)1406 bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
1407   auto* ident = expr->AsIdentifier();
1408   ast::Variable* var = nullptr;
1409   if (global_variables_.get(ident->name(), &var)) {
1410     if (global_is_in_struct(var)) {
1411       auto var_type = var->storage_class() == ast::StorageClass::kInput
1412                           ? VarType::kIn
1413                           : VarType::kOut;
1414       auto name = current_ep_var_name(var_type);
1415       if (name.empty()) {
1416         error_ = "unable to find entry point data for variable";
1417         return false;
1418       }
1419       out_ << name << ".";
1420     }
1421   }
1422   out_ << namer_.NameFor(ident->name());
1423 
1424   return true;
1425 }
1426 
EmitLoop(ast::LoopStatement * stmt)1427 bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) {
1428   loop_emission_counter_++;
1429 
1430   std::string guard = namer_.NameFor("tint_msl_is_first_" +
1431                                      std::to_string(loop_emission_counter_));
1432 
1433   if (stmt->has_continuing()) {
1434     make_indent();
1435 
1436     // Continuing variables get their own scope.
1437     out_ << "{" << std::endl;
1438     increment_indent();
1439 
1440     make_indent();
1441     out_ << "bool " << guard << " = true;" << std::endl;
1442 
1443     // A continuing block may use variables declared in the method body. As a
1444     // first pass, if we have a continuing, we pull all declarations outside
1445     // the for loop into the continuing scope. Then, the variable declarations
1446     // will be turned into assignments.
1447     for (const auto& s : *(stmt->body())) {
1448       if (!s->IsVariableDecl()) {
1449         continue;
1450       }
1451       if (!EmitVariable(s->AsVariableDecl()->variable(), true)) {
1452         return false;
1453       }
1454     }
1455   }
1456 
1457   make_indent();
1458   out_ << "for(;;) {" << std::endl;
1459   increment_indent();
1460 
1461   if (stmt->has_continuing()) {
1462     make_indent();
1463     out_ << "if (!" << guard << ") ";
1464 
1465     if (!EmitBlockAndNewline(stmt->continuing())) {
1466       return false;
1467     }
1468 
1469     make_indent();
1470     out_ << guard << " = false;" << std::endl;
1471     out_ << std::endl;
1472   }
1473 
1474   for (const auto& s : *(stmt->body())) {
1475     // If we have a continuing block we've already emitted the variable
1476     // declaration before the loop, so treat it as an assignment.
1477     if (s->IsVariableDecl() && stmt->has_continuing()) {
1478       make_indent();
1479 
1480       auto* var = s->AsVariableDecl()->variable();
1481       out_ << var->name() << " = ";
1482       if (var->constructor() != nullptr) {
1483         if (!EmitExpression(var->constructor())) {
1484           return false;
1485         }
1486       } else {
1487         if (!EmitZeroValue(var->type())) {
1488           return false;
1489         }
1490       }
1491       out_ << ";" << std::endl;
1492       continue;
1493     }
1494 
1495     if (!EmitStatement(s.get())) {
1496       return false;
1497     }
1498   }
1499 
1500   decrement_indent();
1501   make_indent();
1502   out_ << "}" << std::endl;
1503 
1504   // Close the scope for any continuing variables.
1505   if (stmt->has_continuing()) {
1506     decrement_indent();
1507     make_indent();
1508     out_ << "}" << std::endl;
1509   }
1510 
1511   return true;
1512 }
1513 
EmitDiscard(ast::DiscardStatement *)1514 bool GeneratorImpl::EmitDiscard(ast::DiscardStatement*) {
1515   make_indent();
1516   // TODO(dsinclair): Verify this is correct when the discard semantics are
1517   // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
1518   out_ << "discard_fragment();" << std::endl;
1519   return true;
1520 }
1521 
EmitElse(ast::ElseStatement * stmt)1522 bool GeneratorImpl::EmitElse(ast::ElseStatement* stmt) {
1523   if (stmt->HasCondition()) {
1524     out_ << " else if (";
1525     if (!EmitExpression(stmt->condition())) {
1526       return false;
1527     }
1528     out_ << ") ";
1529   } else {
1530     out_ << " else ";
1531   }
1532 
1533   return EmitBlock(stmt->body());
1534 }
1535 
EmitIf(ast::IfStatement * stmt)1536 bool GeneratorImpl::EmitIf(ast::IfStatement* stmt) {
1537   make_indent();
1538 
1539   out_ << "if (";
1540   if (!EmitExpression(stmt->condition())) {
1541     return false;
1542   }
1543   out_ << ") ";
1544 
1545   if (!EmitBlock(stmt->body())) {
1546     return false;
1547   }
1548 
1549   for (const auto& e : stmt->else_statements()) {
1550     if (!EmitElse(e.get())) {
1551       return false;
1552     }
1553   }
1554   out_ << std::endl;
1555 
1556   return true;
1557 }
1558 
EmitMemberAccessor(ast::MemberAccessorExpression * expr)1559 bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) {
1560   if (!EmitExpression(expr->structure())) {
1561     return false;
1562   }
1563 
1564   out_ << ".";
1565 
1566   return EmitExpression(expr->member());
1567 }
1568 
EmitReturn(ast::ReturnStatement * stmt)1569 bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
1570   make_indent();
1571 
1572   out_ << "return";
1573 
1574   if (generating_entry_point_) {
1575     auto out_data = ep_name_to_out_data_.find(current_ep_name_);
1576     if (out_data != ep_name_to_out_data_.end()) {
1577       out_ << " " << out_data->second.var_name;
1578     }
1579   } else if (stmt->has_value()) {
1580     out_ << " ";
1581     if (!EmitExpression(stmt->value())) {
1582       return false;
1583     }
1584   }
1585   out_ << ";" << std::endl;
1586   return true;
1587 }
1588 
EmitBlock(const ast::BlockStatement * stmt)1589 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
1590   out_ << "{" << std::endl;
1591   increment_indent();
1592 
1593   for (const auto& s : *stmt) {
1594     if (!EmitStatement(s.get())) {
1595       return false;
1596     }
1597   }
1598 
1599   decrement_indent();
1600   make_indent();
1601   out_ << "}";
1602 
1603   return true;
1604 }
1605 
EmitBlockAndNewline(const ast::BlockStatement * stmt)1606 bool GeneratorImpl::EmitBlockAndNewline(const ast::BlockStatement* stmt) {
1607   const bool result = EmitBlock(stmt);
1608   if (result) {
1609     out_ << std::endl;
1610   }
1611   return result;
1612 }
1613 
EmitIndentedBlockAndNewline(ast::BlockStatement * stmt)1614 bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) {
1615   make_indent();
1616   const bool result = EmitBlock(stmt);
1617   if (result) {
1618     out_ << std::endl;
1619   }
1620   return result;
1621 }
1622 
EmitStatement(ast::Statement * stmt)1623 bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
1624   if (stmt->IsAssign()) {
1625     return EmitAssign(stmt->AsAssign());
1626   }
1627   if (stmt->IsBlock()) {
1628     return EmitIndentedBlockAndNewline(stmt->AsBlock());
1629   }
1630   if (stmt->IsBreak()) {
1631     return EmitBreak(stmt->AsBreak());
1632   }
1633   if (stmt->IsCall()) {
1634     make_indent();
1635     if (!EmitCall(stmt->AsCall()->expr())) {
1636       return false;
1637     }
1638     out_ << ";" << std::endl;
1639     return true;
1640   }
1641   if (stmt->IsContinue()) {
1642     return EmitContinue(stmt->AsContinue());
1643   }
1644   if (stmt->IsDiscard()) {
1645     return EmitDiscard(stmt->AsDiscard());
1646   }
1647   if (stmt->IsFallthrough()) {
1648     make_indent();
1649     out_ << "/* fallthrough */" << std::endl;
1650     return true;
1651   }
1652   if (stmt->IsIf()) {
1653     return EmitIf(stmt->AsIf());
1654   }
1655   if (stmt->IsLoop()) {
1656     return EmitLoop(stmt->AsLoop());
1657   }
1658   if (stmt->IsReturn()) {
1659     return EmitReturn(stmt->AsReturn());
1660   }
1661   if (stmt->IsSwitch()) {
1662     return EmitSwitch(stmt->AsSwitch());
1663   }
1664   if (stmt->IsVariableDecl()) {
1665     return EmitVariable(stmt->AsVariableDecl()->variable(), false);
1666   }
1667 
1668   error_ = "unknown statement type: " + stmt->str();
1669   return false;
1670 }
1671 
EmitSwitch(ast::SwitchStatement * stmt)1672 bool GeneratorImpl::EmitSwitch(ast::SwitchStatement* stmt) {
1673   make_indent();
1674 
1675   out_ << "switch(";
1676   if (!EmitExpression(stmt->condition())) {
1677     return false;
1678   }
1679   out_ << ") {" << std::endl;
1680 
1681   increment_indent();
1682 
1683   for (const auto& s : stmt->body()) {
1684     if (!EmitCase(s.get())) {
1685       return false;
1686     }
1687   }
1688 
1689   decrement_indent();
1690   make_indent();
1691   out_ << "}" << std::endl;
1692 
1693   return true;
1694 }
1695 
EmitType(ast::type::Type * type,const std::string & name)1696 bool GeneratorImpl::EmitType(ast::type::Type* type, const std::string& name) {
1697   if (type->IsAlias()) {
1698     auto* alias = type->AsAlias();
1699     out_ << namer_.NameFor(alias->name());
1700   } else if (type->IsArray()) {
1701     auto* ary = type->AsArray();
1702 
1703     ast::type::Type* base_type = ary;
1704     std::vector<uint32_t> sizes;
1705     while (base_type->IsArray()) {
1706       if (base_type->AsArray()->IsRuntimeArray()) {
1707         sizes.push_back(1);
1708       } else {
1709         sizes.push_back(base_type->AsArray()->size());
1710       }
1711       base_type = base_type->AsArray()->type();
1712     }
1713     if (!EmitType(base_type, "")) {
1714       return false;
1715     }
1716     if (!name.empty()) {
1717       out_ << " " << namer_.NameFor(name);
1718     }
1719     for (uint32_t size : sizes) {
1720       out_ << "[" << size << "]";
1721     }
1722   } else if (type->IsBool()) {
1723     out_ << "bool";
1724   } else if (type->IsF32()) {
1725     out_ << "float";
1726   } else if (type->IsI32()) {
1727     out_ << "int";
1728   } else if (type->IsMatrix()) {
1729     auto* mat = type->AsMatrix();
1730     if (!EmitType(mat->type(), "")) {
1731       return false;
1732     }
1733     out_ << mat->columns() << "x" << mat->rows();
1734   } else if (type->IsPointer()) {
1735     auto* ptr = type->AsPointer();
1736     // TODO(dsinclair): Storage class?
1737     if (!EmitType(ptr->type(), "")) {
1738       return false;
1739     }
1740     out_ << "*";
1741   } else if (type->IsSampler()) {
1742     out_ << "sampler";
1743   } else if (type->IsStruct()) {
1744     // The struct type emits as just the name. The declaration would be emitted
1745     // as part of emitting the constructed types.
1746     out_ << type->AsStruct()->name();
1747   } else if (type->IsTexture()) {
1748     auto* tex = type->AsTexture();
1749 
1750     if (tex->IsDepth()) {
1751       out_ << "depth";
1752     } else {
1753       out_ << "texture";
1754     }
1755 
1756     switch (tex->dim()) {
1757       case ast::type::TextureDimension::k1d:
1758         out_ << "1d";
1759         break;
1760       case ast::type::TextureDimension::k1dArray:
1761         out_ << "1d_array";
1762         break;
1763       case ast::type::TextureDimension::k2d:
1764         out_ << "2d";
1765         break;
1766       case ast::type::TextureDimension::k2dArray:
1767         out_ << "2d_array";
1768         break;
1769       case ast::type::TextureDimension::k3d:
1770         out_ << "3d";
1771         break;
1772       case ast::type::TextureDimension::kCube:
1773         out_ << "cube";
1774         break;
1775       case ast::type::TextureDimension::kCubeArray:
1776         out_ << "cube_array";
1777         break;
1778       default:
1779         error_ = "Invalid texture dimensions";
1780         return false;
1781     }
1782     if (tex->IsMultisampled()) {
1783       out_ << "_ms";
1784     }
1785     out_ << "<";
1786     if (tex->IsDepth()) {
1787       out_ << "float, access::sample";
1788     } else if (tex->IsStorage()) {
1789       auto* storage = tex->AsStorage();
1790       if (!EmitType(storage->type(), "")) {
1791         return false;
1792       }
1793       out_ << ", access::";
1794       if (storage->access() == ast::AccessControl::kReadOnly) {
1795         out_ << "read";
1796       } else if (storage->access() == ast::AccessControl::kWriteOnly) {
1797         out_ << "write";
1798       } else {
1799         error_ = "Invalid access control for storage texture";
1800         return false;
1801       }
1802     } else if (tex->IsMultisampled()) {
1803       if (!EmitType(tex->AsMultisampled()->type(), "")) {
1804         return false;
1805       }
1806       out_ << ", access::sample";
1807     } else if (tex->IsSampled()) {
1808       if (!EmitType(tex->AsSampled()->type(), "")) {
1809         return false;
1810       }
1811       out_ << ", access::sample";
1812     } else {
1813       error_ = "invalid texture type";
1814       return false;
1815     }
1816     out_ << ">";
1817 
1818   } else if (type->IsU32()) {
1819     out_ << "uint";
1820   } else if (type->IsVector()) {
1821     auto* vec = type->AsVector();
1822     if (!EmitType(vec->type(), "")) {
1823       return false;
1824     }
1825     out_ << vec->size();
1826   } else if (type->IsVoid()) {
1827     out_ << "void";
1828   } else {
1829     error_ = "unknown type in EmitType: " + type->type_name();
1830     return false;
1831   }
1832 
1833   return true;
1834 }
1835 
EmitStructType(const ast::type::StructType * str)1836 bool GeneratorImpl::EmitStructType(const ast::type::StructType* str) {
1837   // TODO(dsinclair): Block decoration?
1838   // if (str->impl()->decoration() != ast::StructDecoration::kNone) {
1839   // }
1840   out_ << "struct " << str->name() << " {" << std::endl;
1841 
1842   increment_indent();
1843   uint32_t current_offset = 0;
1844   uint32_t pad_count = 0;
1845   for (const auto& mem : str->impl()->members()) {
1846     make_indent();
1847     for (const auto& deco : mem->decorations()) {
1848       if (deco->IsOffset()) {
1849         uint32_t offset = deco->AsOffset()->offset();
1850         if (offset != current_offset) {
1851           out_ << "int8_t pad_" << pad_count << "[" << (offset - current_offset)
1852                << "];" << std::endl;
1853           pad_count++;
1854           make_indent();
1855         }
1856         current_offset = offset;
1857       } else {
1858         error_ = "unsupported member decoration: " + deco->to_str();
1859         return false;
1860       }
1861     }
1862 
1863     if (!EmitType(mem->type(), mem->name())) {
1864       return false;
1865     }
1866     auto size = calculate_alignment_size(mem->type());
1867     if (size == 0) {
1868       error_ = "unable to calculate byte size for: " + mem->type()->type_name();
1869       return false;
1870     }
1871     current_offset += size;
1872 
1873     // Array member name will be output with the type
1874     if (!mem->type()->IsArray()) {
1875       out_ << " " << namer_.NameFor(mem->name());
1876     }
1877     out_ << ";" << std::endl;
1878   }
1879   decrement_indent();
1880   make_indent();
1881 
1882   out_ << "};" << std::endl;
1883   return true;
1884 }
1885 
EmitUnaryOp(ast::UnaryOpExpression * expr)1886 bool GeneratorImpl::EmitUnaryOp(ast::UnaryOpExpression* expr) {
1887   switch (expr->op()) {
1888     case ast::UnaryOp::kNot:
1889       out_ << "!";
1890       break;
1891     case ast::UnaryOp::kNegation:
1892       out_ << "-";
1893       break;
1894   }
1895   out_ << "(";
1896 
1897   if (!EmitExpression(expr->expr())) {
1898     return false;
1899   }
1900 
1901   out_ << ")";
1902 
1903   return true;
1904 }
1905 
EmitVariable(ast::Variable * var,bool skip_constructor)1906 bool GeneratorImpl::EmitVariable(ast::Variable* var, bool skip_constructor) {
1907   make_indent();
1908 
1909   // TODO(dsinclair): Handle variable decorations
1910   if (var->IsDecorated()) {
1911     error_ = "Variable decorations are not handled yet";
1912     return false;
1913   }
1914 
1915   if (var->is_const()) {
1916     out_ << "const ";
1917   }
1918   if (!EmitType(var->type(), var->name())) {
1919     return false;
1920   }
1921   if (!var->type()->IsArray()) {
1922     out_ << " " << var->name();
1923   }
1924 
1925   if (!skip_constructor) {
1926     out_ << " = ";
1927     if (var->constructor() != nullptr) {
1928       if (!EmitExpression(var->constructor())) {
1929         return false;
1930       }
1931     } else if (var->storage_class() == ast::StorageClass::kPrivate ||
1932                var->storage_class() == ast::StorageClass::kFunction ||
1933                var->storage_class() == ast::StorageClass::kNone ||
1934                var->storage_class() == ast::StorageClass::kOutput) {
1935       if (!EmitZeroValue(var->type())) {
1936         return false;
1937       }
1938     }
1939   }
1940   out_ << ";" << std::endl;
1941 
1942   return true;
1943 }
1944 
EmitProgramConstVariable(const ast::Variable * var)1945 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
1946   make_indent();
1947 
1948   if (var->IsDecorated() && !var->AsDecorated()->HasConstantIdDecoration()) {
1949     error_ = "Decorated const values not valid";
1950     return false;
1951   }
1952   if (!var->is_const()) {
1953     error_ = "Expected a const value";
1954     return false;
1955   }
1956 
1957   out_ << "constant ";
1958   if (!EmitType(var->type(), var->name())) {
1959     return false;
1960   }
1961   if (!var->type()->IsArray()) {
1962     out_ << " " << var->name();
1963   }
1964 
1965   if (var->IsDecorated() && var->AsDecorated()->HasConstantIdDecoration()) {
1966     out_ << " [[function_constant(" << var->AsDecorated()->constant_id()
1967          << ")]]";
1968   } else if (var->constructor() != nullptr) {
1969     out_ << " = ";
1970     if (!EmitExpression(var->constructor())) {
1971       return false;
1972     }
1973   }
1974   out_ << ";" << std::endl;
1975 
1976   return true;
1977 }
1978 
1979 }  // namespace msl
1980 }  // namespace writer
1981 }  // namespace tint
1982