1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
23 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(protobufs::TransformationReplaceLinearAlgebraInstruction message)24 TransformationReplaceLinearAlgebraInstruction(
25 protobufs::TransformationReplaceLinearAlgebraInstruction message)
26 : message_(std::move(message)) {}
27
28 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(const std::vector<uint32_t> & fresh_ids,const protobufs::InstructionDescriptor & instruction_descriptor)29 TransformationReplaceLinearAlgebraInstruction(
30 const std::vector<uint32_t>& fresh_ids,
31 const protobufs::InstructionDescriptor& instruction_descriptor) {
32 for (auto fresh_id : fresh_ids) {
33 message_.add_fresh_ids(fresh_id);
34 }
35 *message_.mutable_instruction_descriptor() = instruction_descriptor;
36 }
37
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const38 bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
39 opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
40 auto instruction =
41 FindInstruction(message_.instruction_descriptor(), ir_context);
42
43 // It must be a linear algebra instruction.
44 if (!spvOpcodeIsLinearAlgebra(instruction->opcode())) {
45 return false;
46 }
47
48 // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
49 // apply the transformation.
50 if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
51 GetRequiredFreshIdCount(ir_context, instruction)) {
52 return false;
53 }
54
55 // All ids in |message_.fresh_ids| must be fresh.
56 for (uint32_t fresh_id : message_.fresh_ids()) {
57 if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
58 return false;
59 }
60 }
61
62 return true;
63 }
64
Apply(opt::IRContext * ir_context,TransformationContext *) const65 void TransformationReplaceLinearAlgebraInstruction::Apply(
66 opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
67 auto linear_algebra_instruction =
68 FindInstruction(message_.instruction_descriptor(), ir_context);
69
70 switch (linear_algebra_instruction->opcode()) {
71 case SpvOpTranspose:
72 ReplaceOpTranspose(ir_context, linear_algebra_instruction);
73 break;
74 case SpvOpVectorTimesScalar:
75 ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
76 break;
77 case SpvOpMatrixTimesScalar:
78 ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
79 break;
80 case SpvOpVectorTimesMatrix:
81 ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
82 break;
83 case SpvOpMatrixTimesVector:
84 ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
85 break;
86 case SpvOpMatrixTimesMatrix:
87 ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction);
88 break;
89 case SpvOpOuterProduct:
90 ReplaceOpOuterProduct(ir_context, linear_algebra_instruction);
91 break;
92 case SpvOpDot:
93 ReplaceOpDot(ir_context, linear_algebra_instruction);
94 break;
95 default:
96 assert(false && "Should be unreachable.");
97 break;
98 }
99
100 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
101 }
102
103 protobufs::Transformation
ToMessage() const104 TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
105 protobufs::Transformation result;
106 *result.mutable_replace_linear_algebra_instruction() = message_;
107 return result;
108 }
109
GetRequiredFreshIdCount(opt::IRContext * ir_context,opt::Instruction * instruction)110 uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
111 opt::IRContext* ir_context, opt::Instruction* instruction) {
112 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
113 // Right now we only support certain operations.
114 switch (instruction->opcode()) {
115 case SpvOpTranspose: {
116 // For each matrix row, |2 * matrix_column_count| OpCompositeExtract and 1
117 // OpCompositeConstruct will be inserted.
118 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
119 instruction->GetSingleWordInOperand(0));
120 uint32_t matrix_column_count =
121 ir_context->get_type_mgr()
122 ->GetType(matrix_instruction->type_id())
123 ->AsMatrix()
124 ->element_count();
125 uint32_t matrix_row_count = ir_context->get_type_mgr()
126 ->GetType(matrix_instruction->type_id())
127 ->AsMatrix()
128 ->element_type()
129 ->AsVector()
130 ->element_count();
131 return matrix_row_count * (2 * matrix_column_count + 1);
132 }
133 case SpvOpVectorTimesScalar:
134 // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
135 // inserted.
136 return 2 *
137 ir_context->get_type_mgr()
138 ->GetType(ir_context->get_def_use_mgr()
139 ->GetDef(instruction->GetSingleWordInOperand(0))
140 ->type_id())
141 ->AsVector()
142 ->element_count();
143 case SpvOpMatrixTimesScalar: {
144 // For each matrix column, |1 + column.size| OpCompositeExtract,
145 // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
146 // inserted.
147 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
148 instruction->GetSingleWordInOperand(0));
149 auto matrix_type =
150 ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
151 return 2 * matrix_type->AsMatrix()->element_count() *
152 (1 + matrix_type->AsMatrix()
153 ->element_type()
154 ->AsVector()
155 ->element_count());
156 }
157 case SpvOpVectorTimesMatrix: {
158 // For each vector component, 1 OpCompositeExtract instruction will be
159 // inserted. For each matrix column, |1 + vector_component_count|
160 // OpCompositeExtract, |vector_component_count| OpFMul and
161 // |vector_component_count - 1| OpFAdd instructions will be inserted.
162 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
163 instruction->GetSingleWordInOperand(0));
164 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
165 instruction->GetSingleWordInOperand(1));
166 uint32_t vector_component_count =
167 ir_context->get_type_mgr()
168 ->GetType(vector_instruction->type_id())
169 ->AsVector()
170 ->element_count();
171 uint32_t matrix_column_count =
172 ir_context->get_type_mgr()
173 ->GetType(matrix_instruction->type_id())
174 ->AsMatrix()
175 ->element_count();
176 return vector_component_count * (3 * matrix_column_count + 1);
177 }
178 case SpvOpMatrixTimesVector: {
179 // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
180 // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
181 // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
182 // each vector component, 1 OpCompositeExtract instruction will be
183 // inserted.
184 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
185 instruction->GetSingleWordInOperand(0));
186 uint32_t matrix_column_count =
187 ir_context->get_type_mgr()
188 ->GetType(matrix_instruction->type_id())
189 ->AsMatrix()
190 ->element_count();
191 uint32_t matrix_row_count = ir_context->get_type_mgr()
192 ->GetType(matrix_instruction->type_id())
193 ->AsMatrix()
194 ->element_type()
195 ->AsVector()
196 ->element_count();
197 return 3 * matrix_column_count * matrix_row_count +
198 2 * matrix_column_count - matrix_row_count;
199 }
200 case SpvOpMatrixTimesMatrix: {
201 // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct,
202 // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract,
203 // |matrix_1_row_count * matrix_1_column_count| OpFMul,
204 // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions
205 // will be inserted.
206 auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
207 instruction->GetSingleWordInOperand(0));
208 uint32_t matrix_1_column_count =
209 ir_context->get_type_mgr()
210 ->GetType(matrix_1_instruction->type_id())
211 ->AsMatrix()
212 ->element_count();
213 uint32_t matrix_1_row_count =
214 ir_context->get_type_mgr()
215 ->GetType(matrix_1_instruction->type_id())
216 ->AsMatrix()
217 ->element_type()
218 ->AsVector()
219 ->element_count();
220
221 auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
222 instruction->GetSingleWordInOperand(1));
223 uint32_t matrix_2_column_count =
224 ir_context->get_type_mgr()
225 ->GetType(matrix_2_instruction->type_id())
226 ->AsMatrix()
227 ->element_count();
228 return matrix_2_column_count *
229 (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1));
230 }
231 case SpvOpOuterProduct: {
232 // For each |vector_2| component, |vector_1_component_count + 1|
233 // OpCompositeExtract, |vector_1_component_count| OpFMul and 1
234 // OpCompositeConstruct instructions will be inserted.
235 auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
236 instruction->GetSingleWordInOperand(0));
237 auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
238 instruction->GetSingleWordInOperand(1));
239 uint32_t vector_1_component_count =
240 ir_context->get_type_mgr()
241 ->GetType(vector_1_instruction->type_id())
242 ->AsVector()
243 ->element_count();
244 uint32_t vector_2_component_count =
245 ir_context->get_type_mgr()
246 ->GetType(vector_2_instruction->type_id())
247 ->AsVector()
248 ->element_count();
249 return 2 * vector_2_component_count * (vector_1_component_count + 1);
250 }
251 case SpvOpDot:
252 // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
253 // will be inserted. The first two OpFMul instructions will result the
254 // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
255 // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
256 // the OpDot instruction.
257 return 4 * ir_context->get_type_mgr()
258 ->GetType(
259 ir_context->get_def_use_mgr()
260 ->GetDef(instruction->GetSingleWordInOperand(0))
261 ->type_id())
262 ->AsVector()
263 ->element_count() -
264 2;
265 default:
266 assert(false && "Unsupported linear algebra instruction.");
267 return 0;
268 }
269 }
270
ReplaceOpTranspose(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const271 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpTranspose(
272 opt::IRContext* ir_context,
273 opt::Instruction* linear_algebra_instruction) const {
274 // Gets OpTranspose instruction information.
275 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
276 linear_algebra_instruction->GetSingleWordInOperand(0));
277 uint32_t matrix_column_count = ir_context->get_type_mgr()
278 ->GetType(matrix_instruction->type_id())
279 ->AsMatrix()
280 ->element_count();
281 auto matrix_column_type = ir_context->get_type_mgr()
282 ->GetType(matrix_instruction->type_id())
283 ->AsMatrix()
284 ->element_type();
285 auto matrix_column_component_type =
286 matrix_column_type->AsVector()->element_type();
287 uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
288 auto resulting_matrix_column_type =
289 ir_context->get_type_mgr()
290 ->GetType(linear_algebra_instruction->type_id())
291 ->AsMatrix()
292 ->element_type();
293
294 uint32_t fresh_id_index = 0;
295 std::vector<uint32_t> result_column_ids(matrix_row_count);
296 for (uint32_t i = 0; i < matrix_row_count; i++) {
297 std::vector<uint32_t> column_component_ids(matrix_column_count);
298 for (uint32_t j = 0; j < matrix_column_count; j++) {
299 // Extracts the matrix column.
300 uint32_t matrix_column_id = message_.fresh_ids(fresh_id_index++);
301 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
302 ir_context, SpvOpCompositeExtract,
303 ir_context->get_type_mgr()->GetId(matrix_column_type),
304 matrix_column_id,
305 opt::Instruction::OperandList(
306 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
307 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
308
309 // Extracts the matrix column component.
310 column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
311 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
312 ir_context, SpvOpCompositeExtract,
313 ir_context->get_type_mgr()->GetId(matrix_column_component_type),
314 column_component_ids[j],
315 opt::Instruction::OperandList(
316 {{SPV_OPERAND_TYPE_ID, {matrix_column_id}},
317 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
318 }
319
320 // Inserts the resulting matrix column.
321 opt::Instruction::OperandList in_operands;
322 for (auto& column_component_id : column_component_ids) {
323 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
324 }
325 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
326 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
327 ir_context, SpvOpCompositeConstruct,
328 ir_context->get_type_mgr()->GetId(resulting_matrix_column_type),
329 result_column_ids[i], opt::Instruction::OperandList(in_operands)));
330 }
331
332 // The OpTranspose instruction is changed to an OpCompositeConstruct
333 // instruction.
334 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
335 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
336 for (uint32_t i = 1; i < result_column_ids.size(); i++) {
337 linear_algebra_instruction->AddOperand(
338 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
339 }
340
341 fuzzerutil::UpdateModuleIdBound(
342 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
343 }
344
ReplaceOpVectorTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const345 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
346 opt::IRContext* ir_context,
347 opt::Instruction* linear_algebra_instruction) const {
348 // Gets OpVectorTimesScalar in operands.
349 auto vector = ir_context->get_def_use_mgr()->GetDef(
350 linear_algebra_instruction->GetSingleWordInOperand(0));
351 auto scalar = ir_context->get_def_use_mgr()->GetDef(
352 linear_algebra_instruction->GetSingleWordInOperand(1));
353
354 uint32_t vector_component_count = ir_context->get_type_mgr()
355 ->GetType(vector->type_id())
356 ->AsVector()
357 ->element_count();
358 std::vector<uint32_t> float_multiplication_ids(vector_component_count);
359 uint32_t fresh_id_index = 0;
360
361 for (uint32_t i = 0; i < vector_component_count; i++) {
362 // Extracts |vector| component.
363 uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
364 fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
365 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
366 ir_context, SpvOpCompositeExtract, scalar->type_id(), vector_extract_id,
367 opt::Instruction::OperandList(
368 {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
369 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
370
371 // Multiplies the |vector| component with the |scalar|.
372 uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
373 float_multiplication_ids[i] = float_multiplication_id;
374 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
375 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
376 ir_context, SpvOpFMul, scalar->type_id(), float_multiplication_id,
377 opt::Instruction::OperandList(
378 {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
379 {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
380 }
381
382 // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
383 // instruction.
384 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
385 linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
386 linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
387 for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
388 linear_algebra_instruction->AddOperand(
389 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
390 }
391 }
392
ReplaceOpMatrixTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const393 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
394 opt::IRContext* ir_context,
395 opt::Instruction* linear_algebra_instruction) const {
396 // Gets OpMatrixTimesScalar in operands.
397 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
398 linear_algebra_instruction->GetSingleWordInOperand(0));
399 auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
400 linear_algebra_instruction->GetSingleWordInOperand(1));
401
402 // Gets matrix information.
403 uint32_t matrix_column_count = ir_context->get_type_mgr()
404 ->GetType(matrix_instruction->type_id())
405 ->AsMatrix()
406 ->element_count();
407 auto matrix_column_type = ir_context->get_type_mgr()
408 ->GetType(matrix_instruction->type_id())
409 ->AsMatrix()
410 ->element_type();
411 uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
412
413 std::vector<uint32_t> composite_construct_ids(matrix_column_count);
414 uint32_t fresh_id_index = 0;
415
416 for (uint32_t i = 0; i < matrix_column_count; i++) {
417 // Extracts |matrix| column.
418 uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
419 fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
420 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
421 ir_context, SpvOpCompositeExtract,
422 ir_context->get_type_mgr()->GetId(matrix_column_type),
423 matrix_extract_id,
424 opt::Instruction::OperandList(
425 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
426 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
427
428 std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
429
430 for (uint32_t j = 0; j < matrix_column_size; j++) {
431 // Extracts |column| component.
432 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
433 fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
434 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
435 ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(),
436 column_extract_id,
437 opt::Instruction::OperandList(
438 {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
439 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
440
441 // Multiplies the |column| component with the |scalar|.
442 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
443 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
444 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
445 ir_context, SpvOpFMul, scalar_instruction->type_id(),
446 float_multiplication_ids[j],
447 opt::Instruction::OperandList(
448 {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
449 {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
450 }
451
452 // Constructs a new column multiplied by |scalar|.
453 opt::Instruction::OperandList composite_construct_in_operands;
454 for (uint32_t& float_multiplication_id : float_multiplication_ids) {
455 composite_construct_in_operands.push_back(
456 {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
457 }
458 composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
459 fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
460 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
461 ir_context, SpvOpCompositeConstruct,
462 ir_context->get_type_mgr()->GetId(matrix_column_type),
463 composite_construct_ids[i], composite_construct_in_operands));
464 }
465
466 // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
467 // instruction.
468 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
469 linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
470 linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
471 for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
472 linear_algebra_instruction->AddOperand(
473 {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
474 }
475 }
476
ReplaceOpVectorTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const477 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
478 opt::IRContext* ir_context,
479 opt::Instruction* linear_algebra_instruction) const {
480 // Gets vector information.
481 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
482 linear_algebra_instruction->GetSingleWordInOperand(0));
483 uint32_t vector_component_count = ir_context->get_type_mgr()
484 ->GetType(vector_instruction->type_id())
485 ->AsVector()
486 ->element_count();
487 auto vector_component_type = ir_context->get_type_mgr()
488 ->GetType(vector_instruction->type_id())
489 ->AsVector()
490 ->element_type();
491
492 // Extracts vector components.
493 uint32_t fresh_id_index = 0;
494 std::vector<uint32_t> vector_component_ids(vector_component_count);
495 for (uint32_t i = 0; i < vector_component_count; i++) {
496 vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
497 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
498 ir_context, SpvOpCompositeExtract,
499 ir_context->get_type_mgr()->GetId(vector_component_type),
500 vector_component_ids[i],
501 opt::Instruction::OperandList(
502 {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
503 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
504 }
505
506 // Gets matrix information.
507 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
508 linear_algebra_instruction->GetSingleWordInOperand(1));
509 uint32_t matrix_column_count = ir_context->get_type_mgr()
510 ->GetType(matrix_instruction->type_id())
511 ->AsMatrix()
512 ->element_count();
513 auto matrix_column_type = ir_context->get_type_mgr()
514 ->GetType(matrix_instruction->type_id())
515 ->AsMatrix()
516 ->element_type();
517
518 std::vector<uint32_t> result_component_ids(matrix_column_count);
519 for (uint32_t i = 0; i < matrix_column_count; i++) {
520 // Extracts matrix column.
521 uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
522 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
523 ir_context, SpvOpCompositeExtract,
524 ir_context->get_type_mgr()->GetId(matrix_column_type),
525 matrix_extract_id,
526 opt::Instruction::OperandList(
527 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
528 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
529
530 std::vector<uint32_t> float_multiplication_ids(vector_component_count);
531 for (uint32_t j = 0; j < vector_component_count; j++) {
532 // Extracts column component.
533 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
534 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
535 ir_context, SpvOpCompositeExtract,
536 ir_context->get_type_mgr()->GetId(vector_component_type),
537 column_extract_id,
538 opt::Instruction::OperandList(
539 {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
540 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
541
542 // Multiplies corresponding vector and column components.
543 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
544 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
545 ir_context, SpvOpFMul,
546 ir_context->get_type_mgr()->GetId(vector_component_type),
547 float_multiplication_ids[j],
548 opt::Instruction::OperandList(
549 {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
550 {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
551 }
552
553 // Adds the multiplication results.
554 std::vector<uint32_t> float_add_ids;
555 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
556 float_add_ids.push_back(float_add_id);
557 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
558 ir_context, SpvOpFAdd,
559 ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
560 opt::Instruction::OperandList(
561 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
562 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
563 for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
564 float_add_id = message_.fresh_ids(fresh_id_index++);
565 float_add_ids.push_back(float_add_id);
566 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
567 ir_context, SpvOpFAdd,
568 ir_context->get_type_mgr()->GetId(vector_component_type),
569 float_add_id,
570 opt::Instruction::OperandList(
571 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
572 {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
573 }
574
575 result_component_ids[i] = float_add_ids.back();
576 }
577
578 // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
579 // instruction.
580 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
581 linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
582 linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
583 for (uint32_t i = 2; i < result_component_ids.size(); i++) {
584 linear_algebra_instruction->AddOperand(
585 {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
586 }
587
588 fuzzerutil::UpdateModuleIdBound(
589 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
590 }
591
ReplaceOpMatrixTimesVector(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const592 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
593 opt::IRContext* ir_context,
594 opt::Instruction* linear_algebra_instruction) const {
595 // Gets matrix information.
596 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
597 linear_algebra_instruction->GetSingleWordInOperand(0));
598 uint32_t matrix_column_count = ir_context->get_type_mgr()
599 ->GetType(matrix_instruction->type_id())
600 ->AsMatrix()
601 ->element_count();
602 auto matrix_column_type = ir_context->get_type_mgr()
603 ->GetType(matrix_instruction->type_id())
604 ->AsMatrix()
605 ->element_type();
606 uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
607
608 // Extracts matrix columns.
609 uint32_t fresh_id_index = 0;
610 std::vector<uint32_t> matrix_column_ids(matrix_column_count);
611 for (uint32_t i = 0; i < matrix_column_count; i++) {
612 matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
613 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
614 ir_context, SpvOpCompositeExtract,
615 ir_context->get_type_mgr()->GetId(matrix_column_type),
616 matrix_column_ids[i],
617 opt::Instruction::OperandList(
618 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
619 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
620 }
621
622 // Gets vector information.
623 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
624 linear_algebra_instruction->GetSingleWordInOperand(1));
625 auto vector_component_type = ir_context->get_type_mgr()
626 ->GetType(vector_instruction->type_id())
627 ->AsVector()
628 ->element_type();
629
630 // Extracts vector components.
631 std::vector<uint32_t> vector_component_ids(matrix_column_count);
632 for (uint32_t i = 0; i < matrix_column_count; i++) {
633 vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
634 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
635 ir_context, SpvOpCompositeExtract,
636 ir_context->get_type_mgr()->GetId(vector_component_type),
637 vector_component_ids[i],
638 opt::Instruction::OperandList(
639 {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
640 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
641 }
642
643 std::vector<uint32_t> result_component_ids(matrix_row_count);
644 for (uint32_t i = 0; i < matrix_row_count; i++) {
645 std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
646 for (uint32_t j = 0; j < matrix_column_count; j++) {
647 // Extracts column component.
648 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
649 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
650 ir_context, SpvOpCompositeExtract,
651 ir_context->get_type_mgr()->GetId(vector_component_type),
652 column_extract_id,
653 opt::Instruction::OperandList(
654 {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
655 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
656
657 // Multiplies corresponding vector and column components.
658 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
659 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
660 ir_context, SpvOpFMul,
661 ir_context->get_type_mgr()->GetId(vector_component_type),
662 float_multiplication_ids[j],
663 opt::Instruction::OperandList(
664 {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
665 {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
666 }
667
668 // Adds the multiplication results.
669 std::vector<uint32_t> float_add_ids;
670 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
671 float_add_ids.push_back(float_add_id);
672 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
673 ir_context, SpvOpFAdd,
674 ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
675 opt::Instruction::OperandList(
676 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
677 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
678 for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
679 float_add_id = message_.fresh_ids(fresh_id_index++);
680 float_add_ids.push_back(float_add_id);
681 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
682 ir_context, SpvOpFAdd,
683 ir_context->get_type_mgr()->GetId(vector_component_type),
684 float_add_id,
685 opt::Instruction::OperandList(
686 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
687 {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
688 }
689
690 result_component_ids[i] = float_add_ids.back();
691 }
692
693 // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
694 // instruction.
695 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
696 linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
697 linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
698 for (uint32_t i = 2; i < result_component_ids.size(); i++) {
699 linear_algebra_instruction->AddOperand(
700 {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
701 }
702
703 fuzzerutil::UpdateModuleIdBound(
704 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
705 }
706
ReplaceOpMatrixTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const707 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix(
708 opt::IRContext* ir_context,
709 opt::Instruction* linear_algebra_instruction) const {
710 // Gets matrix 1 information.
711 auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
712 linear_algebra_instruction->GetSingleWordInOperand(0));
713 uint32_t matrix_1_column_count =
714 ir_context->get_type_mgr()
715 ->GetType(matrix_1_instruction->type_id())
716 ->AsMatrix()
717 ->element_count();
718 auto matrix_1_column_type = ir_context->get_type_mgr()
719 ->GetType(matrix_1_instruction->type_id())
720 ->AsMatrix()
721 ->element_type();
722 auto matrix_1_column_component_type =
723 matrix_1_column_type->AsVector()->element_type();
724 uint32_t matrix_1_row_count =
725 matrix_1_column_type->AsVector()->element_count();
726
727 // Gets matrix 2 information.
728 auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
729 linear_algebra_instruction->GetSingleWordInOperand(1));
730 uint32_t matrix_2_column_count =
731 ir_context->get_type_mgr()
732 ->GetType(matrix_2_instruction->type_id())
733 ->AsMatrix()
734 ->element_count();
735 auto matrix_2_column_type = ir_context->get_type_mgr()
736 ->GetType(matrix_2_instruction->type_id())
737 ->AsMatrix()
738 ->element_type();
739
740 uint32_t fresh_id_index = 0;
741 std::vector<uint32_t> result_column_ids(matrix_2_column_count);
742 for (uint32_t i = 0; i < matrix_2_column_count; i++) {
743 // Extracts matrix 2 column.
744 uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++);
745 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
746 ir_context, SpvOpCompositeExtract,
747 ir_context->get_type_mgr()->GetId(matrix_2_column_type),
748 matrix_2_column_id,
749 opt::Instruction::OperandList(
750 {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}},
751 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
752
753 std::vector<uint32_t> column_component_ids(matrix_1_row_count);
754 for (uint32_t j = 0; j < matrix_1_row_count; j++) {
755 std::vector<uint32_t> float_multiplication_ids(matrix_1_column_count);
756 for (uint32_t k = 0; k < matrix_1_column_count; k++) {
757 // Extracts matrix 1 column.
758 uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++);
759 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
760 ir_context, SpvOpCompositeExtract,
761 ir_context->get_type_mgr()->GetId(matrix_1_column_type),
762 matrix_1_column_id,
763 opt::Instruction::OperandList(
764 {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}},
765 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
766
767 // Extracts matrix 1 column component.
768 uint32_t matrix_1_column_component_id =
769 message_.fresh_ids(fresh_id_index++);
770 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
771 ir_context, SpvOpCompositeExtract,
772 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
773 matrix_1_column_component_id,
774 opt::Instruction::OperandList(
775 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}},
776 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
777
778 // Extracts matrix 2 column component.
779 uint32_t matrix_2_column_component_id =
780 message_.fresh_ids(fresh_id_index++);
781 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
782 ir_context, SpvOpCompositeExtract,
783 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
784 matrix_2_column_component_id,
785 opt::Instruction::OperandList(
786 {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}},
787 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
788
789 // Multiplies corresponding matrix 1 and matrix 2 column components.
790 float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++);
791 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
792 ir_context, SpvOpFMul,
793 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
794 float_multiplication_ids[k],
795 opt::Instruction::OperandList(
796 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}},
797 {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}})));
798 }
799
800 // Adds the multiplication results.
801 std::vector<uint32_t> float_add_ids;
802 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
803 float_add_ids.push_back(float_add_id);
804 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
805 ir_context, SpvOpFAdd,
806 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
807 float_add_id,
808 opt::Instruction::OperandList(
809 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
810 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
811 for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) {
812 float_add_id = message_.fresh_ids(fresh_id_index++);
813 float_add_ids.push_back(float_add_id);
814 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
815 ir_context, SpvOpFAdd,
816 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
817 float_add_id,
818 opt::Instruction::OperandList(
819 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}},
820 {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}})));
821 }
822
823 column_component_ids[j] = float_add_ids.back();
824 }
825
826 // Inserts the resulting matrix column.
827 opt::Instruction::OperandList in_operands;
828 for (auto& column_component_id : column_component_ids) {
829 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
830 }
831 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
832 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
833 ir_context, SpvOpCompositeConstruct,
834 ir_context->get_type_mgr()->GetId(matrix_1_column_type),
835 result_column_ids[i], opt::Instruction::OperandList(in_operands)));
836 }
837
838 // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct
839 // instruction.
840 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
841 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
842 linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
843 for (uint32_t i = 2; i < result_column_ids.size(); i++) {
844 linear_algebra_instruction->AddOperand(
845 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
846 }
847
848 fuzzerutil::UpdateModuleIdBound(
849 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
850 }
851
ReplaceOpOuterProduct(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const852 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpOuterProduct(
853 opt::IRContext* ir_context,
854 opt::Instruction* linear_algebra_instruction) const {
855 // Gets vector 1 information.
856 auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
857 linear_algebra_instruction->GetSingleWordInOperand(0));
858 uint32_t vector_1_component_count =
859 ir_context->get_type_mgr()
860 ->GetType(vector_1_instruction->type_id())
861 ->AsVector()
862 ->element_count();
863 auto vector_1_component_type = ir_context->get_type_mgr()
864 ->GetType(vector_1_instruction->type_id())
865 ->AsVector()
866 ->element_type();
867
868 // Gets vector 2 information.
869 auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
870 linear_algebra_instruction->GetSingleWordInOperand(1));
871 uint32_t vector_2_component_count =
872 ir_context->get_type_mgr()
873 ->GetType(vector_2_instruction->type_id())
874 ->AsVector()
875 ->element_count();
876
877 uint32_t fresh_id_index = 0;
878 std::vector<uint32_t> result_column_ids(vector_2_component_count);
879 for (uint32_t i = 0; i < vector_2_component_count; i++) {
880 // Extracts |vector_2| component.
881 uint32_t vector_2_component_id = message_.fresh_ids(fresh_id_index++);
882 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
883 ir_context, SpvOpCompositeExtract,
884 ir_context->get_type_mgr()->GetId(vector_1_component_type),
885 vector_2_component_id,
886 opt::Instruction::OperandList(
887 {{SPV_OPERAND_TYPE_ID, {vector_2_instruction->result_id()}},
888 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
889
890 std::vector<uint32_t> column_component_ids(vector_1_component_count);
891 for (uint32_t j = 0; j < vector_1_component_count; j++) {
892 // Extracts |vector_1| component.
893 uint32_t vector_1_component_id = message_.fresh_ids(fresh_id_index++);
894 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
895 ir_context, SpvOpCompositeExtract,
896 ir_context->get_type_mgr()->GetId(vector_1_component_type),
897 vector_1_component_id,
898 opt::Instruction::OperandList(
899 {{SPV_OPERAND_TYPE_ID, {vector_1_instruction->result_id()}},
900 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
901
902 // Multiplies |vector_1| and |vector_2| components.
903 column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
904 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
905 ir_context, SpvOpFMul,
906 ir_context->get_type_mgr()->GetId(vector_1_component_type),
907 column_component_ids[j],
908 opt::Instruction::OperandList(
909 {{SPV_OPERAND_TYPE_ID, {vector_2_component_id}},
910 {SPV_OPERAND_TYPE_ID, {vector_1_component_id}}})));
911 }
912
913 // Inserts the resulting matrix column.
914 opt::Instruction::OperandList in_operands;
915 for (auto& column_component_id : column_component_ids) {
916 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
917 }
918 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
919 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
920 ir_context, SpvOpCompositeConstruct, vector_1_instruction->type_id(),
921 result_column_ids[i], in_operands));
922 }
923
924 // The OpOuterProduct instruction is changed to an OpCompositeConstruct
925 // instruction.
926 linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
927 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
928 linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
929 for (uint32_t i = 2; i < result_column_ids.size(); i++) {
930 linear_algebra_instruction->AddOperand(
931 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
932 }
933
934 fuzzerutil::UpdateModuleIdBound(
935 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
936 }
937
ReplaceOpDot(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const938 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
939 opt::IRContext* ir_context,
940 opt::Instruction* linear_algebra_instruction) const {
941 // Gets OpDot in operands.
942 auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
943 linear_algebra_instruction->GetSingleWordInOperand(0));
944 auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
945 linear_algebra_instruction->GetSingleWordInOperand(1));
946
947 uint32_t vectors_component_count = ir_context->get_type_mgr()
948 ->GetType(vector_1->type_id())
949 ->AsVector()
950 ->element_count();
951 std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
952 uint32_t fresh_id_index = 0;
953
954 for (uint32_t i = 0; i < vectors_component_count; i++) {
955 // Extracts |vector_1| component.
956 uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
957 fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
958 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
959 ir_context, SpvOpCompositeExtract,
960 linear_algebra_instruction->type_id(), vector_1_extract_id,
961 opt::Instruction::OperandList(
962 {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
963 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
964
965 // Extracts |vector_2| component.
966 uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
967 fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
968 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
969 ir_context, SpvOpCompositeExtract,
970 linear_algebra_instruction->type_id(), vector_2_extract_id,
971 opt::Instruction::OperandList(
972 {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
973 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
974
975 // Multiplies the pair of components.
976 float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
977 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
978 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
979 ir_context, SpvOpFMul, linear_algebra_instruction->type_id(),
980 float_multiplication_ids[i],
981 opt::Instruction::OperandList(
982 {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
983 {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
984 }
985
986 // If the vector has 2 components, then there will be 2 float multiplication
987 // instructions.
988 if (vectors_component_count == 2) {
989 linear_algebra_instruction->SetOpcode(SpvOpFAdd);
990 linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
991 linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
992 } else {
993 // The first OpFAdd instruction has as operands the first two OpFMul
994 // instructions.
995 std::vector<uint32_t> float_add_ids;
996 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
997 float_add_ids.push_back(float_add_id);
998 fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
999 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1000 ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
1001 float_add_id,
1002 opt::Instruction::OperandList(
1003 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
1004 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
1005
1006 // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
1007 // instruction.
1008 for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
1009 float_add_id = message_.fresh_ids(fresh_id_index++);
1010 fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1011 float_add_ids.push_back(float_add_id);
1012 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1013 ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
1014 float_add_id,
1015 opt::Instruction::OperandList(
1016 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
1017 {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
1018 }
1019
1020 // The last OpFAdd instruction is got by changing some of the OpDot
1021 // instruction attributes.
1022 linear_algebra_instruction->SetOpcode(SpvOpFAdd);
1023 linear_algebra_instruction->SetInOperand(
1024 0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
1025 linear_algebra_instruction->SetInOperand(
1026 1, {float_add_ids[float_add_ids.size() - 1]});
1027 }
1028 }
1029
1030 std::unordered_set<uint32_t>
GetFreshIds() const1031 TransformationReplaceLinearAlgebraInstruction::GetFreshIds() const {
1032 std::unordered_set<uint32_t> result;
1033 for (auto id : message_.fresh_ids()) {
1034 result.insert(id);
1035 }
1036 return result;
1037 }
1038
1039 } // namespace fuzz
1040 } // namespace spvtools
1041