writer/spirv: Clean up constant generation

Requiring a temporary stack-allocated ast::Literal is an unpleasant requirement to generate a SPIR-V constant value.
GenerateU32Literal() was also creating an invalid AST - the type was U32, yet an an ast::SintLiteral was used.

Instead add Constant for holding a constant value, and use this as the map key.

This also removes the last remaining use of ast::NullLiteral, which will be removed in the next change.

Change-Id: Ia85732784075f153503dbef101ba95018eaa4bf5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45342
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 787afe2..fc82118 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -558,6 +558,7 @@
     "writer/spirv/instruction.h",
     "writer/spirv/operand.cc",
     "writer/spirv/operand.h",
+    "writer/spirv/scalar_constant.h",
   ]
 
   configs += [ ":tint_common_config" ]
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 9614ed4..9e5f29e 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -343,6 +343,7 @@
     writer/spirv/instruction.h
     writer/spirv/operand.cc
     writer/spirv/operand.h
+    writer/spirv/scalar_constant.h
   )
 endif()
 
@@ -678,6 +679,7 @@
       writer/spirv/builder_unary_op_expression_test.cc
       writer/spirv/instruction_test.cc
       writer/spirv/operand_test.cc
+      writer/spirv/scalar_constant_test.cc
       writer/spirv/spv_dump.cc
       writer/spirv/spv_dump.h
       writer/spirv/test_helper.h
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 71054c8..7209aba 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -21,7 +21,6 @@
 #include "src/ast/call_statement.h"
 #include "src/ast/constant_id_decoration.h"
 #include "src/ast/fallthrough_statement.h"
-#include "src/ast/null_literal.h"
 #include "src/semantic/array.h"
 #include "src/semantic/call.h"
 #include "src/semantic/function.h"
@@ -365,12 +364,6 @@
   return true;
 }
 
-uint32_t Builder::GenerateU32Literal(uint32_t val) {
-  type::U32 u32;
-  ast::SintLiteral lit(Source{}, &u32, val);
-  return GenerateLiteralIfNeeded(nullptr, &lit);
-}
-
 bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
   auto lhs_id = GenerateExpression(assign->lhs());
   if (lhs_id == 0) {
@@ -648,8 +641,7 @@
 
   // TODO(dsinclair) We could detect if the constructor is fully const and emit
   // an initializer value for the variable instead of doing the OpLoad.
-  ast::NullLiteral nl(Source{}, var->type()->UnwrapPtrIfNeeded());
-  auto null_id = GenerateLiteralIfNeeded(var, &nl);
+  auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded());
   if (null_id == 0) {
     return 0;
   }
@@ -779,8 +771,7 @@
     } else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
                sem->StorageClass() == ast::StorageClass::kNone ||
                sem->StorageClass() == ast::StorageClass::kOutput) {
-      ast::NullLiteral nl(Source{}, type);
-      init_id = GenerateLiteralIfNeeded(var, &nl);
+      init_id = GenerateConstantNullIfNeeded(type);
       if (init_id == 0) {
         return 0;
       }
@@ -888,7 +879,7 @@
       }
     }
 
-    auto idx_id = GenerateU32Literal(i);
+    auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
     if (idx_id == 0) {
       return 0;
     }
@@ -913,7 +904,7 @@
     }
 
     if (info->source_type->Is<type::Pointer>()) {
-      auto idx_id = GenerateU32Literal(val);
+      auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val));
       if (idx_id == 0) {
         return 0;
       }
@@ -1044,8 +1035,7 @@
 
       auto ary_result = result_op();
 
-      ast::NullLiteral nl(Source{}, ary_res_type);
-      auto init = GenerateLiteralIfNeeded(nullptr, &nl);
+      auto init = GenerateConstantNullIfNeeded(ary_res_type);
 
       // If we're access chaining into an array then we must be in a function
       push_function_var(
@@ -1259,8 +1249,7 @@
 
   // Generate the zero initializer if there are no values provided.
   if (values.empty()) {
-    ast::NullLiteral nl(Source{}, init->type()->UnwrapPtrIfNeeded());
-    return GenerateLiteralIfNeeded(nullptr, &nl);
+    return GenerateConstantNullIfNeeded(init->type()->UnwrapPtrIfNeeded());
   }
 
   std::ostringstream out;
@@ -1370,7 +1359,7 @@
           result_is_constant_composite = false;
         } else {
           // A global initializer, must use OpSpecConstantOp. Case 1.
-          auto idx_id = GenerateU32Literal(i);
+          auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
           if (idx_id == 0) {
             return 0;
           }
@@ -1392,8 +1381,8 @@
   }
 
   auto str = out.str();
