[spirv-reader] Fix OpSDiv operand and result signedness

(I expect that) the WGSL signed division operator expects both operands
to be signed and the result will also be signed.

When the operands of a SPIR-V OpSDiv is unsigned, then wrap
the operand in an as-cast to the corresponding signed type.

When the result type of a SPIR-V OpSDiv instruction is unsigned,
we have to wrap the generated WGSL operator with an as-cast to
that unsigned type.

This first CL addresses OpSDiv.  We'll address other operations in future CLs.

Bug: tint:3
Change-Id: If3849ceb44b21db87c1efd2c6a2cd63c6d648c88
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19800
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 0a17fd5..3d42885 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -20,6 +20,7 @@
 #include "source/opt/function.h"
 #include "source/opt/instruction.h"
 #include "source/opt/module.h"
+#include "src/ast/as_expression.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/binary_expression.h"
 #include "src/ast/identifier_expression.h"
@@ -83,6 +84,7 @@
   }
   return ast::BinaryOp::kNone;
 }
+
 }  // namespace
 
 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@@ -358,19 +360,30 @@
   // TODO(dneto): Fill in the following cases.
 
   auto operand = [this, &inst](uint32_t operand_index) {
-    return this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
+    auto expr =
+        this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
+    return parser_impl_.RectifyOperandSignedness(inst.opcode(),
+                                                 std::move(expr));
   };
 
-  auto* ast_type =
+  ast::type::Type* ast_type =
       inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr;
 
   auto binary_op = ConvertBinaryOp(inst.opcode());
   if (binary_op != ast::BinaryOp::kNone) {
-    return {ast_type, std::make_unique<ast::BinaryExpression>(
-                          binary_op, std::move(operand(0).expr),
-                          std::move(operand(1).expr))};
+    auto arg0 = operand(0);
+    auto arg1 = operand(1);
+    auto binary_expr = std::make_unique<ast::BinaryExpression>(
+        binary_op, std::move(arg0.expr), std::move(arg1.expr));
+    auto* forced_result_ty =
+        parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
+    if (forced_result_ty && forced_result_ty != ast_type) {
+      return {ast_type, std::make_unique<ast::AsExpression>(
+                            ast_type, std::move(binary_expr))};
+    }
+    return {ast_type, std::move(binary_expr)};
   }
-  // binary operator
+
   // unary operator
   // builtin readonly function
   // glsl.std.450 readonly function
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index a0ccf6e..31ac7f0 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -91,6 +91,15 @@
           ScalarConstructor{30}
         })";
   }
+  if (assembly == "cast_int_v2uint_10_20") {
+    return R"(As<__vec_2__i32>{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{10}
+            ScalarConstructor{20}
+          }
+        })";
+  }
   if (assembly == "v2float_50_60") {
     return R"(TypeConstructor{
           __vec_2__f32
@@ -126,6 +135,7 @@
 }
 
 using SpvBinaryTest = SpvParserTestBase<::testing::TestWithParam<BinaryData>>;
