spirv-reader: Fix signedness for extended instructions

SClamp is the only implemented instruction that is affected, so far

Bug: tint:405
Change-Id: I21c1cdd3e70fc3a64046f0473569ba906048cd37
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35240
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 956aec6..61d35ec 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2862,7 +2862,7 @@
     const spvtools::opt::Instruction& inst,
     uint32_t operand_index) {
   auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
-  return parser_impl_.RectifyOperandSignedness(inst.opcode(), std::move(expr));
+  return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
 }
 
 TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
@@ -2883,7 +2883,7 @@
     auto* binary_expr =
         create<ast::BinaryExpression>(binary_op, arg0.expr, arg1.expr);
     TypedExpression result{ast_type, binary_expr};
-    return parser_impl_.RectifyForcedResultType(result, opcode, arg0.type);
+    return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
   }
 
   auto unary_op = ast::UnaryOp::kNegation;
@@ -2891,7 +2891,7 @@
     auto arg0 = MakeOperand(inst, 0);
     auto* unary_expr = create<ast::UnaryOpExpression>(unary_op, arg0.expr);
     TypedExpression result{ast_type, unary_expr};
-    return parser_impl_.RectifyForcedResultType(result, opcode, arg0.type);
+    return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
   }
 
   const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
@@ -3002,13 +3002,20 @@
 
   auto* func = create<ast::IdentifierExpression>(name);
   ast::ExpressionList operands;
+  ast::type::Type* first_operand_type = nullptr;
   // All parameters to GLSL.std.450 extended instructions are IDs.
   for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
-    operands.emplace_back(MakeOperand(inst, iarg).expr);
+    TypedExpression operand = MakeOperand(inst, iarg);
+    if (first_operand_type == nullptr) {
+      first_operand_type = operand.type;
+    }
+    operands.emplace_back(operand.expr);
   }
   auto* ast_type = parser_impl_.ConvertType(inst.type_id());
   auto* call = create<ast::CallExpression>(func, std::move(operands));
-  return {ast_type, call};
+  TypedExpression call_expr{ast_type, call};
+  return parser_impl_.RectifyForcedResultType(call_expr, inst,
+                                              first_operand_type);
 }
 
 ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc
index 10c16da..203d357 100644
--- a/src/reader/spirv/function_glsl_std_450_test.cc
+++ b/src/reader/spirv/function_glsl_std_450_test.cc
@@ -826,6 +826,66 @@
                          SpvParserTest_GlslStd450_Uinting_UintingUintingUinting,
                          ::testing::Values(GlslStd450Case{"UClamp", "clamp"}));
 