-  auto val = const_to_id_.find(str);
-  if (val != const_to_id_.end()) {
+  auto val = type_constructor_to_id_.find(str);
+  if (val != type_constructor_to_id_.end()) {
     return val->second;
   }
 
@@ -1401,7 +1390,7 @@
   ops.insert(ops.begin(), result);
   ops.insert(ops.begin(), Operand::Int(type_id));
 
-  const_to_id_[str] = result.to_i();
+  type_constructor_to_id_[str] = result.to_i();
 
   if (result_is_spec_composite) {
     push_type(spv::Op::OpSpecConstantComposite, ops);
@@ -1480,59 +1469,133 @@
 
 uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
                                           ast::Literal* lit) {
-  auto type_id = GenerateTypeIfNeeded(lit->type());
-  if (type_id == 0) {
-    return 0;
-  }
+  ScalarConstant constant;
 
-  auto name = lit->name();
-  bool is_spec_constant = false;
   if (var && var->HasConstantIdDecoration()) {
-    name = "__spec" + name;
-    is_spec_constant = true;
-  }
-
-  auto val = const_to_id_.find(name);
-  if (val != const_to_id_.end()) {
-    return val->second;
-  }
-
-  auto result = result_op();
-  auto result_id = result.to_i();
-
-  if (is_spec_constant) {
-    push_annot(spv::Op::OpDecorate,
-               {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
-                Operand::Int(var->constant_id())});
+    constant.is_spec_op = true;
+    constant.constant_id = var->constant_id();
   }
 
   if (auto* l = lit->As<ast::BoolLiteral>()) {
-    if (l->IsTrue()) {
-      push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
-                                 : spv::Op::OpConstantTrue,
-                {Operand::Int(type_id), result});
-    } else {
-      push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse
-                                 : spv::Op::OpConstantFalse,
-                {Operand::Int(type_id), result});
-    }
+    constant.kind = ScalarConstant::Kind::kBool;
+    constant.value.b = l->IsTrue();
   } else if (auto* sl = lit->As<ast::SintLiteral>()) {
-    push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
-              {Operand::Int(type_id), result, Operand::Int(sl->value())});
+    constant.kind = ScalarConstant::Kind::kI32;
+    constant.value.i32 = sl->value();
   } else if (auto* ul = lit->As<ast::UintLiteral>()) {
-    push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
-              {Operand::Int(type_id), result, Operand::Int(ul->value())});
+    constant.kind = ScalarConstant::Kind::kU32;
+    constant.value.u32 = ul->value();
   } else if (auto* fl = lit->As<ast::FloatLiteral>()) {
-    push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
-              {Operand::Int(type_id), result, Operand::Float(fl->value())});
-  } else if (lit->Is<ast::NullLiteral>()) {
-    push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
+    constant.kind = ScalarConstant::Kind::kF32;
+    constant.value.f32 = fl->value();
   } else {
     error_ = "unknown literal type";
     return 0;
   }
 
-  const_to_id_[name] = result_id;
+  return GenerateConstantIfNeeded(constant);
+}
+
+uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
+  auto it = const_to_id_.find(constant);
+  if (it != const_to_id_.end()) {
+    return it->second;
+  }
+
+  uint32_t type_id = 0;
+
+  switch (constant.kind) {
+    case ScalarConstant::Kind::kU32: {
+      type::U32 u32;
+      type_id = GenerateTypeIfNeeded(&u32);
+      break;
+    }
+    case ScalarConstant::Kind::kI32: {
+      type::I32 i32;
+      type_id = GenerateTypeIfNeeded(&i32);
+      break;
+    }
+    case ScalarConstant::Kind::kF32: {
+      type::F32 f32;
+      type_id = GenerateTypeIfNeeded(&f32);
+      break;
+    }
+    case ScalarConstant::Kind::kBool: {
+      type::Bool bool_;
+      type_id = GenerateTypeIfNeeded(&bool_);
+      break;
+    }
+  }
+
+  if (type_id == 0) {
+    return 0;
+  }
+
+  auto result = result_op();
+  auto result_id = result.to_i();
+
+  if (constant.is_spec_op) {
+    push_annot(spv::Op::OpDecorate,
+               {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
+                Operand::Int(constant.constant_id)});
+  }
+
+  switch (constant.kind) {
+    case ScalarConstant::Kind::kU32: {
+      push_type(
+          constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+          {Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
+      break;
+    }
+    case ScalarConstant::Kind::kI32: {
+      push_type(
+          constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+          {Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
+      break;
+    }
+    case ScalarConstant::Kind::kF32: {
+      push_type(
+          constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+          {Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
+      break;
+    }
+    case ScalarConstant::Kind::kBool: {
+      if (constant.value.b) {
+        push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
+                                      : spv::Op::OpConstantTrue,
+                  {Operand::Int(type_id), result});
+      } else {
+        push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
+                                      : spv::Op::OpConstantFalse,
+                  {Operand::Int(type_id), result});
+      }
+      break;
+    }
+  }
+
+  const_to_id_[constant] = result_id;
+  return result_id;
+}
+
+uint32_t Builder::GenerateConstantNullIfNeeded(type::Type* type) {
+  auto type_id = GenerateTypeIfNeeded(type);
+  if (type_id == 0) {
+    return 0;
+  }
+
+  auto name = type->type_name();
+
+  auto it = const_null_to_id_.find(name);
+  if (it != const_null_to_id_.end()) {
+    return it->second;
+  }
+
+  auto result = result_op();
+  auto result_id = result.to_i();
+
+  push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
+
+  const_null_to_id_[name] = result_id;
   return result_id;
 }
 
@@ -2955,7 +3018,7 @@
   if (ary->IsRuntimeArray()) {
     push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
   } else {
-    auto len_id = GenerateU32Literal(ary->size());
+    auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->size()));
     if (len_id == 0) {
       return false;
     }
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 4f3251d..6147c9f 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -37,6 +37,7 @@
 #include "src/type/access_control_type.h"
 #include "src/type/storage_texture_type.h"
 #include "src/writer/spirv/function.h"
+#include "src/writer/spirv/scalar_constant.h"
 
 namespace tint {
 
@@ -208,10 +209,6 @@
   /// @param id the id to use for the label
   /// @returns true on success.
   bool GenerateLabel(uint32_t id);
-  /// Generates a uint32_t literal.
-  /// @param val the value to generate
-  /// @returns the ID of the generated literal
-  uint32_t GenerateU32Literal(uint32_t val);
   /// Generates an assignment statement
   /// @param assign the statement to generate
   /// @returns true if the statement was successfully generated
@@ -486,6 +483,16 @@
     return builder_.TypeOf(expr);
   }
 
+  /// Generates a constant if needed
+  /// @param constant the constant to generate.
+  /// @returns the ID on success or 0 on failure
+  uint32_t GenerateConstantIfNeeded(const ScalarConstant& constant);
+
+  /// Generates a constant-null of the given type, if needed
+  /// @param type the type of the constant null to generate.
+  /// @returns the ID on success or 0 on failure
+  uint32_t GenerateConstantNullIfNeeded(type::Type* type);
+
   ProgramBuilder builder_;
   std::string error_;
   uint32_t next_id_ = 1;
@@ -504,7 +511,9 @@
   std::unordered_map<std::string, uint32_t> import_name_to_id_;
   std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
   std::unordered_map<std::string, uint32_t> type_name_to_id_;
-  std::unordered_map<std::string, uint32_t> const_to_id_;
+  std::unordered_map<ScalarConstant, uint32_t> const_to_id_;
+  std::unordered_map<std::string, uint32_t> type_constructor_to_id_;
+  std::unordered_map<std::string, uint32_t> const_null_to_id_;
   std::unordered_map<std::string, uint32_t>
       texture_type_name_to_sampled_image_type_id_;
   ScopeStack<uint32_t> scope_stack_;
diff --git a/src/writer/spirv/scalar_constant.h b/src/writer/spirv/scalar_constant.h
new file mode 100644
index 0000000..18dba91
--- /dev/null
+++ b/src/writer/spirv/scalar_constant.h
@@ -0,0 +1,115 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
+#define SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
+
+#include <stdint.h>
+
+#include <cstring>
+#include <functional>
+
+namespace tint {
+
+// Forward declarations
+namespace semantic {
+class Call;
+}  // namespace semantic
+
+namespace writer {
+namespace spirv {
+
+/// ScalarConstant represents a scalar constant value
+struct ScalarConstant {
+  /// The constant value
+  union Value {
+    /// The value as a bool
+    bool b;
+    /// The value as a uint32_t
+    uint32_t u32;
+    /// The value as a int32_t
+    int32_t i32;
+    /// The value as a float
+    float f32;
+
+    /// The value that is wide enough to encompass all other types (including
+    /// future 64-bit data types).
+    uint64_t u64;
+  };
+
+  /// The kind of constant
+  enum class Kind { kBool, kU32, kI32, kF32 };
+
+  /// Constructor
+  inline ScalarConstant() { value.u64 = 0; }
+
+  /// @param value the value of the constant
+  /// @returns a new ScalarConstant with the provided value and kind Kind::kU32
+  static inline ScalarConstant U32(uint32_t value) {
+    ScalarConstant c;
+    c.value.u32 = value;
+    c.kind = Kind::kU32;
+    return c;
+  }
+
+  /// Equality operator
+  /// @param rhs the ScalarConstant to compare against
+  /// @returns true if this ScalarConstant is equal to `rhs`
+  inline bool operator==(const ScalarConstant& rhs) const {
+    return value.u64 == rhs.value.u64 && kind == rhs.kind &&
+           is_spec_op == rhs.is_spec_op && constant_id == rhs.constant_id;
+  }
+
+  /// Inequality operator
+  /// @param rhs the ScalarConstant to compare against
+  /// @returns true if this ScalarConstant is not equal to `rhs`
+  inline bool operator!=(const ScalarConstant& rhs) const {
+    return !(*this == rhs);
+  }
+
+  /// The constant value
+  Value value;
+  /// The constant value kind
+  Kind kind = Kind::kBool;
+  /// True if the constant is a specialization op
+  bool is_spec_op = false;
+  /// The identifier if a specialization op
+  uint32_t constant_id = 0;
+};
+
+}  // namespace spirv
+}  // namespace writer
+}  // namespace tint
+
+namespace std {
+
+/// Custom std::hash specialization for tint::Symbol so symbols can be used as
+/// keys for std::unordered_map and std::unordered_set.
+template <>
+class hash<tint::writer::spirv::ScalarConstant> {
+ public:
+  /// @param c the ScalarConstant
+  /// @return the Symbol internal value
+  inline std::size_t operator()(
+      const tint::writer::spirv::ScalarConstant& c) const {
+    uint32_t value = 0;
+    std::memcpy(&value, &c.value, sizeof(value));
+    return (static_cast<std::size_t>(value) << 2) |
+           (static_cast<std::size_t>(c.kind) & 3);
+  }
+};
+
+}  // namespace std
+
+#endif  // SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
diff --git a/src/writer/spirv/scalar_constant_test.cc b/src/writer/spirv/scalar_constant_test.cc
new file mode 100644
index 0000000..b514146
--- /dev/null
+++ b/src/writer/spirv/scalar_constant_test.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/writer/spirv/scalar_constant.h"
+#include "src/writer/spirv/test_helper.h"
+
+namespace tint {
+namespace writer {
+namespace spirv {
+namespace {
+
+using SpirvScalarConstantTest = TestHelper;
+
+TEST_F(SpirvScalarConstantTest, Equality) {
+  ScalarConstant a{};
+  ScalarConstant b{};
+  EXPECT_EQ(a, b);
+
+  a.kind = ScalarConstant::Kind::kU32;
+  EXPECT_NE(a, b);
+  b.kind = ScalarConstant::Kind::kU32;
+  EXPECT_EQ(a, b);
+
+  a.value.b = true;
+  EXPECT_NE(a, b);
+  b.value.b = true;
+  EXPECT_EQ(a, b);
+
+  a.is_spec_op = true;
+  EXPECT_NE(a, b);
+  b.is_spec_op = true;
+  EXPECT_EQ(a, b);
+
+  a.constant_id = 3;
+  EXPECT_NE(a, b);
+  b.constant_id = 3;
+  EXPECT_EQ(a, b);
+}
+
+TEST_F(SpirvScalarConstantTest, U32) {
+  auto c = ScalarConstant::U32(123);
+  EXPECT_EQ(c.value.u32, 123u);
+  EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32);
+}
+
+}  // namespace
+}  // namespace spirv
+}  // namespace writer
+}  // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index b49cf59..b69dd73 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -336,6 +336,7 @@
     "../src/writer/spirv/builder_unary_op_expression_test.cc",
     "../src/writer/spirv/instruction_test.cc",
     "../src/writer/spirv/operand_test.cc",
+    "../src/writer/spirv/scalar_constant_test.cc",
     "../src/writer/spirv/spv_dump.cc",
     "../src/writer/spirv/spv_dump.h",
     "../src/writer/spirv/test_helper.h",