[spirv-writer] Emit specialization constants.

This CL adds OpSpec constants to the SPIRV-Writer.

Bug: tint:151
Change-Id: I309013ca0b4cb514edd92fab3dab2e4faa15969a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/29101
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 6b4866d..c843386 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -355,7 +355,7 @@
 uint32_t Builder::GenerateU32Literal(uint32_t val) {
   ast::type::U32Type u32;
   ast::SintLiteral lit(&u32, val);
-  return GenerateLiteralIfNeeded(&lit);
+  return GenerateLiteralIfNeeded(nullptr, &lit);
 }
 
 bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
@@ -465,7 +465,7 @@
     return GenerateCallExpression(expr->AsCall());
   }
   if (expr->IsConstructor()) {
-    return GenerateConstructorExpression(expr->AsConstructor(), false);
+    return GenerateConstructorExpression(nullptr, expr->AsConstructor(), false);
   }
   if (expr->IsIdentifier()) {
     return GenerateIdentifierExpression(expr->AsIdentifier());
@@ -611,7 +611,7 @@
   // TODO(dsinclair) We could detect if the constructor is fully const and emit
   // an initializer value for the variable instead of doing the OpLoad.
   ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded());
-  auto null_id = GenerateLiteralIfNeeded(&nl);
+  auto null_id = GenerateLiteralIfNeeded(var, &nl);
   if (null_id == 0) {
     return 0;
   }
@@ -642,8 +642,8 @@
       return false;
     }
 
-    init_id = GenerateConstructorExpression(var->constructor()->AsConstructor(),
-                                            true);
+    init_id = GenerateConstructorExpression(
+        var, var->constructor()->AsConstructor(), true);
     if (init_id == 0) {
       return false;
     }
@@ -689,7 +689,7 @@
         var->storage_class() == ast::StorageClass::kNone ||
         var->storage_class() == ast::StorageClass::kOutput) {
       ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded());
-      init_id = GenerateLiteralIfNeeded(&nl);
+      init_id = GenerateLiteralIfNeeded(var, &nl);
       if (init_id == 0) {
         return 0;
       }
@@ -718,6 +718,8 @@
             spv::Op::OpDecorate,
             {Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
              Operand::Int(deco->AsSet()->value())});
+      } else if (deco->IsConstantId()) {
+        // Spec constants are handled elsewhere
       } else {
         error_ = "unknown decoration";
         return false;
@@ -1033,10 +1035,11 @@
 }
 
 uint32_t Builder::GenerateConstructorExpression(
+    ast::Variable* var,
     ast::ConstructorExpression* expr,
     bool is_global_init) {
   if (expr->IsScalarConstructor()) {
-    return GenerateLiteralIfNeeded(expr->AsScalarConstructor()->literal());
+    return GenerateLiteralIfNeeded(var, expr->AsScalarConstructor()->literal());
   }
   if (expr->IsTypeConstructor()) {
     return GenerateTypeConstructorExpression(expr->AsTypeConstructor(),
@@ -1055,7 +1058,7 @@
   // Generate the zero initializer if there are no values provided.
   if (values.empty()) {
     ast::NullLiteral nl(init->type()->UnwrapPtrIfNeeded());
-    return GenerateLiteralIfNeeded(&nl);
+    return GenerateLiteralIfNeeded(nullptr, &nl);
   }
 
   std::ostringstream out;
@@ -1102,7 +1105,8 @@
   for (const auto& e : values) {
     uint32_t id = 0;
     if (constructor_is_const) {
-      id = GenerateConstructorExpression(e->AsConstructor(), is_global_init);
+      id = GenerateConstructorExpression(nullptr, e->AsConstructor(),
+                                         is_global_init);
     } else {
       id = GenerateExpression(e.get());
       id = GenerateLoadIfNeeded(e->result_type(), id);
@@ -1268,12 +1272,21 @@
   return result_id;
 }
 
-uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
+uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
+                                          ast::Literal* lit) {
   auto type_id = GenerateTypeIfNeeded(lit->type());
   if (type_id == 0) {
     return 0;
   }
+
   auto name = lit->name();
+  bool is_spec_constant = false;
+  if (var && var->IsDecorated() &&
+      var->AsDecorated()->HasConstantIdDecoration()) {
+    name = "__spec" + name;
+    is_spec_constant = true;
+  }
+
   auto val = const_to_id_.find(name);
   if (val != const_to_id_.end()) {
     return val->second;
@@ -1282,21 +1295,34 @@
   auto result = result_op();
   auto result_id = result.to_i();
 
+  if (is_spec_constant) {
+    push_annot(spv::Op::OpDecorate,
+               {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
+                Operand::Int(var->AsDecorated()->constant_id())});
+  }
+
   if (lit->IsBool()) {
     if (lit->AsBool()->IsTrue()) {
-      push_type(spv::Op::OpConstantTrue, {Operand::Int(type_id), result});
+      push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
+                                 : spv::Op::OpConstantTrue,
+                {Operand::Int(type_id), result});
     } else {
-      push_type(spv::Op::OpConstantFalse, {Operand::Int(type_id), result});
+      push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse
+                                 : spv::Op::OpConstantFalse,
+                {Operand::Int(type_id), result});
     }
   } else if (lit->IsSint()) {
-    push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
-                                    Operand::Int(lit->AsSint()->value())});
+    push_type(
+        is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+        {Operand::Int(type_id), result, Operand::Int(lit->AsSint()->value())});
   } else if (lit->IsUint()) {
-    push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
-                                    Operand::Int(lit->AsUint()->value())});
+    push_type(
+        is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+        {Operand::Int(type_id), result, Operand::Int(lit->AsUint()->value())});
   } else if (lit->IsFloat()) {
-    push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
-                                    Operand::Float(lit->AsFloat()->value())});
+    push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+              {Operand::Int(type_id), result,
+               Operand::Float(lit->AsFloat()->value())});
   } else if (lit->IsNull()) {
     push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
   } else {
@@ -1730,7 +1756,8 @@
     spirv_params.push_back(Operand::Int(SpvImageOperandsLodMask));
     ast::type::F32Type f32;
     ast::FloatLiteral float_0(&f32, 0.0);
-    spirv_params.push_back(Operand::Int(GenerateLiteralIfNeeded(&float_0)));
+    spirv_params.push_back(
+        Operand::Int(GenerateLiteralIfNeeded(nullptr, &float_0)));
   }
   if (op == spv::Op::OpNop) {
     error_ = "unable to determine operator for: " + ident->name();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index ccfce39..b6e3f7c 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -249,10 +249,12 @@
   /// Generates an import instruction
   void GenerateGLSLstd450Import();
   /// Generates a constructor expression
+  /// @param var the variable generated for, nullptr if no variable associated.
   /// @param expr the expression to generate
   /// @param is_global_init set true if this is a global variable constructor
   /// @returns the ID of the expression or 0 on failure.
-  uint32_t GenerateConstructorExpression(ast::ConstructorExpression* expr,
+  uint32_t GenerateConstructorExpression(ast::Variable* var,
+                                         ast::ConstructorExpression* expr,
                                          bool is_global_init);
   /// Generates a type constructor expression
   /// @param init the expression to generate
@@ -262,9 +264,10 @@
       ast::TypeConstructorExpression* init,
       bool is_global_init);
   /// Generates a literal constant if needed
+  /// @param var the variable generated for, nullptr if no variable associated.
   /// @param lit the literal to generate
   /// @returns the ID on success or 0 on failure
-  uint32_t GenerateLiteralIfNeeded(ast::Literal* lit);
+  uint32_t GenerateLiteralIfNeeded(ast::Variable* var, ast::Literal* lit);
   /// Generates a binary expression
   /// @param expr the expression to generate
   /// @returns the expression ID on success or 0 otherwise
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index 73f07ad..12ac999 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -56,7 +56,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  EXPECT_EQ(b.GenerateConstructorExpression(&c, true), 2u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &c, true), 2u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
@@ -84,7 +84,7 @@
   EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
 
   Builder b(&mod);
-  EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 5u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
@@ -190,7 +190,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 0u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 0u);
   EXPECT_TRUE(b.has_error());
   EXPECT_EQ(b.error(), R"(constructor must be a constant expression)");
 }