+TEST_F(SpvParserTest, RectifyOperandsAndResult_GLSLstd450SClamp) {
+  const auto assembly = Preamble() + R"(
+     %1 = OpExtInst %uint %glsl SClamp %u1 %i2 %u3
+     %2 = OpExtInst %v2uint %glsl SClamp %v2u1 %v2i2 %v2u3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto body = ToString(fe.ast_body());
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_1
+    none
+    __u32
+    {
+      Bitcast[not set]<__u32>{
+        Call[not set]{
+          Identifier[not set]{clamp}
+          (
+            Bitcast[not set]<__i32>{
+              Identifier[not set]{u1}
+            }
+            Identifier[not set]{i2}
+            Bitcast[not set]<__i32>{
+              Identifier[not set]{u3}
+            }
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+  EXPECT_THAT(body, HasSubstr(R"(
+  VariableConst{
+    x_2
+    none
+    __vec_2__u32
+    {
+      Bitcast[not set]<__vec_2__u32>{
+        Call[not set]{
+          Identifier[not set]{clamp}
+          (
+            Bitcast[not set]<__vec_2__i32>{
+              Identifier[not set]{v2u1}
+            }
+            Identifier[not set]{v2i2}
+            Bitcast[not set]<__vec_2__i32>{
+              Identifier[not set]{v2u3}
+            }
+          )
+        }
+      }
+    }
+  })"))
+      << body;
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 12908b4..3543a77 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -159,6 +159,23 @@
   return false;
 }
 
+// Returns true if the GLSL extended instruction expects operands to be signed.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if all operands must be signed integral type
+bool AssumesSignedOperands(GLSLstd450 extended_opcode) {
+  switch (extended_opcode) {
+    case GLSLstd450SAbs:
+    case GLSLstd450SSign:
+    case GLSLstd450SMin:
+    case GLSLstd450SMax:
+    case GLSLstd450SClamp:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
 // Returns true if the opcode operates as if its operands are unsigned integral.
 bool AssumesUnsignedOperands(SpvOp opcode) {
   switch (opcode) {
@@ -176,6 +193,22 @@
   return false;
 }
 
+// Returns true if the GLSL extended instruction expects operands to be
+// unsigned.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if all operands must be unsigned integral type
+bool AssumesUnsignedOperands(GLSLstd450 extended_opcode) {
+  switch (extended_opcode) {
+    case GLSLstd450UMin:
+    case GLSLstd450UMax:
+    case GLSLstd450UClamp:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
 // Returns true if the operation is binary, and the WGSL operation requires
 // the signedness of the result to match the signedness of the first operand.
 bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
@@ -190,6 +223,29 @@
   return false;
 }
 
+// Returns true if the extended instruction requires the signedness of the
+// result to match the signedness of the first operand to the operation.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if the result type must match the first operand type.
+bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
+  switch (extended_opcode) {
+    case GLSLstd450SAbs:
+    case GLSLstd450SSign:
+    case GLSLstd450SMin:
+    case GLSLstd450SMax:
+    case GLSLstd450SClamp:
+    case GLSLstd450UMin:
+    case GLSLstd450UMax:
+    case GLSLstd450UClamp:
+      // TODO(dneto): FindSMsb?
+      // TODO(dneto): FindUMsb?
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
 }  // namespace
 
 ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
@@ -561,6 +617,12 @@
   return true;
 }
 
+bool ParserImpl::IsGlslExtendedInstruction(
+    const spvtools::opt::Instruction& inst) const {
+  return (inst.opcode() == SpvOpExtInst) &&
+         (glsl_std_450_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
+}
+
 bool ParserImpl::RegisterUserAndStructMemberNames() {
   if (!success_) {
     return false;
@@ -1389,10 +1451,21 @@
   return nullptr;
 }
 
-TypedExpression ParserImpl::RectifyOperandSignedness(SpvOp op,
-                                                     TypedExpression&& expr) {
-  const bool requires_signed = AssumesSignedOperands(op);
-  const bool requires_unsigned = AssumesUnsignedOperands(op);
+TypedExpression ParserImpl::RectifyOperandSignedness(
+    const spvtools::opt::Instruction& inst,
+    TypedExpression&& expr) {
+  bool requires_signed = false;
+  bool requires_unsigned = false;
+  if (IsGlslExtendedInstruction(inst)) {
+    const auto extended_opcode =
+        static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
+    requires_signed = AssumesSignedOperands(extended_opcode);
+    requires_unsigned = AssumesUnsignedOperands(extended_opcode);
+  } else {
+    const auto opcode = inst.opcode();
+    requires_signed = AssumesSignedOperands(opcode);
+    requires_unsigned = AssumesUnsignedOperands(opcode);
+  }
   if (!requires_signed && !requires_unsigned) {
     // No conversion is required, assuming our tables are complete.
     return std::move(expr);
@@ -1425,14 +1498,26 @@
 }
 
 ast::type::Type* ParserImpl::ForcedResultType(
-    SpvOp op,
+    const spvtools::opt::Instruction& inst,
     ast::type::Type* first_operand_type) {
-  const bool binary_match_first_operand =
-      AssumesResultSignednessMatchesBinaryFirstOperand(op);
-  const bool unary_match_operand = (op == SpvOpSNegate) || (op == SpvOpNot);
-  if (binary_match_first_operand || unary_match_operand) {
+  const auto opcode = inst.opcode();
+  if ((opcode == SpvOpSNegate) || (opcode == SpvOpNot)) {
+    // The unary operation cases that force the result type to match the
+    // first operand type.
     return first_operand_type;
   }
+  if (AssumesResultSignednessMatchesBinaryFirstOperand(opcode)) {
+    // The binary operation cases that force the result type to match
+    // the first operand type.
+    return first_operand_type;
+  }
+  if (IsGlslExtendedInstruction(inst)) {
+    const auto extended_opcode =
+        static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
+    if (AssumesResultSignednessMatchesFirstOperand(extended_opcode)) {
+      return first_operand_type;
+    }
+  }
   return nullptr;
 }
 
@@ -1474,9 +1559,9 @@
 
 TypedExpression ParserImpl::RectifyForcedResultType(
     TypedExpression expr,
-    SpvOp op,
+    const spvtools::opt::Instruction& inst,
     ast::type::Type* first_operand_type) {
-  auto* forced_result_ty = ForcedResultType(op, first_operand_type);
+  auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
   if ((forced_result_ty == nullptr) || (forced_result_ty == expr.type)) {
     return expr;
   }
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 346fecb..f5b9689 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -235,6 +235,12 @@
   /// @returns true if parser is still successful.
   bool RegisterExtendedInstructionImports();
 
+  // Returns true when the given instruction is an extended instruction
+  // for GLSL.std.450.
+  // @param inst a SPIR-V instruction
+  // @returns true if its an SpvOpExtInst for GLSL.std.450
+  bool IsGlslExtendedInstruction(const spvtools::opt::Instruction& inst) const;
+
   /// Registers user names for SPIR-V objects, from OpName, and OpMemberName.
   /// Also synthesizes struct field names.  Ensures uniqueness for names for
   /// SPIR-V IDs, and uniqueness of names of fields within any single struct.
@@ -301,25 +307,27 @@
   ast::Expression* MakeNullValue(ast::type::Type* type);
 
   /// Converts a given expression to the signedness demanded for an operand
-  /// of the given SPIR-V opcode, if required.  If the operation assumes
+  /// of the given SPIR-V instruction, if required.  If the instruction assumes
   /// signed integer operands, and `expr` is unsigned, then return an
   /// as-cast expression converting it to signed. Otherwise, return
   /// `expr` itself.  Similarly, convert as required from unsigned
   /// to signed. Assumes all SPIR-V types have been mapped to AST types.
-  /// @param op the SPIR-V opcode
+  /// @param inst the SPIR-V instruction
   /// @param expr an expression
   /// @returns expr, or a cast of expr
-  TypedExpression RectifyOperandSignedness(SpvOp op, TypedExpression&& expr);
+  TypedExpression RectifyOperandSignedness(
+      const spvtools::opt::Instruction& inst,
+      TypedExpression&& expr);
 
-  /// Returns the "forced" result type for the given SPIR-V opcode.
+  /// Returns the "forced" result type for the given SPIR-V instruction.
   /// If the WGSL result type for an operation has a more strict rule than
   /// requried by SPIR-V, then we say the result type is "forced".  This occurs
   /// for signed integer division (OpSDiv), for example, where the result type
   /// in WGSL must match the operand types.
-  /// @param op the SPIR-V opcode
+  /// @param inst the SPIR-V instruction
   /// @param first_operand_type the AST type for the first operand.
   /// @returns the forced AST result type, or nullptr if no forcing is required.
-  ast::type::Type* ForcedResultType(SpvOp op,
+  ast::type::Type* ForcedResultType(const spvtools::opt::Instruction& inst,
                                     ast::type::Type* first_operand_type);
 
   /// Returns a signed integer scalar or vector type matching the shape (scalar,
@@ -343,12 +351,13 @@
   /// from the expression's result type. Otherwise, returns the given expression
   /// unchanged.
   /// @param expr the expression to pass through or to wrap
-  /// @param op the SPIR-V opcode
+  /// @param inst the SPIR-V instruction
   /// @param first_operand_type the AST type for the first operand.
   /// @returns the forced AST result type, or nullptr if no forcing is required.
-  TypedExpression RectifyForcedResultType(TypedExpression expr,
-                                          SpvOp op,
-                                          ast::type::Type* first_operand_type);
+  TypedExpression RectifyForcedResultType(
+      TypedExpression expr,
+      const spvtools::opt::Instruction& inst,
+      ast::type::Type* first_operand_type);
 
   /// @returns the registered boolean type.
   ast::type::Type* Bool() const { return bool_type_; }