spirv-reader: support scalar spec constants

Translate OpSpecConstantTrue, OpSpecConstantFalse, and OpSpecConstant.
The latter only can be used with integer or float scalars.

If the constant has a SpecId decoration, then generate a module-scope
decorated constant.  Otherwise generate a module-scope constant without
decorations.

Register the ID so we know to use the declared const identifier in
expressions later in the module.

Bug: tint:156
Change-Id: Icd6e9b60225ced7ee99963c4f85cec1eb0e3ae6b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/31541
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 949aadc..98b2c83 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -1700,7 +1700,7 @@
   if (failed()) {
     return {};
   }
-  if (identifier_values_.count(id)) {
+  if (identifier_values_.count(id) || parser_impl_.IsScalarSpecConstant(id)) {
     return TypedExpression(
         parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
         std::make_unique<ast::IdentifierExpression>(namer_.Name(id)));
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 108bd2a..d38e96c 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -40,6 +40,7 @@
 #include "src/ast/bool_literal.h"
 #include "src/ast/builtin.h"
 #include "src/ast/builtin_decoration.h"
+#include "src/ast/constant_id_decoration.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/scalar_constructor_expression.h"
@@ -534,6 +535,9 @@
   if (!RegisterTypes()) {
     return false;
   }
+  if (!EmitScalarSpecConstants()) {
+    return false;
+  }
   if (!EmitModuleScopeVariables()) {
     return false;
   }
@@ -947,6 +951,82 @@
   return success_;
 }
 
+bool ParserImpl::EmitScalarSpecConstants() {
+  if (!success_) {
+    return false;
+  }
+  // Generate a module-scope const declaration for each instruction
+  // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant.
+  for (auto& inst : module_->types_values()) {
+    // These will be populated for a valid scalar spec constant.
+    ast::type::Type* ast_type = nullptr;
+    std::unique_ptr<ast::ScalarConstructorExpression> ast_expr;
+
+    switch (inst.opcode()) {
+      case SpvOpSpecConstantTrue:
+      case SpvOpSpecConstantFalse: {
+        ast_type = ConvertType(inst.type_id());
+        ast_expr = std::make_unique<ast::ScalarConstructorExpression>(
+            std::make_unique<ast::BoolLiteral>(
+                ast_type, inst.opcode() == SpvOpSpecConstantTrue));
+        break;
+      }
+      case SpvOpSpecConstant: {
+        ast_type = ConvertType(inst.type_id());
+        const uint32_t literal_value = inst.GetSingleWordInOperand(0);
+        if (ast_type->IsI32()) {
+          ast_expr = std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::SintLiteral>(
+                  ast_type, static_cast<int32_t>(literal_value)));
+        } else if (ast_type->IsU32()) {
+          ast_expr = std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::UintLiteral>(
+                  ast_type, static_cast<uint32_t>(literal_value)));
+        } else if (ast_type->IsF32()) {
+          float float_value;
+          // Copy the bits so we can read them as a float.
+          std::memcpy(&float_value, &literal_value, sizeof(float_value));
+          ast_expr = std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::FloatLiteral>(ast_type, float_value));
+        } else {
+          return Fail() << " invalid result type for OpSpecConstant "
+                        << inst.PrettyPrint();
+        }
+        break;
+      }
+      default:
+        break;
+    }
+    if (ast_type && ast_expr) {
+      auto ast_var =
+          MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type);
+      ast::VariableDecorationList spec_id_decos;
+      for (const auto& deco : GetDecorationsFor(inst.result_id())) {
+        if ((deco.size() == 2) && (deco[0] == SpvDecorationSpecId)) {
+          auto cid = std::make_unique<ast::ConstantIdDecoration>(deco[1]);
+          spec_id_decos.push_back(std::move(cid));
+          break;
+        }
+      }
+      if (spec_id_decos.empty()) {
+        // Register it as a named constant, without specialization id.
+        ast_var->set_is_const(true);
+        ast_var->set_constructor(std::move(ast_expr));
+        ast_module_.AddGlobalVariable(std::move(ast_var));
+      } else {
+        auto ast_deco_var =
+            std::make_unique<ast::DecoratedVariable>(std::move(ast_var));
+        ast_deco_var->set_is_const(true);
+        ast_deco_var->set_constructor(std::move(ast_expr));
+        ast_deco_var->set_decorations(std::move(spec_id_decos));
+        ast_module_.AddGlobalVariable(std::move(ast_deco_var));
+      }
+      scalar_spec_constants_.insert(inst.result_id());
+    }
+  }
+  return success_;
+}
+
 void ParserImpl::MaybeGenerateAlias(uint32_t type_id,
                                     const spvtools::opt::analysis::Type* type) {
   if (!success_) {
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index dfffc8d..7037397 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -246,6 +246,12 @@
   /// @returns true if parser is still successful.
   bool RegisterTypes();
 
+  /// Emit const definitions for scalar specialization constants generated
+  /// by one of OpConstantTrue, OpConstantFalse, or OpSpecConstant.
+  /// This is a no-op if the parser has already failed.
+  /// @returns true if parser is still successful.
+  bool EmitScalarSpecConstants();
+
   /// Emits module-scope variables.
   /// This is a no-op if the parser has already failed.
   /// @returns true if parser is still successful.
@@ -373,6 +379,14 @@
   /// @returns true if the given string is a valid WGSL identifier.
   static bool IsValidIdentifier(const std::string& str);
 
+  /// Returns true if the given SPIR-V ID is a declared specialization constant,
+  /// generated by one of OpConstantTrue, OpConstantFalse, or OpSpecConstant
+  /// @param id a SPIR-V result ID
+  /// @returns true if the ID is a scalar spec constant.
+  bool IsScalarSpecConstant(uint32_t id) {
+    return scalar_spec_constants_.find(id) != scalar_spec_constants_.end();
+  }
+
  private:
   /// Converts a specific SPIR-V type to a Tint type. Integer case
   ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
@@ -489,6 +503,9 @@
   // The struct types with only read-only members.
   std::unordered_set<ast::type::Type*> read_only_struct_types_;
 
+  // The IDs of scalar spec constants
+  std::unordered_set<uint32_t> scalar_spec_constants_;
+
   // Maps function_id to a list of entrypoint information
   std::unordered_map<uint32_t, std::vector<EntryPointInfo>>
       function_to_ep_info_;
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 467d79a..f829fff 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -15,6 +15,7 @@
 #include <string>
 
 #include "gmock/gmock.h"
+#include "src/reader/spirv/function.h"
 #include "src/reader/spirv/parser_impl.h"
 #include "src/reader/spirv/parser_impl_test_helper.h"
 #include "src/reader/spirv/spirv_tools_helpers_test.h"
@@ -1489,6 +1490,186 @@
 })")) << module_str;
 }
 
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_True) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     OpDecorate %c SpecId 12
+     %bool = OpTypeBool
+     %c = OpSpecConstantTrue %bool
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  DecoratedVariableConst{
+    Decorations{
+      ConstantIdDecoration{12}
+    }
+    myconst
+    none
+    __bool
+    {
+      ScalarConstructor{true}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_False) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     OpDecorate %c SpecId 12
+     %bool = OpTypeBool
+     %c = OpSpecConstantFalse %bool
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  DecoratedVariableConst{
+    Decorations{
+      ConstantIdDecoration{12}
+    }
+    myconst
+    none
+    __bool
+    {
+      ScalarConstructor{false}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_U32) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     OpDecorate %c SpecId 12
+     %uint = OpTypeInt 32 0
+     %c = OpSpecConstant %uint 42
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  DecoratedVariableConst{
+    Decorations{
+      ConstantIdDecoration{12}
+    }
+    myconst
+    none
+    __u32
+    {
+      ScalarConstructor{42}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_I32) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     OpDecorate %c SpecId 12
+     %int = OpTypeInt 32 1
+     %c = OpSpecConstant %int 42
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  DecoratedVariableConst{
+    Decorations{
+      ConstantIdDecoration{12}
+    }
+    myconst
+    none
+    __i32
+    {
+      ScalarConstructor{42}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_DeclareConst_F32) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     OpDecorate %c SpecId 12
+     %float = OpTypeFloat 32
+     %c = OpSpecConstant %float 2.5
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  DecoratedVariableConst{
+    Decorations{
+      ConstantIdDecoration{12}
+    }
+    myconst
+    none
+    __f32
+    {
+      ScalarConstructor{2.500000}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest,
+       ModuleScopeVar_ScalarSpecConstant_DeclareConst_F32_WithoutSpecId) {
+  // When we don't have a spec ID, declare an undecorated module-scope constant.
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     %float = OpTypeFloat 32
+     %c = OpSpecConstant %float 2.5
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->module().to_str();
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  VariableConst{
+    myconst
+    none
+    __f32
+    {
+      ScalarConstructor{2.500000}
+    }
+  }
+})")) << module_str;
+}
+
+TEST_F(SpvParserTest, ModuleScopeVar_ScalarSpecConstant_UsedInFunction) {
+  auto* p = parser(test::Assemble(R"(
+     OpName %c "myconst"
+     %float = OpTypeFloat 32
+     %c = OpSpecConstant %float 2.5
+     %floatfn = OpTypeFunction %float
+     %100 = OpFunction %float None %floatfn
+     %entry = OpLabel
+     %1 = OpIAdd %float %c %c
+     OpReturn
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_TRUE(p->error().empty());
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __f32
+    {
+      Binary{
+        Identifier{myconst}
+        add
+        Identifier{myconst}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader