[spirv-reader] Internally, generate typed expressions

The AST only wants expressions, not their result types.
But the SPIR-V reader wants to track the AST type as well.
So introduce a TypedExpression concept for internal use.

Bug: tint:3
Change-Id: Ia832f7422440ef0e8e04630cdca98cae20e18921
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20040
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 2d520cb..0a17fd5 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -212,7 +212,8 @@
       // (OpenCL also allows the ID of an OpVariable, but we don't handle that
       // here.)
       var->set_constructor(
-          parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)));
+          parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))
+              .expr);
     }
     // TODO(dneto): Add the initializer via Variable::set_constructor.
     auto var_decl_stmt =
@@ -224,12 +225,14 @@
   return success();
 }
 
-std::unique_ptr<ast::Expression> FunctionEmitter::MakeExpression(uint32_t id) {
+TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
   if (failed()) {
-    return nullptr;
+    return {};
   }
   if (identifier_values_.count(id)) {
-    return std::make_unique<ast::IdentifierExpression>(namer_.Name(id));
+    return TypedExpression(
+        parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
+        std::make_unique<ast::IdentifierExpression>(namer_.Name(id)));
   }
   if (singly_used_values_.count(id)) {
     auto expr = std::move(singly_used_values_[id]);
@@ -243,18 +246,19 @@
   const auto* inst = def_use_mgr_->GetDef(id);
   if (inst == nullptr) {
     Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
-    return nullptr;
+    return {};
   }
   switch (inst->opcode()) {
     case SpvOpVariable:
       // This occurs for module-scope variables.
-      return std::make_unique<ast::IdentifierExpression>(
-          namer_.Name(inst->result_id()));
+      return TypedExpression(parser_impl_.ConvertType(inst->type_id()),
+                             std::make_unique<ast::IdentifierExpression>(
+                                 namer_.Name(inst->result_id())));
     default:
       break;
   }
   Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
-  return nullptr;
+  return {};
 }
 
 bool FunctionEmitter::EmitFunctionBodyStatements() {
@@ -284,8 +288,8 @@
 
 bool FunctionEmitter::EmitConstDefinition(
     const spvtools::opt::Instruction& inst,
-    std::unique_ptr<ast::Expression> ast_expr) {
-  if (!ast_expr) {
+    TypedExpression ast_expr) {
+  if (!ast_expr.expr) {
     return false;
   }
   auto ast_const =
@@ -294,7 +298,7 @@
   if (!ast_const) {
     return false;
   }
-  ast_const->set_constructor(std::move(ast_expr));
+  ast_const->set_constructor(std::move(ast_expr.expr));
   ast_const->set_is_const(true);
   ast_body_.emplace_back(
       std::make_unique<ast::VariableDeclStatement>(std::move(ast_const)));
@@ -306,11 +310,12 @@
 bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
   // Handle combinatorial instructions first.
   auto combinatorial_expr = MaybeEmitCombinatorialValue(inst);
-  if (combinatorial_expr != nullptr) {
+  if (combinatorial_expr.expr != nullptr) {
     if (def_use_mgr_->NumUses(&inst) == 1) {
       // If it's used once, then defer emitting the expression until it's used.
       // Any supporting statements have already been emitted.
-      singly_used_values_[inst.result_id()] = std::move(combinatorial_expr);
+      singly_used_values_.insert(
+          std::make_pair(inst.result_id(), std::move(combinatorial_expr)));
       return success();
     }
     // Otherwise, generate a const definition for it now and later use
@@ -327,7 +332,7 @@
       auto lhs = MakeExpression(inst.GetSingleWordInOperand(0));
       auto rhs = MakeExpression(inst.GetSingleWordInOperand(1));
       ast_body_.emplace_back(std::make_unique<ast::AssignmentStatement>(
-          std::move(lhs), std::move(rhs)));
+          std::move(lhs.expr), std::move(rhs.expr)));
       return success();
     }
     case SpvOpLoad:
@@ -344,10 +349,10 @@
   return Fail() << "unhandled instruction with opcode " << inst.opcode();
 }
 
-std::unique_ptr<ast::Expression> FunctionEmitter::MaybeEmitCombinatorialValue(
+TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
     const spvtools::opt::Instruction& inst) {
   if (inst.result_id() == 0) {
-    return nullptr;
+    return {};
   }
 
   // TODO(dneto): Fill in the following cases.
@@ -356,10 +361,14 @@
     return this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
   };
 
+  auto* ast_type =
+      inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
+
   auto binary_op = ConvertBinaryOp(inst.opcode());
   if (binary_op != ast::BinaryOp::kNone) {
-    return std::make_unique<ast::BinaryExpression>(binary_op, operand(0),
-                                                   operand(1));
+    return {ast_type, std::make_unique<ast::BinaryExpression>(
+                          binary_op, std::move(operand(0).expr),
+                          std::move(operand(1).expr))};
   }
   // binary operator
   // unary operator
@@ -393,7 +402,7 @@
   //    OpCompositeExtract
   //    OpCompositeInsert
 
-  return nullptr;
+  return {};
 }
 
 }  // namespace spirv
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index a22140c..433b078 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -97,12 +97,12 @@
   /// @param ast_expr the already-computed AST expression for the value
   /// @returns false if emission failed.
   bool EmitConstDefinition(const spvtools::opt::Instruction& inst,
-                           std::unique_ptr<ast::Expression> ast_expr);
+                           TypedExpression ast_expr);
 
   /// Makes an expression
   /// @param id the SPIR-V ID of the value
   /// @returns true if emission has not yet failed.
-  std::unique_ptr<ast::Expression> MakeExpression(uint32_t id);
+  TypedExpression MakeExpression(uint32_t id);
 
   /// Creates an expression and supporting statements for a combinatorial
   /// instruction, or returns null.  A SPIR-V instruction is combinatorial
@@ -113,7 +113,7 @@
   /// combinatorial.
   /// @param inst a SPIR-V instruction representing an exrpression
   /// @returns an AST expression for the instruction, or nullptr.
-  std::unique_ptr<ast::Expression> MaybeEmitCombinatorialValue(
+  TypedExpression MaybeEmitCombinatorialValue(
       const spvtools::opt::Instruction& inst);
 
  private:
@@ -135,8 +135,7 @@
   // The set of IDs that have already had an identifier name generated for it.
   std::unordered_set<uint32_t> identifier_values_;
   // Mapping from SPIR-V ID that is used at most once, to its AST expression.
-  std::unordered_map<uint32_t, std::unique_ptr<ast::Expression>>
-      singly_used_values_;
+  std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
 };
 
 }  // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index cd36971..795dc7d 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -715,7 +715,7 @@
       // (OpenCL also allows the ID of an OpVariable, but we don't handle that
       // here.)
       ast_var->set_constructor(
-          MakeConstantExpression(var.GetSingleWordInOperand(1)));
+          MakeConstantExpression(var.GetSingleWordInOperand(1)).expr);
     }
     // TODO(dneto): initializers (a.k.a. constructor expression)
     ast_module_.AddGlobalVariable(std::move(ast_var));
@@ -763,48 +763,50 @@
   return ast_var;
 }
 
-std::unique_ptr<ast::Expression> ParserImpl::MakeConstantExpression(
-    uint32_t id) {
+TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
   if (!success_) {
-    return nullptr;
+    return {};
   }
   const auto* inst = def_use_mgr_->GetDef(id);
   if (inst == nullptr) {
     Fail() << "ID " << id << " is not a registered instruction";
-    return nullptr;
+    return {};
   }
   auto* ast_type = ConvertType(inst->type_id());
   if (ast_type == nullptr) {
-    return nullptr;
+    return {};
   }
   // TODO(dneto): Handle spec constants too?
   const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id);
   if (spirv_const == nullptr) {
     Fail() << "ID " << id << " is not a constant";
-    return nullptr;
+    return {};
   }
   // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
   // So canonicalization should map that way too.
   // Currently "null<type>" is missing from the WGSL parser.
   // See https://bugs.chromium.org/p/tint/issues/detail?id=34
   if (ast_type->IsU32()) {
-    return std::make_unique<ast::ScalarConstructorExpression>(
-        std::make_unique<ast::UintLiteral>(ast_type, spirv_const->GetU32()));
+    return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
+                          std::make_unique<ast::UintLiteral>(
+                              ast_type, spirv_const->GetU32()))};
   }
   if (ast_type->IsI32()) {
-    return std::make_unique<ast::ScalarConstructorExpression>(
-        std::make_unique<ast::IntLiteral>(ast_type, spirv_const->GetS32()));
+    return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
+                          std::make_unique<ast::IntLiteral>(
+                              ast_type, spirv_const->GetS32()))};
   }
   if (ast_type->IsF32()) {
-    return std::make_unique<ast::ScalarConstructorExpression>(
-        std::make_unique<ast::FloatLiteral>(ast_type, spirv_const->GetFloat()));
+    return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
+                          std::make_unique<ast::FloatLiteral>(
+                              ast_type, spirv_const->GetFloat()))};
   }
   if (ast_type->IsBool()) {
     const bool value = spirv_const->AsNullConstant()
                            ? false
                            : spirv_const->AsBoolConstant()->value();
-    return std::make_unique<ast::ScalarConstructorExpression>(
-        std::make_unique<ast::BoolLiteral>(ast_type, value));
+    return {ast_type, std::make_unique<ast::ScalarConstructorExpression>(
+                          std::make_unique<ast::BoolLiteral>(ast_type, value))};
   }
   auto* spirv_composite_const = spirv_const->AsCompositeConstant();
   if (spirv_composite_const != nullptr) {
@@ -820,21 +822,21 @@
       if (def == nullptr) {
         Fail() << "internal error: SPIR-V constant doesn't have defining "
                   "instruction";
-        return nullptr;
+        return {};
       }
       auto ast_component = MakeConstantExpression(def->result_id());
       if (!success_) {
         // We've already emitted a diagnostic.
-        return nullptr;
+        return {};
       }
-      ast_components.emplace_back(std::move(ast_component));
+      ast_components.emplace_back(std::move(ast_component.expr));
     }
-    return std::make_unique<ast::TypeConstructorExpression>(
-        ast_type, std::move(ast_components));
+    return {ast_type, std::make_unique<ast::TypeConstructorExpression>(
+                          ast_type, std::move(ast_components))};
   }
   Fail() << "Unhandled constant type " << inst->type_id() << " for value ID "
          << id;
-  return nullptr;
+  return {};
 }
 
 bool ParserImpl::EmitFunctions() {
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index da73bcf..91c22e0 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -50,6 +50,25 @@
 using Decoration = std::vector<uint32_t>;
 using DecorationList = std::vector<Decoration>;
 
+// An AST expression with its type.
+struct TypedExpression {
+  /// Dummy constructor
+  TypedExpression() : type(nullptr), expr(nullptr) {}
+  /// Constructor
+  /// @param t the type
+  /// @param e the expression
+  TypedExpression(ast::type::Type* t, std::unique_ptr<ast::Expression> e)
+      : type(t), expr(std::move(e)) {}
+  /// Move constructor
+  /// @param other the other typed expression
+  TypedExpression(TypedExpression&& other)
+      : type(other.type), expr(std::move(other.expr)) {}
+  /// The type
+  ast::type::Type* type;
+  /// The expression
+  std::unique_ptr<ast::Expression> expr;
+};
+
 /// Parser implementation for SPIR-V.
 class ParserImpl : Reader {
  public:
@@ -224,7 +243,7 @@
   /// Creates an AST expression node for a SPIR-V constant.
   /// @param id the SPIR-V ID of the constant
   /// @returns a new Literal node
-  std::unique_ptr<ast::Expression> MakeConstantExpression(uint32_t id);
+  TypedExpression MakeConstantExpression(uint32_t id);
 
  private:
   /// Converts a specific SPIR-V type to a Tint type. Integer case