+using SpvBinaryTestBasic = SpvParserTestBase<::testing::Test>;
 
 TEST_P(SpvBinaryTest, EmitExpression) {
   const auto assembly = CommonTypes() + R"(
@@ -325,6 +335,110 @@
                    AstFor("v2int_40_30")}));
 
 INSTANTIATE_TEST_SUITE_P(
+    SpvParserTest_SDiv_MixedSignednessOperands,
+    SpvBinaryTest,
+    ::testing::Values(
+        // Mixed, returning int, second arg uint
+        BinaryData{"int", "int_30", "OpSDiv", "uint_10", "__i32",
+                   "ScalarConstructor{30}", "divide",
+                   R"(As<__i32>{
+          ScalarConstructor{10}
+        })"},
+        // Mixed, returning int, first arg uint
+        BinaryData{"int", "uint_10", "OpSDiv", "int_30", "__i32",
+                   R"(As<__i32>{
+          ScalarConstructor{10}
+        })",
+                   "divide", "ScalarConstructor{30}"},
+        // Mixed, returning v2int, first arg v2uint
+        BinaryData{"v2int", "v2uint_10_20", "OpSDiv", "v2int_30_40",
+                   "__vec_2__i32", AstFor("cast_int_v2uint_10_20"), "divide",
+                   AstFor("v2int_30_40")},
+        // Mixed, returning v2int, second arg v2uint
+        BinaryData{"v2int", "v2int_30_40", "OpSDiv", "v2uint_10_20",
+                   "__vec_2__i32", AstFor("v2int_30_40"), "divide",
+                   AstFor("cast_int_v2uint_10_20")}));
+
+TEST_F(SpvBinaryTestBasic, SDiv_Scalar_UnsignedResult) {
+  // The WGSL signed division operator expects both operands to be signed
+  // and the result is signed as well.
+  // In this test SPIR-V demands an unsigned result, so we have to
+  // wrap the result with an as-cast.
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSDiv %uint %int_30 %int_40
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+      << p->error() << "\n"
+      << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __u32
+    {
+      As<__u32>{
+        Binary{
+          ScalarConstructor{30}
+          divide
+          ScalarConstructor{40}
+        }
+      }
+    }
+  })"));
+}
+
+TEST_F(SpvBinaryTestBasic, SDiv_Vector_UnsignedResult) {
+  // The WGSL signed division operator expects both operands to be signed
+  // and the result is signed as well.
+  // In this test SPIR-V demands an unsigned result, so we have to
+  // wrap the result with an as-cast.
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSDiv %v2uint %v2int_30_40 %v2int_40_30
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+      << p->error() << "\n"
+      << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __vec_2__u32
+    {
+      As<__vec_2__u32>{
+        Binary{
+          TypeConstructor{
+            __vec_2__i32
+            ScalarConstructor{30}
+            ScalarConstructor{40}
+          }
+          divide
+          TypeConstructor{
+            __vec_2__i32
+            ScalarConstructor{40}
+            ScalarConstructor{30}
+          }
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+INSTANTIATE_TEST_SUITE_P(
     SpvParserTest_FDiv,
     SpvBinaryTest,
     ::testing::Values(
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 795dc7d..3139cb9 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -33,6 +33,8 @@
 #include "source/opt/type_manager.h"
 #include "source/opt/types.h"
 #include "spirv-tools/libspirv.hpp"
+#include "src/ast/as_expression.h"
+#include "src/ast/binary_expression.h"
 #include "src/ast/bool_literal.h"
 #include "src/ast/builtin_decoration.h"
 #include "src/ast/decorated_variable.h"
@@ -119,6 +121,52 @@
   std::vector<const spvtools::opt::Function*> ordered_;
 };
 
+// Returns true if the opcode operates as if its operands are signed integral.
+bool AssumesSignedOperands(SpvOp opcode) {
+  switch (opcode) {
+    case SpvOpSDiv:
+    case SpvOpSRem:
+    case SpvOpSMod:
+    case SpvOpSLessThan:
+    case SpvOpSLessThanEqual:
+    case SpvOpSGreaterThan:
+    case SpvOpSGreaterThanEqual:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
+// Returns true if the opcode operates as if its operands are unsigned integral.
+bool AssumesUnsignedOperands(SpvOp opcode) {
+  switch (opcode) {
+    case SpvOpUDiv:
+    case SpvOpUMod:
+    case SpvOpULessThan:
+    case SpvOpULessThanEqual:
+    case SpvOpUGreaterThan:
+    case SpvOpUGreaterThanEqual:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
+// Returns true if the operation is binary, and the WGSL operation requires
+// the signedness of the result to match the signedness of the first operand.
+bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
+  switch (opcode) {
+    // TODO(dneto): More arithmetic operations.
+    case SpvOpSDiv:
+      return true;
+    default:
+      break;
+  }
+  return false;
+}
+
 }  // namespace
 
 ParserImpl::ParserImpl(Context* ctx, const std::vector<uint32_t>& spv_binary)
@@ -458,11 +506,13 @@
 ast::type::Type* ParserImpl::ConvertType(
     const spvtools::opt::analysis::Integer* int_ty) {
   if (int_ty->width() == 32) {
-    if (int_ty->IsSigned()) {
-      return ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
-    } else {
-      return ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
-    }
+    auto signed_ty =
+        ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
+    auto unsigned_ty =
+        ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
+    signed_type_for_[unsigned_ty] = signed_ty;
+    unsigned_type_for_[signed_ty] = unsigned_ty;
+    return int_ty->IsSigned() ? signed_ty : unsigned_ty;
   }
   Fail() << "unhandled integer width: " << int_ty->width();
   return nullptr;
@@ -484,8 +534,23 @@
   if (ast_elem_ty == nullptr) {
     return nullptr;
   }
-  return ctx_.type_mgr().Get(
+  auto* this_ty = ctx_.type_mgr().Get(
       std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
+  // Generate the opposite-signedness vector type, if this type is integral.
+  if (unsigned_type_for_.count(ast_elem_ty)) {
+    auto* other_ty =
+        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+            unsigned_type_for_[ast_elem_ty], num_elem));
+    signed_type_for_[other_ty] = this_ty;
+    unsigned_type_for_[this_ty] = other_ty;
+  } else if (signed_type_for_.count(ast_elem_ty)) {
+    auto* other_ty =
+        ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+            signed_type_for_[ast_elem_ty], num_elem));
+    unsigned_type_for_[other_ty] = this_ty;
+    signed_type_for_[this_ty] = other_ty;
+  }
+  return this_ty;
 }
 
 ast::type::Type* ParserImpl::ConvertType(
@@ -782,6 +847,7 @@
     Fail() << "ID " << id << " is not a constant";
     return {};
   }
+
   // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
   // So canonicalization should map that way too.
   // Currently "null<type>" is missing from the WGSL parser.
@@ -839,6 +905,53 @@
   return {};
 }
 
+TypedExpression ParserImpl::RectifyOperandSignedness(SpvOp op,
+                                                     TypedExpression&& expr) {
+  const bool requires_signed = AssumesSignedOperands(op);
+  const bool requires_unsigned = AssumesUnsignedOperands(op);
+  if (!requires_signed && !requires_unsigned) {
+    // No conversion is required, assuming our tables are complete.
+    return std::move(expr);
+  }
+  if (!expr.expr) {
+    Fail() << "internal error: RectifyOperandSignedness given a null expr\n";
+    return {};
+  }
+  auto* type = expr.type;
+  if (!type) {
+    Fail() << "internal error: unmapped type for: " << expr.expr->str() << "\n";
+    return {};
+  }
+  if (requires_unsigned) {
+    auto* unsigned_ty = unsigned_type_for_[type];
+    if (unsigned_ty != nullptr) {
+      // Conversion is required.
+      return {unsigned_ty, std::make_unique<ast::AsExpression>(
+                               unsigned_ty, std::move(expr.expr))};
+    }
+  } else if (requires_signed) {
+    auto* signed_ty = signed_type_for_[type];
+    if (signed_ty != nullptr) {
+      // Conversion is required.
+      return {signed_ty, std::make_unique<ast::AsExpression>(
+                             signed_ty, std::move(expr.expr))};
+    }
+  }
+  // We should not reach here.
+  return std::move(expr);
+}
+
+ast::type::Type* ParserImpl::ForcedResultType(
+    SpvOp op,
+    ast::type::Type* first_operand_type) {
+  const bool binary_match_first_operand =
+      AssumesResultSignednessMatchesBinaryFirstOperand(op);
+  if (binary_match_first_operand) {
+    return first_operand_type;
+  }
+  return nullptr;
+}
+
 bool ParserImpl::EmitFunctions() {
   if (!success_) {
     return false;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 91c22e0..2412398 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -30,6 +30,7 @@
 #include "source/opt/type_manager.h"
 #include "source/opt/types.h"
 #include "spirv-tools/libspirv.hpp"
+#include "src/ast/expression.h"
 #include "src/ast/import.h"
 #include "src/ast/module.h"
 #include "src/ast/struct_member_decoration.h"
@@ -245,6 +246,28 @@
   /// @returns a new Literal node
   TypedExpression MakeConstantExpression(uint32_t id);
 
+  /// Converts a given expression to the signedness demanded for an operand
+  /// of the given SPIR-V opcode, if required.  If the operation assumes
+  /// signed integer operands, and |expr| is unsigned, then return an
+  /// as-cast expression converting it to signed. Otherwise, return
+  /// |expr| itself.  Similarly, convert as required from unsigned
+  /// to signed. Assumes all SPIR-V types have been mapped to AST types.
+  /// @param op the SPIR-V opcode
+  /// @param expr an expression
+  /// @returns expr, or a cast of expr
+  TypedExpression RectifyOperandSignedness(SpvOp op, TypedExpression&& expr);
+
+  /// Returns the "forced" result type for the given SPIR-V opcode.
+  /// If the WGSL result type for an operation has a more strict rule than
+  /// requried by SPIR-V, then we say the result type is "forced".  This occurs
+  /// for signed integer division (OpSDiv), for example, where the result type
+  /// in WGSL must match the operand types.
+  /// @param op the SPIR-V opcode
+  /// @param first_operand_type the AST type for the first operand.
+  /// @returns the forced AST result type, or nullptr if no forcing is required.
+  ast::type::Type* ForcedResultType(SpvOp op,
+                                    ast::type::Type* first_operand_type);
+
  private:
   /// Converts a specific SPIR-V type to a Tint type. Integer case
   ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
@@ -303,6 +326,11 @@
 
   // Maps a SPIR-V type ID to a Tint type.
   std::unordered_map<uint32_t, ast::type::Type*> id_to_type_;
+
+  // Maps an unsigned type corresponding to the given signed type.
+  std::unordered_map<ast::type::Type*, ast::type::Type*> signed_type_for_;
+  // Maps an signed type corresponding to the given unsigned type.
+  std::unordered_map<ast::type::Type*, ast::type::Type*> unsigned_type_for_;
 };
 
 }  // namespace spirv