// Copyright (c) 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/fuzz/force_render_red.h" #include "source/fuzz/fact_manager/fact_manager.h" #include "source/fuzz/instruction_descriptor.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" #include "source/fuzz/transformation_context.h" #include "source/fuzz/transformation_replace_constant_with_uniform.h" #include "source/fuzz/uniform_buffer_element_descriptor.h" #include "source/opt/build_module.h" #include "source/opt/ir_context.h" #include "source/opt/types.h" #include "source/util/make_unique.h" #include "tools/util/cli_consumer.h" namespace spvtools { namespace fuzz { namespace { // Helper method to find the fragment shader entry point, complaining if there // is no shader or if there is no fragment entry point. opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context, MessageConsumer message_consumer) { // Check that this is a fragment shader bool found_capability_shader = false; for (auto& capability : ir_context->capabilities()) { assert(capability.opcode() == SpvOpCapability); if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) { found_capability_shader = true; break; } } if (!found_capability_shader) { message_consumer( SPV_MSG_ERROR, nullptr, {}, "Forcing of red rendering requires the Shader capability."); return nullptr; } opt::Instruction* fragment_entry_point = nullptr; for (auto& entry_point : ir_context->module()->entry_points()) { if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) { fragment_entry_point = &entry_point; break; } } if (fragment_entry_point == nullptr) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "Forcing of red rendering requires an entry point with " "the Fragment execution model."); return nullptr; } for (auto& function : *ir_context->module()) { if (function.result_id() == fragment_entry_point->GetSingleWordInOperand(1)) { return &function; } } assert( false && "A valid module must have a function associate with each entry point."); return nullptr; } // Helper method to check that there is a single vec4 output variable and get a // pointer to it. opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context, MessageConsumer message_consumer) { opt::Instruction* output_variable = nullptr; for (auto& inst : ir_context->types_values()) { if (inst.opcode() == SpvOpVariable && inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) { if (output_variable != nullptr) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "Only one output variable can be handled at present; " "found multiple."); return nullptr; } output_variable = &inst; // Do not break, as we want to check for multiple output variables. } } if (output_variable == nullptr) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "No output variable to which to write red was found."); return nullptr; } auto output_variable_base_type = ir_context->get_type_mgr() ->GetType(output_variable->type_id()) ->AsPointer() ->pointee_type() ->AsVector(); if (!output_variable_base_type || output_variable_base_type->element_count() != 4 || !output_variable_base_type->element_type()->AsFloat()) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "The output variable must have type vec4."); return nullptr; } return output_variable; } // Helper to get the ids of float constants 0.0 and 1.0, creating them if // necessary. std::pair FindOrCreateFloatZeroAndOne( opt::IRContext* ir_context, opt::analysis::Float* float_type) { float one = 1.0; uint32_t one_as_uint; memcpy(&one_as_uint, &one, sizeof(float)); std::vector zero_bytes = {0}; std::vector one_bytes = {one_as_uint}; auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant( MakeUnique(float_type, zero_bytes)); auto constant_one = ir_context->get_constant_mgr()->RegisterConstant( MakeUnique(float_type, one_bytes)); auto constant_zero_id = ir_context->get_constant_mgr() ->GetDefiningInstruction(constant_zero) ->result_id(); auto constant_one_id = ir_context->get_constant_mgr() ->GetDefiningInstruction(constant_one) ->result_id(); return std::pair(constant_zero_id, constant_one_id); } std::unique_ptr MakeConstantUniformReplacement(opt::IRContext* ir_context, const FactManager& fact_manager, uint32_t constant_id, uint32_t greater_than_instruction, uint32_t in_operand_index) { return MakeUnique( MakeIdUseDescriptor(constant_id, MakeInstructionDescriptor(greater_than_instruction, SpvOpFOrdGreaterThan, 0), in_operand_index), fact_manager.GetUniformDescriptorsForConstant(constant_id)[0], ir_context->TakeNextId(), ir_context->TakeNextId()); } } // namespace bool ForceRenderRed( const spv_target_env& target_env, spv_validator_options validator_options, const std::vector& binary_in, const spvtools::fuzz::protobufs::FactSequence& initial_facts, std::vector* binary_out) { auto message_consumer = spvtools::utils::CLIMessageConsumer; spvtools::SpirvTools tools(target_env); if (!tools.IsValid()) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "Failed to create SPIRV-Tools interface; stopping."); return false; } // Initial binary should be valid. if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) { message_consumer(SPV_MSG_ERROR, nullptr, {}, "Initial binary is invalid; stopping."); return false; } // Build the module from the input binary. std::unique_ptr ir_context = BuildModule( target_env, message_consumer, binary_in.data(), binary_in.size()); assert(ir_context); // Set up a fact manager with any given initial facts. TransformationContext transformation_context( MakeUnique(ir_context.get()), validator_options); for (auto& fact : initial_facts.fact()) { transformation_context.GetFactManager()->MaybeAddFact(fact); } auto entry_point_function = FindFragmentShaderEntryPoint(ir_context.get(), message_consumer); auto output_variable = FindVec4OutputVariable(ir_context.get(), message_consumer); if (entry_point_function == nullptr || output_variable == nullptr) { return false; } opt::analysis::Float temp_float_type(32); opt::analysis::Float* float_type = ir_context->get_type_mgr() ->GetRegisteredType(&temp_float_type) ->AsFloat(); std::pair zero_one_float_ids = FindOrCreateFloatZeroAndOne(ir_context.get(), float_type); // Make the new exit block auto new_exit_block_id = ir_context->TakeNextId(); { auto label = MakeUnique(ir_context.get(), SpvOpLabel, 0, new_exit_block_id, opt::Instruction::OperandList()); auto new_exit_block = MakeUnique(std::move(label)); new_exit_block->AddInstruction(MakeUnique( ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList())); new_exit_block->SetParent(entry_point_function); entry_point_function->AddBasicBlock(std::move(new_exit_block)); } // Make the new entry block { auto label = MakeUnique(ir_context.get(), SpvOpLabel, 0, ir_context->TakeNextId(), opt::Instruction::OperandList()); auto new_entry_block = MakeUnique(std::move(label)); // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing // the colour red. opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}}; opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}}; opt::Instruction::OperandList op_composite_construct_operands = { one_float, zero_float, zero_float, one_float}; auto temp_vec4 = opt::analysis::Vector(float_type, 4); auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4); auto red = MakeUnique( ir_context.get(), SpvOpCompositeConstruct, vec4_id, ir_context->TakeNextId(), op_composite_construct_operands); auto red_id = red->result_id(); new_entry_block->AddInstruction(std::move(red)); // Make an instruction to store red into the output color. opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID, {output_variable->result_id()}}; opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}}; opt::Instruction::OperandList op_store_operands = {variable_to_store_into, value_to_be_stored}; new_entry_block->AddInstruction(MakeUnique( ir_context.get(), SpvOpStore, 0, 0, op_store_operands)); // We are going to attempt to construct 'false' as an expression of the form // 'literal1 > literal2'. If we succeed, we will later replace each literal // with a uniform of the same value - we can only do that replacement once // we have added the entry block to the module. std::unique_ptr first_greater_then_operand_replacement = nullptr; std::unique_ptr second_greater_then_operand_replacement = nullptr; uint32_t id_guaranteed_to_be_false = 0; opt::analysis::Bool temp_bool_type; opt::analysis::Bool* registered_bool_type = ir_context->get_type_mgr() ->GetRegisteredType(&temp_bool_type) ->AsBool(); auto float_type_id = ir_context->get_type_mgr()->GetId(float_type); auto types_for_which_uniforms_are_known = transformation_context.GetFactManager() ->GetTypesForWhichUniformValuesAreKnown(); // Check whether we have any float uniforms. if (std::find(types_for_which_uniforms_are_known.begin(), types_for_which_uniforms_are_known.end(), float_type_id) != types_for_which_uniforms_are_known.end()) { // We have at least one float uniform; let's see whether we have at least // two. auto available_constants = transformation_context.GetFactManager() ->GetConstantsAvailableFromUniformsForType(float_type_id); if (available_constants.size() > 1) { // Grab the float constants associated with the first two known float // uniforms. auto first_constant = ir_context->get_constant_mgr() ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef( available_constants[0])) ->AsFloatConstant(); auto second_constant = ir_context->get_constant_mgr() ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef( available_constants[1])) ->AsFloatConstant(); // Now work out which of the two constants is larger than the other. uint32_t larger_constant_index = 0; uint32_t smaller_constant_index = 0; if (first_constant->GetFloat() > second_constant->GetFloat()) { larger_constant_index = 0; smaller_constant_index = 1; } else if (first_constant->GetFloat() < second_constant->GetFloat()) { larger_constant_index = 1; smaller_constant_index = 0; } // Only proceed with these constants if they have turned out to be // distinct. if (larger_constant_index != smaller_constant_index) { // We are in a position to create 'false' as 'literal1 > literal2', so // reserve an id for this computation; this id will end up being // guaranteed to be 'false'. id_guaranteed_to_be_false = ir_context->TakeNextId(); auto smaller_constant = available_constants[smaller_constant_index]; auto larger_constant = available_constants[larger_constant_index]; opt::Instruction::OperandList greater_than_operands = { {SPV_OPERAND_TYPE_ID, {smaller_constant}}, {SPV_OPERAND_TYPE_ID, {larger_constant}}}; new_entry_block->AddInstruction(MakeUnique( ir_context.get(), SpvOpFOrdGreaterThan, ir_context->get_type_mgr()->GetId(registered_bool_type), id_guaranteed_to_be_false, greater_than_operands)); first_greater_then_operand_replacement = MakeConstantUniformReplacement( ir_context.get(), *transformation_context.GetFactManager(), smaller_constant, id_guaranteed_to_be_false, 0); second_greater_then_operand_replacement = MakeConstantUniformReplacement( ir_context.get(), *transformation_context.GetFactManager(), larger_constant, id_guaranteed_to_be_false, 1); } } } if (id_guaranteed_to_be_false == 0) { auto constant_false = ir_context->get_constant_mgr()->RegisterConstant( MakeUnique(registered_bool_type, false)); id_guaranteed_to_be_false = ir_context->get_constant_mgr() ->GetDefiningInstruction(constant_false) ->result_id(); } opt::Operand false_condition = {SPV_OPERAND_TYPE_ID, {id_guaranteed_to_be_false}}; opt::Operand then_block = {SPV_OPERAND_TYPE_ID, {entry_point_function->entry()->id()}}; opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}}; opt::Instruction::OperandList op_branch_conditional_operands = { false_condition, then_block, else_block}; new_entry_block->AddInstruction( MakeUnique(ir_context.get(), SpvOpBranchConditional, 0, 0, op_branch_conditional_operands)); entry_point_function->InsertBasicBlockBefore( std::move(new_entry_block), entry_point_function->entry().get()); for (auto& replacement : {first_greater_then_operand_replacement.get(), second_greater_then_operand_replacement.get()}) { if (replacement) { assert(replacement->IsApplicable(ir_context.get(), transformation_context)); replacement->Apply(ir_context.get(), &transformation_context); } } } // Write out the module as a binary. ir_context->module()->ToBinary(binary_out, false); return true; } } // namespace fuzz } // namespace spvtools