spirv-reader: support scalar spec constants Translate OpSpecConstantTrue, OpSpecConstantFalse, and OpSpecConstant. The latter only can be used with integer or float scalars. If the constant has a SpecId decoration, then generate a module-scope decorated constant. Otherwise generate a module-scope constant without decorations. Register the ID so we know to use the declared const identifier in expressions later in the module. Bug: tint:156 Change-Id: Icd6e9b60225ced7ee99963c4f85cec1eb0e3ae6b Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/31541 Commit-Queue: David Neto <dneto@google.com> Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 949aadc..98b2c83 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc
@@ -1700,7 +1700,7 @@ if (failed()) { return {}; } - if (identifier_values_.count(id)) { + if (identifier_values_.count(id) || parser_impl_.IsScalarSpecConstant(id)) { return TypedExpression( parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()), std::make_unique<ast::IdentifierExpression>(namer_.Name(id)));
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 108bd2a..d38e96c 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc
@@ -40,6 +40,7 @@ #include "src/ast/bool_literal.h" #include "src/ast/builtin.h" #include "src/ast/builtin_decoration.h" +#include "src/ast/constant_id_decoration.h" #include "src/ast/decorated_variable.h" #include "src/ast/float_literal.h" #include "src/ast/scalar_constructor_expression.h" @@ -534,6 +535,9 @@ if (!RegisterTypes()) { return false; } + if (!EmitScalarSpecConstants()) { + return false; + } if (!EmitModuleScopeVariables()) { return false; } @@ -947,6 +951,82 @@ return success_; } +bool ParserImpl::EmitScalarSpecConstants() { + if (!success_) { + return false; + } + // Generate a module-scope const declaration for each instruction + // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant. + for (auto& inst : module_->types_values()) { + // These will be populated for a valid scalar spec constant. + ast::type::Type* ast_type = nullptr; + std::unique_ptr<ast::ScalarConstructorExpression> ast_expr; + + switch (inst.opcode()) { + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: { + ast_type = ConvertType(inst.type_id()); + ast_expr = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::BoolLiteral>( + ast_type, inst.opcode() == SpvOpSpecConstantTrue)); + break; + } + case SpvOpSpecConstant: { + ast_type = ConvertType(inst.type_id()); + const uint32_t literal_value = inst.GetSingleWordInOperand(0); + if (ast_type->IsI32()) { + ast_expr = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::SintLiteral>( + ast_type, static_cast<int32_t>(literal_value))); + } else if (ast_type->IsU32()) { + ast_expr = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::UintLiteral>( + ast_type, static_cast<uint32_t>(literal_value))); + } else if (ast_type->IsF32()) { + float float_value; + // Copy the bits so we can read them as a float. + std::memcpy(&float_value, &literal_value, sizeof(float_value)); + ast_expr = std::make_unique<ast::ScalarConstructorExpression>( + std::make_unique<ast::FloatLiteral>(ast_type, float_value)); + } else { + return Fail() << " invalid result type for OpSpecConstant " + << inst.PrettyPrint(); + } + break; + } + default: + break; + } + if (ast_type && ast_expr) { + auto ast_var = + MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type); + ast::VariableDecorationList spec_id_decos; + for (const auto& deco : GetDecorationsFor(inst.result_id())) { + if ((deco.size() == 2) && (deco[0] == SpvDecorationSpecId)) { + auto cid = std::make_unique<ast::ConstantIdDecoration>(deco[1]); + spec_id_decos.push_back(std::move(cid)); + break; + } + } + if (spec_id_decos.empty()) { + // Register it as a named constant, without specialization id. + ast_var->set_is_const(true); + ast_var->set_constructor(std::move(ast_expr)); + ast_module_.AddGlobalVariable(std::move(ast_var)); + } else { + auto ast_deco_var = + std::make_unique<ast::DecoratedVariable>(std::move(ast_var)); + ast_deco_var->set_is_const(true); + ast_deco_var->set_constructor(std::move(ast_expr)); + ast_deco_var->set_decorations(std::move(spec_id_decos)); + ast_module_.AddGlobalVariable(std::move(ast_deco_var)); + } + scalar_spec_constants_.insert(inst.result_id()); + } + } + return success_; +} + void ParserImpl::MaybeGenerateAlias(uint32_t type_id, const spvtools::opt::analysis::Type* type) { if (!success_) {
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index dfffc8d..7037397 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h
@@ -246,6 +246,12 @@ /// @returns true if parser is still successful. bool RegisterTypes(); + /// Emit const definitions for scalar specialization constants generated + /// by one of OpConstantTrue, OpConstantFalse, or OpSpecConstant. + /// This is a no-op if the parser has already failed. + /// @returns true if parser is still successful. + bool EmitScalarSpecConstants(); + /// Emits module-scope variables. /// This is a no-op if the parser has already failed. /// @returns true if parser is still successful. @@ -373,6 +379,14 @@ /// @returns true if the given string is a valid WGSL identifier. static bool IsValidIdentifier(const std::string& str); + /// Returns true if the given SPIR-V ID is a declared specialization constant, + /// generated by one of OpConstantTrue, OpConstantFalse, or OpSpecConstant + /// @param id a SPIR-V result ID + /// @returns true if the ID is a scalar spec constant. + bool IsScalarSpecConstant(uint32_t id) { + return scalar_spec_constants_.find(id) != scalar_spec_constants_.end(); + } + private: /// Converts a specific SPIR-V type to a Tint type. Integer case ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); @@ -489,6 +503,9 @@ // The struct types with only read-only members. std::unordered_set<ast::type::Type*> read_only_struct_types_; + // The IDs of scalar spec constants + std::unordered_set<uint32_t> scalar_spec_constants_; + // Maps function_id to a list of entrypoint information std::unordered_map<uint32_t, std::vector<EntryPointInfo>> function_to_ep_info_;
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index 467d79a..f829fff 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -15,6 +15,7 @@ #include <string> #include "gmock/gmock.h" +#include "src/reader/spirv/function.h" #include "src/reader/spirv/parser_impl.h" #include "src/reader/spirv/parser_impl_test_helper.h" #include "src/reader/spirv/spirv_tools_helpers_test.h" @@ -1489,6 +1490,186 @@ })")) << module_str; } +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_True) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + OpDecorate %c SpecId 12 + %bool = OpTypeBool + %c = OpSpecConstantTrue %bool + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + DecoratedVariableConst{ + Decorations{ + ConstantIdDecoration{12} + } + myconst + none + __bool + { + ScalarConstructor{true} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_False) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + OpDecorate %c SpecId 12 + %bool = OpTypeBool + %c = OpSpecConstantFalse %bool + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + DecoratedVariableConst{ + Decorations{ + ConstantIdDecoration{12} + } + myconst + none + __bool + { + ScalarConstructor{false} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_U32) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + OpDecorate %c SpecId 12 + %uint = OpTypeInt 32 0 + %c = OpSpecConstant %uint 42 + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + DecoratedVariableConst{ + Decorations{ + ConstantIdDecoration{12} + } + myconst + none + __u32 + { + ScalarConstructor{42} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_I32) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + OpDecorate %c SpecId 12 + %int = OpTypeInt 32 1 + %c = OpSpecConstant %int 42 + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + DecoratedVariableConst{ + Decorations{ + ConstantIdDecoration{12} + } + myconst + none + __i32 + { + ScalarConstructor{42} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_F32) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + OpDecorate %c SpecId 12 + %float = OpTypeFloat 32 + %c = OpSpecConstant %float 2.5 + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + DecoratedVariableConst{ + Decorations{ + ConstantIdDecoration{12} + } + myconst + none + __f32 + { + ScalarConstructor{2.500000} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, + ModuleScopeVar_ScalarSpecConstant_DeclareConst_F32_WithoutSpecId) { + // When we don't have a spec ID, declare an undecorated module-scope constant. + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + %float = OpTypeFloat 32 + %c = OpSpecConstant %float 2.5 + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->module().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + VariableConst{ + myconst + none + __f32 + { + ScalarConstructor{2.500000} + } + } +})")) << module_str; +} + +TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_UsedInFunction) { + auto* p = parser(test::Assemble(R"( + OpName %c "myconst" + %float = OpTypeFloat 32 + %c = OpSpecConstant %float 2.5 + %floatfn = OpTypeFunction %float + %100 = OpFunction %float None %floatfn + %entry = OpLabel + %1 = OpIAdd %float %c %c + OpReturn + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_TRUE(p->error().empty()); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + VariableConst{ + x_1 + none + __f32 + { + Binary{ + Identifier{myconst} + add + Identifier{myconst} + } + } + })")) + << ToString(fe.ast_body()); +} + } // namespace } // namespace spirv } // namespace reader