@@ -765,7 +765,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 3
@@ -807,7 +807,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 3
@@ -851,7 +851,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
@@ -895,7 +895,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
@@ -939,7 +939,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
@@ -987,7 +987,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
@@ -1033,7 +1033,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
@@ -1079,7 +1079,7 @@
 
   Builder b(&mod);
   b.push_function(Function{});
-  EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+  EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 4
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 1e60c0f..73ae303 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -16,14 +16,18 @@
 
 #include "gtest/gtest.h"
 #include "src/ast/binding_decoration.h"
+#include "src/ast/bool_literal.h"
 #include "src/ast/builtin.h"
 #include "src/ast/builtin_decoration.h"
+#include "src/ast/constant_id_decoration.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/location_decoration.h"
+#include "src/ast/module.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/set_decoration.h"
 #include "src/ast/storage_class.h"
+#include "src/ast/type/bool_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
@@ -359,6 +363,58 @@
 )");
 }
 
+TEST_F(BuilderTest, GlobalVar_ConstantId_Bool) {
+  ast::type::BoolType bool_type;
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::ConstantIdDecoration>(1200));
+
+  ast::DecoratedVariable v(std::make_unique<ast::Variable>(
+      "var", ast::StorageClass::kNone, &bool_type));
+  v.set_decorations(std::move(decos));
+  v.set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true)));
+
+  ast::Module mod;
+  Builder b(&mod);
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var"
+)");
+  EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200
+)");
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
+%2 = OpSpecConstantTrue %1
+%4 = OpTypePointer Private %1
+%3 = OpVariable %4 Private %2
+)");
+}
+
+TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar) {
+  ast::type::F32Type f32;
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::ConstantIdDecoration>(0));
+
+  ast::DecoratedVariable v(
+      std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &f32));
+  v.set_decorations(std::move(decos));
+  v.set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 2.0)));
+
+  ast::Module mod;
+  Builder b(&mod);
+  EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var"
+)");
+  EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
+)");
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpSpecConstant %1 2
+%4 = OpTypePointer Private %1
+%3 = OpVariable %4 Private %2
+)");
+}
+
 struct BuiltinData {
   ast::Builtin builtin;
   SpvBuiltIn result;
diff --git a/src/writer/spirv/builder_literal_test.cc b/src/writer/spirv/builder_literal_test.cc
index 1c96661..294c6b6 100644
--- a/src/writer/spirv/builder_literal_test.cc
+++ b/src/writer/spirv/builder_literal_test.cc
@@ -38,7 +38,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  auto id = b.GenerateLiteralIfNeeded(&b_true);
+  auto id = b.GenerateLiteralIfNeeded(nullptr, &b_true);
   ASSERT_FALSE(b.has_error()) << b.error();
   EXPECT_EQ(2u, id);
 
@@ -53,7 +53,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  auto id = b.GenerateLiteralIfNeeded(&b_false);
+  auto id = b.GenerateLiteralIfNeeded(nullptr, &b_false);
   ASSERT_FALSE(b.has_error()) << b.error();
   EXPECT_EQ(2u, id);
 
@@ -69,11 +69,11 @@
 
   ast::Module mod;
   Builder b(&mod);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&b_false), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_false), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
@@ -88,7 +88,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  auto id = b.GenerateLiteralIfNeeded(&i);
+  auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
   ASSERT_FALSE(b.has_error()) << b.error();
   EXPECT_EQ(2u, id);
 
@@ -104,8 +104,8 @@
 
   ast::Module mod;
   Builder b(&mod);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
@@ -119,7 +119,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  auto id = b.GenerateLiteralIfNeeded(&i);
+  auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
   ASSERT_FALSE(b.has_error()) << b.error();
   EXPECT_EQ(2u, id);
 
@@ -135,8 +135,8 @@
 
   ast::Module mod;
   Builder b(&mod);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
@@ -150,7 +150,7 @@
 
   ast::Module mod;
   Builder b(&mod);
-  auto id = b.GenerateLiteralIfNeeded(&i);
+  auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
   ASSERT_FALSE(b.has_error()) << b.error();
   EXPECT_EQ(2u, id);
 
@@ -166,8 +166,8 @@
 
   ast::Module mod;
   Builder b(&mod);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
-  ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+  ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
   ASSERT_FALSE(b.has_error()) << b.error();
 
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32