1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
23 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(protobufs::TransformationReplaceLinearAlgebraInstruction message)24     TransformationReplaceLinearAlgebraInstruction(
25         protobufs::TransformationReplaceLinearAlgebraInstruction message)
26     : message_(std::move(message)) {}
27 
28 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(const std::vector<uint32_t> & fresh_ids,const protobufs::InstructionDescriptor & instruction_descriptor)29     TransformationReplaceLinearAlgebraInstruction(
30         const std::vector<uint32_t>& fresh_ids,
31         const protobufs::InstructionDescriptor& instruction_descriptor) {
32   for (auto fresh_id : fresh_ids) {
33     message_.add_fresh_ids(fresh_id);
34   }
35   *message_.mutable_instruction_descriptor() = instruction_descriptor;
36 }
37 
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const38 bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
39     opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
40   auto instruction =
41       FindInstruction(message_.instruction_descriptor(), ir_context);
42 
43   // It must be a linear algebra instruction.
44   if (!spvOpcodeIsLinearAlgebra(instruction->opcode())) {
45     return false;
46   }
47 
48   // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
49   // apply the transformation.
50   if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
51       GetRequiredFreshIdCount(ir_context, instruction)) {
52     return false;
53   }
54 
55   // All ids in |message_.fresh_ids| must be fresh.
56   for (uint32_t fresh_id : message_.fresh_ids()) {
57     if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
58       return false;
59     }
60   }
61 
62   return true;
63 }
64 
Apply(opt::IRContext * ir_context,TransformationContext *) const65 void TransformationReplaceLinearAlgebraInstruction::Apply(
66     opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
67   auto linear_algebra_instruction =
68       FindInstruction(message_.instruction_descriptor(), ir_context);
69 
70   switch (linear_algebra_instruction->opcode()) {
71     case SpvOpTranspose:
72       ReplaceOpTranspose(ir_context, linear_algebra_instruction);
73       break;
74     case SpvOpVectorTimesScalar:
75       ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
76       break;
77     case SpvOpMatrixTimesScalar:
78       ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
79       break;
80     case SpvOpVectorTimesMatrix:
81       ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
82       break;
83     case SpvOpMatrixTimesVector:
84       ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
85       break;
86     case SpvOpMatrixTimesMatrix:
87       ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction);
88       break;
89     case SpvOpOuterProduct:
90       ReplaceOpOuterProduct(ir_context, linear_algebra_instruction);
91       break;
92     case SpvOpDot:
93       ReplaceOpDot(ir_context, linear_algebra_instruction);
94       break;
95     default:
96       assert(false && "Should be unreachable.");
97       break;
98   }
99 
100   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
101 }
102 
103 protobufs::Transformation
ToMessage() const104 TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
105   protobufs::Transformation result;
106   *result.mutable_replace_linear_algebra_instruction() = message_;
107   return result;
108 }
109 
GetRequiredFreshIdCount(opt::IRContext * ir_context,opt::Instruction * instruction)110 uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
111     opt::IRContext* ir_context, opt::Instruction* instruction) {
112   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
113   // Right now we only support certain operations.
114   switch (instruction->opcode()) {
115     case SpvOpTranspose: {
116       // For each matrix row, |2 * matrix_column_count| OpCompositeExtract and 1
117       // OpCompositeConstruct will be inserted.
118       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
119           instruction->GetSingleWordInOperand(0));
120       uint32_t matrix_column_count =
121           ir_context->get_type_mgr()
122               ->GetType(matrix_instruction->type_id())
123               ->AsMatrix()
124               ->element_count();
125       uint32_t matrix_row_count = ir_context->get_type_mgr()
126                                       ->GetType(matrix_instruction->type_id())
127                                       ->AsMatrix()
128                                       ->element_type()
129                                       ->AsVector()
130                                       ->element_count();
131       return matrix_row_count * (2 * matrix_column_count + 1);
132     }
133     case SpvOpVectorTimesScalar:
134       // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
135       // inserted.
136       return 2 *
137              ir_context->get_type_mgr()
138                  ->GetType(ir_context->get_def_use_mgr()
139                                ->GetDef(instruction->GetSingleWordInOperand(0))
140                                ->type_id())
141                  ->AsVector()
142                  ->element_count();
143     case SpvOpMatrixTimesScalar: {
144       // For each matrix column, |1 + column.size| OpCompositeExtract,
145       // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
146       // inserted.
147       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
148           instruction->GetSingleWordInOperand(0));
149       auto matrix_type =
150           ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
151       return 2 * matrix_type->AsMatrix()->element_count() *
152              (1 + matrix_type->AsMatrix()
153                       ->element_type()
154                       ->AsVector()
155                       ->element_count());
156     }
157     case SpvOpVectorTimesMatrix: {
158       // For each vector component, 1 OpCompositeExtract instruction will be
159       // inserted. For each matrix column, |1 + vector_component_count|
160       // OpCompositeExtract, |vector_component_count| OpFMul and
161       // |vector_component_count - 1| OpFAdd instructions will be inserted.
162       auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
163           instruction->GetSingleWordInOperand(0));
164       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
165           instruction->GetSingleWordInOperand(1));
166       uint32_t vector_component_count =
167           ir_context->get_type_mgr()
168               ->GetType(vector_instruction->type_id())
169               ->AsVector()
170               ->element_count();
171       uint32_t matrix_column_count =
172           ir_context->get_type_mgr()
173               ->GetType(matrix_instruction->type_id())
174               ->AsMatrix()
175               ->element_count();
176       return vector_component_count * (3 * matrix_column_count + 1);
177     }
178     case SpvOpMatrixTimesVector: {
179       // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
180       // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
181       // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
182       // each vector component, 1 OpCompositeExtract instruction will be
183       // inserted.
184       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
185           instruction->GetSingleWordInOperand(0));
186       uint32_t matrix_column_count =
187           ir_context->get_type_mgr()
188               ->GetType(matrix_instruction->type_id())
189               ->AsMatrix()
190               ->element_count();
191       uint32_t matrix_row_count = ir_context->get_type_mgr()
192                                       ->GetType(matrix_instruction->type_id())
193                                       ->AsMatrix()
194                                       ->element_type()
195                                       ->AsVector()
196                                       ->element_count();
197       return 3 * matrix_column_count * matrix_row_count +
198              2 * matrix_column_count - matrix_row_count;
199     }
200     case SpvOpMatrixTimesMatrix: {
201       // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct,
202       // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract,
203       // |matrix_1_row_count * matrix_1_column_count| OpFMul,
204       // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions
205       // will be inserted.
206       auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
207           instruction->GetSingleWordInOperand(0));
208       uint32_t matrix_1_column_count =
209           ir_context->get_type_mgr()
210               ->GetType(matrix_1_instruction->type_id())
211               ->AsMatrix()
212               ->element_count();
213       uint32_t matrix_1_row_count =
214           ir_context->get_type_mgr()
215               ->GetType(matrix_1_instruction->type_id())
216               ->AsMatrix()
217               ->element_type()
218               ->AsVector()
219               ->element_count();
220 
221       auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
222           instruction->GetSingleWordInOperand(1));
223       uint32_t matrix_2_column_count =
224           ir_context->get_type_mgr()
225               ->GetType(matrix_2_instruction->type_id())
226               ->AsMatrix()
227               ->element_count();
228       return matrix_2_column_count *
229              (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1));
230     }
231     case SpvOpOuterProduct: {
232       // For each |vector_2| component, |vector_1_component_count + 1|
233       // OpCompositeExtract, |vector_1_component_count| OpFMul and 1
234       // OpCompositeConstruct instructions will be inserted.
235       auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
236           instruction->GetSingleWordInOperand(0));
237       auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
238           instruction->GetSingleWordInOperand(1));
239       uint32_t vector_1_component_count =
240           ir_context->get_type_mgr()
241               ->GetType(vector_1_instruction->type_id())
242               ->AsVector()
243               ->element_count();
244       uint32_t vector_2_component_count =
245           ir_context->get_type_mgr()
246               ->GetType(vector_2_instruction->type_id())
247               ->AsVector()
248               ->element_count();
249       return 2 * vector_2_component_count * (vector_1_component_count + 1);
250     }
251     case SpvOpDot:
252       // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
253       // will be inserted. The first two OpFMul instructions will result the
254       // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
255       // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
256       // the OpDot instruction.
257       return 4 * ir_context->get_type_mgr()
258                      ->GetType(
259                          ir_context->get_def_use_mgr()
260                              ->GetDef(instruction->GetSingleWordInOperand(0))
261                              ->type_id())
262                      ->AsVector()
263                      ->element_count() -
264              2;
265     default:
266       assert(false && "Unsupported linear algebra instruction.");
267       return 0;
268   }
269 }
270 
ReplaceOpTranspose(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const271 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpTranspose(
272     opt::IRContext* ir_context,
273     opt::Instruction* linear_algebra_instruction) const {
274   // Gets OpTranspose instruction information.
275   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
276       linear_algebra_instruction->GetSingleWordInOperand(0));
277   uint32_t matrix_column_count = ir_context->get_type_mgr()
278                                      ->GetType(matrix_instruction->type_id())
279                                      ->AsMatrix()
280                                      ->element_count();
281   auto matrix_column_type = ir_context->get_type_mgr()
282                                 ->GetType(matrix_instruction->type_id())
283                                 ->AsMatrix()
284                                 ->element_type();
285   auto matrix_column_component_type =
286       matrix_column_type->AsVector()->element_type();
287   uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
288   auto resulting_matrix_column_type =
289       ir_context->get_type_mgr()
290           ->GetType(linear_algebra_instruction->type_id())
291           ->AsMatrix()
292           ->element_type();
293 
294   uint32_t fresh_id_index = 0;
295   std::vector<uint32_t> result_column_ids(matrix_row_count);
296   for (uint32_t i = 0; i < matrix_row_count; i++) {
297     std::vector<uint32_t> column_component_ids(matrix_column_count);
298     for (uint32_t j = 0; j < matrix_column_count; j++) {
299       // Extracts the matrix column.
300       uint32_t matrix_column_id = message_.fresh_ids(fresh_id_index++);
301       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
302           ir_context, SpvOpCompositeExtract,
303           ir_context->get_type_mgr()->GetId(matrix_column_type),
304           matrix_column_id,
305           opt::Instruction::OperandList(
306               {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
307                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
308 
309       // Extracts the matrix column component.
310       column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
311       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
312           ir_context, SpvOpCompositeExtract,
313           ir_context->get_type_mgr()->GetId(matrix_column_component_type),
314           column_component_ids[j],
315           opt::Instruction::OperandList(
316               {{SPV_OPERAND_TYPE_ID, {matrix_column_id}},
317                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
318     }
319 
320     // Inserts the resulting matrix column.
321     opt::Instruction::OperandList in_operands;
322     for (auto& column_component_id : column_component_ids) {
323       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
324     }
325     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
326     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
327         ir_context, SpvOpCompositeConstruct,
328         ir_context->get_type_mgr()->GetId(resulting_matrix_column_type),
329         result_column_ids[i], opt::Instruction::OperandList(in_operands)));
330   }
331 
332   // The OpTranspose instruction is changed to an OpCompositeConstruct
333   // instruction.
334   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
335   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
336   for (uint32_t i = 1; i < result_column_ids.size(); i++) {
337     linear_algebra_instruction->AddOperand(
338         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
339   }
340 
341   fuzzerutil::UpdateModuleIdBound(
342       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
343 }
344 
ReplaceOpVectorTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const345 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
346     opt::IRContext* ir_context,
347     opt::Instruction* linear_algebra_instruction) const {
348   // Gets OpVectorTimesScalar in operands.
349   auto vector = ir_context->get_def_use_mgr()->GetDef(
350       linear_algebra_instruction->GetSingleWordInOperand(0));
351   auto scalar = ir_context->get_def_use_mgr()->GetDef(
352       linear_algebra_instruction->GetSingleWordInOperand(1));
353 
354   uint32_t vector_component_count = ir_context->get_type_mgr()
355                                         ->GetType(vector->type_id())
356                                         ->AsVector()
357                                         ->element_count();
358   std::vector<uint32_t> float_multiplication_ids(vector_component_count);
359   uint32_t fresh_id_index = 0;
360 
361   for (uint32_t i = 0; i < vector_component_count; i++) {
362     // Extracts |vector| component.
363     uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
364     fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
365     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
366         ir_context, SpvOpCompositeExtract, scalar->type_id(), vector_extract_id,
367         opt::Instruction::OperandList(
368             {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
369              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
370 
371     // Multiplies the |vector| component with the |scalar|.
372     uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
373     float_multiplication_ids[i] = float_multiplication_id;
374     fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
375     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
376         ir_context, SpvOpFMul, scalar->type_id(), float_multiplication_id,
377         opt::Instruction::OperandList(
378             {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
379              {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
380   }
381 
382   // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
383   // instruction.
384   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
385   linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
386   linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
387   for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
388     linear_algebra_instruction->AddOperand(
389         {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
390   }
391 }
392 
ReplaceOpMatrixTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const393 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
394     opt::IRContext* ir_context,
395     opt::Instruction* linear_algebra_instruction) const {
396   // Gets OpMatrixTimesScalar in operands.
397   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
398       linear_algebra_instruction->GetSingleWordInOperand(0));
399   auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
400       linear_algebra_instruction->GetSingleWordInOperand(1));
401 
402   // Gets matrix information.
403   uint32_t matrix_column_count = ir_context->get_type_mgr()
404                                      ->GetType(matrix_instruction->type_id())
405                                      ->AsMatrix()
406                                      ->element_count();
407   auto matrix_column_type = ir_context->get_type_mgr()
408                                 ->GetType(matrix_instruction->type_id())
409                                 ->AsMatrix()
410                                 ->element_type();
411   uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
412 
413   std::vector<uint32_t> composite_construct_ids(matrix_column_count);
414   uint32_t fresh_id_index = 0;
415 
416   for (uint32_t i = 0; i < matrix_column_count; i++) {
417     // Extracts |matrix| column.
418     uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
419     fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
420     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
421         ir_context, SpvOpCompositeExtract,
422         ir_context->get_type_mgr()->GetId(matrix_column_type),
423         matrix_extract_id,
424         opt::Instruction::OperandList(
425             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
426              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
427 
428     std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
429 
430     for (uint32_t j = 0; j < matrix_column_size; j++) {
431       // Extracts |column| component.
432       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
433       fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
434       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
435           ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(),
436           column_extract_id,
437           opt::Instruction::OperandList(
438               {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
439                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
440 
441       // Multiplies the |column| component with the |scalar|.
442       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
443       fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
444       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
445           ir_context, SpvOpFMul, scalar_instruction->type_id(),
446           float_multiplication_ids[j],
447           opt::Instruction::OperandList(
448               {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
449                {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
450     }
451 
452     // Constructs a new column multiplied by |scalar|.
453     opt::Instruction::OperandList composite_construct_in_operands;
454     for (uint32_t& float_multiplication_id : float_multiplication_ids) {
455       composite_construct_in_operands.push_back(
456           {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
457     }
458     composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
459     fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
460     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
461         ir_context, SpvOpCompositeConstruct,
462         ir_context->get_type_mgr()->GetId(matrix_column_type),
463         composite_construct_ids[i], composite_construct_in_operands));
464   }
465 
466   // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
467   // instruction.
468   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
469   linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
470   linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
471   for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
472     linear_algebra_instruction->AddOperand(
473         {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
474   }
475 }
476 
ReplaceOpVectorTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const477 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
478     opt::IRContext* ir_context,
479     opt::Instruction* linear_algebra_instruction) const {
480   // Gets vector information.
481   auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
482       linear_algebra_instruction->GetSingleWordInOperand(0));
483   uint32_t vector_component_count = ir_context->get_type_mgr()
484                                         ->GetType(vector_instruction->type_id())
485                                         ->AsVector()
486                                         ->element_count();
487   auto vector_component_type = ir_context->get_type_mgr()
488                                    ->GetType(vector_instruction->type_id())
489                                    ->AsVector()
490                                    ->element_type();
491 
492   // Extracts vector components.
493   uint32_t fresh_id_index = 0;
494   std::vector<uint32_t> vector_component_ids(vector_component_count);
495   for (uint32_t i = 0; i < vector_component_count; i++) {
496     vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
497     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
498         ir_context, SpvOpCompositeExtract,
499         ir_context->get_type_mgr()->GetId(vector_component_type),
500         vector_component_ids[i],
501         opt::Instruction::OperandList(
502             {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
503              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
504   }
505 
506   // Gets matrix information.
507   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
508       linear_algebra_instruction->GetSingleWordInOperand(1));
509   uint32_t matrix_column_count = ir_context->get_type_mgr()
510                                      ->GetType(matrix_instruction->type_id())
511                                      ->AsMatrix()
512                                      ->element_count();
513   auto matrix_column_type = ir_context->get_type_mgr()
514                                 ->GetType(matrix_instruction->type_id())
515                                 ->AsMatrix()
516                                 ->element_type();
517 
518   std::vector<uint32_t> result_component_ids(matrix_column_count);
519   for (uint32_t i = 0; i < matrix_column_count; i++) {
520     // Extracts matrix column.
521     uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
522     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
523         ir_context, SpvOpCompositeExtract,
524         ir_context->get_type_mgr()->GetId(matrix_column_type),
525         matrix_extract_id,
526         opt::Instruction::OperandList(
527             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
528              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
529 
530     std::vector<uint32_t> float_multiplication_ids(vector_component_count);
531     for (uint32_t j = 0; j < vector_component_count; j++) {
532       // Extracts column component.
533       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
534       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
535           ir_context, SpvOpCompositeExtract,
536           ir_context->get_type_mgr()->GetId(vector_component_type),
537           column_extract_id,
538           opt::Instruction::OperandList(
539               {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
540                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
541 
542       // Multiplies corresponding vector and column components.
543       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
544       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
545           ir_context, SpvOpFMul,
546           ir_context->get_type_mgr()->GetId(vector_component_type),
547           float_multiplication_ids[j],
548           opt::Instruction::OperandList(
549               {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
550                {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
551     }
552 
553     // Adds the multiplication results.
554     std::vector<uint32_t> float_add_ids;
555     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
556     float_add_ids.push_back(float_add_id);
557     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
558         ir_context, SpvOpFAdd,
559         ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
560         opt::Instruction::OperandList(
561             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
562              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
563     for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
564       float_add_id = message_.fresh_ids(fresh_id_index++);
565       float_add_ids.push_back(float_add_id);
566       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
567           ir_context, SpvOpFAdd,
568           ir_context->get_type_mgr()->GetId(vector_component_type),
569           float_add_id,
570           opt::Instruction::OperandList(
571               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
572                {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
573     }
574 
575     result_component_ids[i] = float_add_ids.back();
576   }
577 
578   // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
579   // instruction.
580   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
581   linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
582   linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
583   for (uint32_t i = 2; i < result_component_ids.size(); i++) {
584     linear_algebra_instruction->AddOperand(
585         {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
586   }
587 
588   fuzzerutil::UpdateModuleIdBound(
589       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
590 }
591 
ReplaceOpMatrixTimesVector(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const592 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
593     opt::IRContext* ir_context,
594     opt::Instruction* linear_algebra_instruction) const {
595   // Gets matrix information.
596   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
597       linear_algebra_instruction->GetSingleWordInOperand(0));
598   uint32_t matrix_column_count = ir_context->get_type_mgr()
599                                      ->GetType(matrix_instruction->type_id())
600                                      ->AsMatrix()
601                                      ->element_count();
602   auto matrix_column_type = ir_context->get_type_mgr()
603                                 ->GetType(matrix_instruction->type_id())
604                                 ->AsMatrix()
605                                 ->element_type();
606   uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
607 
608   // Extracts matrix columns.
609   uint32_t fresh_id_index = 0;
610   std::vector<uint32_t> matrix_column_ids(matrix_column_count);
611   for (uint32_t i = 0; i < matrix_column_count; i++) {
612     matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
613     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
614         ir_context, SpvOpCompositeExtract,
615         ir_context->get_type_mgr()->GetId(matrix_column_type),
616         matrix_column_ids[i],
617         opt::Instruction::OperandList(
618             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
619              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
620   }
621 
622   // Gets vector information.
623   auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
624       linear_algebra_instruction->GetSingleWordInOperand(1));
625   auto vector_component_type = ir_context->get_type_mgr()
626                                    ->GetType(vector_instruction->type_id())
627                                    ->AsVector()
628                                    ->element_type();
629 
630   // Extracts vector components.
631   std::vector<uint32_t> vector_component_ids(matrix_column_count);
632   for (uint32_t i = 0; i < matrix_column_count; i++) {
633     vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
634     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
635         ir_context, SpvOpCompositeExtract,
636         ir_context->get_type_mgr()->GetId(vector_component_type),
637         vector_component_ids[i],
638         opt::Instruction::OperandList(
639             {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
640              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
641   }
642 
643   std::vector<uint32_t> result_component_ids(matrix_row_count);
644   for (uint32_t i = 0; i < matrix_row_count; i++) {
645     std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
646     for (uint32_t j = 0; j < matrix_column_count; j++) {
647       // Extracts column component.
648       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
649       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
650           ir_context, SpvOpCompositeExtract,
651           ir_context->get_type_mgr()->GetId(vector_component_type),
652           column_extract_id,
653           opt::Instruction::OperandList(
654               {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
655                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
656 
657       // Multiplies corresponding vector and column components.
658       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
659       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
660           ir_context, SpvOpFMul,
661           ir_context->get_type_mgr()->GetId(vector_component_type),
662           float_multiplication_ids[j],
663           opt::Instruction::OperandList(
664               {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
665                {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
666     }
667 
668     // Adds the multiplication results.
669     std::vector<uint32_t> float_add_ids;
670     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
671     float_add_ids.push_back(float_add_id);
672     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
673         ir_context, SpvOpFAdd,
674         ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
675         opt::Instruction::OperandList(
676             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
677              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
678     for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
679       float_add_id = message_.fresh_ids(fresh_id_index++);
680       float_add_ids.push_back(float_add_id);
681       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
682           ir_context, SpvOpFAdd,
683           ir_context->get_type_mgr()->GetId(vector_component_type),
684           float_add_id,
685           opt::Instruction::OperandList(
686               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
687                {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
688     }
689 
690     result_component_ids[i] = float_add_ids.back();
691   }
692 
693   // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
694   // instruction.
695   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
696   linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
697   linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
698   for (uint32_t i = 2; i < result_component_ids.size(); i++) {
699     linear_algebra_instruction->AddOperand(
700         {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
701   }
702 
703   fuzzerutil::UpdateModuleIdBound(
704       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
705 }
706 
ReplaceOpMatrixTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const707 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix(
708     opt::IRContext* ir_context,
709     opt::Instruction* linear_algebra_instruction) const {
710   // Gets matrix 1 information.
711   auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
712       linear_algebra_instruction->GetSingleWordInOperand(0));
713   uint32_t matrix_1_column_count =
714       ir_context->get_type_mgr()
715           ->GetType(matrix_1_instruction->type_id())
716           ->AsMatrix()
717           ->element_count();
718   auto matrix_1_column_type = ir_context->get_type_mgr()
719                                   ->GetType(matrix_1_instruction->type_id())
720                                   ->AsMatrix()
721                                   ->element_type();
722   auto matrix_1_column_component_type =
723       matrix_1_column_type->AsVector()->element_type();
724   uint32_t matrix_1_row_count =
725       matrix_1_column_type->AsVector()->element_count();
726 
727   // Gets matrix 2 information.
728   auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
729       linear_algebra_instruction->GetSingleWordInOperand(1));
730   uint32_t matrix_2_column_count =
731       ir_context->get_type_mgr()
732           ->GetType(matrix_2_instruction->type_id())
733           ->AsMatrix()
734           ->element_count();
735   auto matrix_2_column_type = ir_context->get_type_mgr()
736                                   ->GetType(matrix_2_instruction->type_id())
737                                   ->AsMatrix()
738                                   ->element_type();
739 
740   uint32_t fresh_id_index = 0;
741   std::vector<uint32_t> result_column_ids(matrix_2_column_count);
742   for (uint32_t i = 0; i < matrix_2_column_count; i++) {
743     // Extracts matrix 2 column.
744     uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++);
745     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
746         ir_context, SpvOpCompositeExtract,
747         ir_context->get_type_mgr()->GetId(matrix_2_column_type),
748         matrix_2_column_id,
749         opt::Instruction::OperandList(
750             {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}},
751              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
752 
753     std::vector<uint32_t> column_component_ids(matrix_1_row_count);
754     for (uint32_t j = 0; j < matrix_1_row_count; j++) {
755       std::vector<uint32_t> float_multiplication_ids(matrix_1_column_count);
756       for (uint32_t k = 0; k < matrix_1_column_count; k++) {
757         // Extracts matrix 1 column.
758         uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++);
759         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
760             ir_context, SpvOpCompositeExtract,
761             ir_context->get_type_mgr()->GetId(matrix_1_column_type),
762             matrix_1_column_id,
763             opt::Instruction::OperandList(
764                 {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}},
765                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
766 
767         // Extracts matrix 1 column component.
768         uint32_t matrix_1_column_component_id =
769             message_.fresh_ids(fresh_id_index++);
770         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
771             ir_context, SpvOpCompositeExtract,
772             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
773             matrix_1_column_component_id,
774             opt::Instruction::OperandList(
775                 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}},
776                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
777 
778         // Extracts matrix 2 column component.
779         uint32_t matrix_2_column_component_id =
780             message_.fresh_ids(fresh_id_index++);
781         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
782             ir_context, SpvOpCompositeExtract,
783             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
784             matrix_2_column_component_id,
785             opt::Instruction::OperandList(
786                 {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}},
787                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
788 
789         // Multiplies corresponding matrix 1 and matrix 2 column components.
790         float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++);
791         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
792             ir_context, SpvOpFMul,
793             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
794             float_multiplication_ids[k],
795             opt::Instruction::OperandList(
796                 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}},
797                  {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}})));
798       }
799 
800       // Adds the multiplication results.
801       std::vector<uint32_t> float_add_ids;
802       uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
803       float_add_ids.push_back(float_add_id);
804       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
805           ir_context, SpvOpFAdd,
806           ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
807           float_add_id,
808           opt::Instruction::OperandList(
809               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
810                {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
811       for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) {
812         float_add_id = message_.fresh_ids(fresh_id_index++);
813         float_add_ids.push_back(float_add_id);
814         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
815             ir_context, SpvOpFAdd,
816             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
817             float_add_id,
818             opt::Instruction::OperandList(
819                 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}},
820                  {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}})));
821       }
822 
823       column_component_ids[j] = float_add_ids.back();
824     }
825 
826     // Inserts the resulting matrix column.
827     opt::Instruction::OperandList in_operands;
828     for (auto& column_component_id : column_component_ids) {
829       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
830     }
831     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
832     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
833         ir_context, SpvOpCompositeConstruct,
834         ir_context->get_type_mgr()->GetId(matrix_1_column_type),
835         result_column_ids[i], opt::Instruction::OperandList(in_operands)));
836   }
837 
838   // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct
839   // instruction.
840   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
841   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
842   linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
843   for (uint32_t i = 2; i < result_column_ids.size(); i++) {
844     linear_algebra_instruction->AddOperand(
845         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
846   }
847 
848   fuzzerutil::UpdateModuleIdBound(
849       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
850 }
851 
ReplaceOpOuterProduct(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const852 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpOuterProduct(
853     opt::IRContext* ir_context,
854     opt::Instruction* linear_algebra_instruction) const {
855   // Gets vector 1 information.
856   auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
857       linear_algebra_instruction->GetSingleWordInOperand(0));
858   uint32_t vector_1_component_count =
859       ir_context->get_type_mgr()
860           ->GetType(vector_1_instruction->type_id())
861           ->AsVector()
862           ->element_count();
863   auto vector_1_component_type = ir_context->get_type_mgr()
864                                      ->GetType(vector_1_instruction->type_id())
865                                      ->AsVector()
866                                      ->element_type();
867 
868   // Gets vector 2 information.
869   auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
870       linear_algebra_instruction->GetSingleWordInOperand(1));
871   uint32_t vector_2_component_count =
872       ir_context->get_type_mgr()
873           ->GetType(vector_2_instruction->type_id())
874           ->AsVector()
875           ->element_count();
876 
877   uint32_t fresh_id_index = 0;
878   std::vector<uint32_t> result_column_ids(vector_2_component_count);
879   for (uint32_t i = 0; i < vector_2_component_count; i++) {
880     // Extracts |vector_2| component.
881     uint32_t vector_2_component_id = message_.fresh_ids(fresh_id_index++);
882     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
883         ir_context, SpvOpCompositeExtract,
884         ir_context->get_type_mgr()->GetId(vector_1_component_type),
885         vector_2_component_id,
886         opt::Instruction::OperandList(
887             {{SPV_OPERAND_TYPE_ID, {vector_2_instruction->result_id()}},
888              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
889 
890     std::vector<uint32_t> column_component_ids(vector_1_component_count);
891     for (uint32_t j = 0; j < vector_1_component_count; j++) {
892       // Extracts |vector_1| component.
893       uint32_t vector_1_component_id = message_.fresh_ids(fresh_id_index++);
894       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
895           ir_context, SpvOpCompositeExtract,
896           ir_context->get_type_mgr()->GetId(vector_1_component_type),
897           vector_1_component_id,
898           opt::Instruction::OperandList(
899               {{SPV_OPERAND_TYPE_ID, {vector_1_instruction->result_id()}},
900                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
901 
902       // Multiplies |vector_1| and |vector_2| components.
903       column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
904       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
905           ir_context, SpvOpFMul,
906           ir_context->get_type_mgr()->GetId(vector_1_component_type),
907           column_component_ids[j],
908           opt::Instruction::OperandList(
909               {{SPV_OPERAND_TYPE_ID, {vector_2_component_id}},
910                {SPV_OPERAND_TYPE_ID, {vector_1_component_id}}})));
911     }
912 
913     // Inserts the resulting matrix column.
914     opt::Instruction::OperandList in_operands;
915     for (auto& column_component_id : column_component_ids) {
916       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
917     }
918     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
919     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
920         ir_context, SpvOpCompositeConstruct, vector_1_instruction->type_id(),
921         result_column_ids[i], in_operands));
922   }
923 
924   // The OpOuterProduct instruction is changed to an OpCompositeConstruct
925   // instruction.
926   linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
927   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
928   linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
929   for (uint32_t i = 2; i < result_column_ids.size(); i++) {
930     linear_algebra_instruction->AddOperand(
931         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
932   }
933 
934   fuzzerutil::UpdateModuleIdBound(
935       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
936 }
937 
ReplaceOpDot(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const938 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
939     opt::IRContext* ir_context,
940     opt::Instruction* linear_algebra_instruction) const {
941   // Gets OpDot in operands.
942   auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
943       linear_algebra_instruction->GetSingleWordInOperand(0));
944   auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
945       linear_algebra_instruction->GetSingleWordInOperand(1));
946 
947   uint32_t vectors_component_count = ir_context->get_type_mgr()
948                                          ->GetType(vector_1->type_id())
949                                          ->AsVector()
950                                          ->element_count();
951   std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
952   uint32_t fresh_id_index = 0;
953 
954   for (uint32_t i = 0; i < vectors_component_count; i++) {
955     // Extracts |vector_1| component.
956     uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
957     fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
958     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
959         ir_context, SpvOpCompositeExtract,
960         linear_algebra_instruction->type_id(), vector_1_extract_id,
961         opt::Instruction::OperandList(
962             {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
963              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
964 
965     // Extracts |vector_2| component.
966     uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
967     fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
968     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
969         ir_context, SpvOpCompositeExtract,
970         linear_algebra_instruction->type_id(), vector_2_extract_id,
971         opt::Instruction::OperandList(
972             {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
973              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
974 
975     // Multiplies the pair of components.
976     float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
977     fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
978     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
979         ir_context, SpvOpFMul, linear_algebra_instruction->type_id(),
980         float_multiplication_ids[i],
981         opt::Instruction::OperandList(
982             {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
983              {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
984   }
985 
986   // If the vector has 2 components, then there will be 2 float multiplication
987   // instructions.
988   if (vectors_component_count == 2) {
989     linear_algebra_instruction->SetOpcode(SpvOpFAdd);
990     linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
991     linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
992   } else {
993     // The first OpFAdd instruction has as operands the first two OpFMul
994     // instructions.
995     std::vector<uint32_t> float_add_ids;
996     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
997     float_add_ids.push_back(float_add_id);
998     fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
999     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1000         ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
1001         float_add_id,
1002         opt::Instruction::OperandList(
1003             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
1004              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
1005 
1006     // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
1007     // instruction.
1008     for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
1009       float_add_id = message_.fresh_ids(fresh_id_index++);
1010       fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1011       float_add_ids.push_back(float_add_id);
1012       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1013           ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
1014           float_add_id,
1015           opt::Instruction::OperandList(
1016               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
1017                {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
1018     }
1019 
1020     // The last OpFAdd instruction is got by changing some of the OpDot
1021     // instruction attributes.
1022     linear_algebra_instruction->SetOpcode(SpvOpFAdd);
1023     linear_algebra_instruction->SetInOperand(
1024         0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
1025     linear_algebra_instruction->SetInOperand(
1026         1, {float_add_ids[float_add_ids.size() - 1]});
1027   }
1028 }
1029 
1030 std::unordered_set<uint32_t>
GetFreshIds() const1031 TransformationReplaceLinearAlgebraInstruction::GetFreshIds() const {
1032   std::unordered_set<uint32_t> result;
1033   for (auto id : message_.fresh_ids()) {
1034     result.insert(id);
1035   }
1036   return result;
1037 }
1038 
1039 }  // namespace fuzz
1040 }  // namespace spvtools
1041