[spirv-reader] Add OpSNegate

Bug: tint:3
Change-Id: Id396319dd32216a71e21464d41bb2f2545929207
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19882
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index c76f882..a97cba8 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -27,6 +27,8 @@
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/uint_literal.h"
+#include "src/ast/unary_op.h"
+#include "src/ast/unary_op_expression.h"
 #include "src/ast/variable.h"
 #include "src/ast/variable_decl_statement.h"
 #include "src/reader/spirv/fail_stream.h"
@@ -37,6 +39,25 @@
 namespace spirv {
 
 namespace {
+
+// Gets the AST unary opcode for the given SPIR-V opcode, if any
+// @param opcode SPIR-V opcode
+// @param ast_unary_op return parameter
+// @returns true if it was a unary operation
+bool GetUnaryOp(SpvOp opcode, ast::UnaryOp* ast_unary_op) {
+  switch (opcode) {
+    case SpvOpSNegate:
+      *ast_unary_op = ast::UnaryOp::kNegation;
+      return true;
+    // TODO(dneto): SpvOpNegate SpvOpNot SpvLogicalNot
+    default:
+      break;
+  }
+  return false;
+}
+
+// Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
+// @param opcode SPIR-V opcode
 // @returns the AST binary op for the given opcode, or kNone
 ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
   switch (opcode) {
@@ -386,7 +407,20 @@
     return {ast_type, std::move(binary_expr)};
   }
 
-  // unary operator
+  auto unary_op = ast::UnaryOp::kNegation;
+  if (GetUnaryOp(inst.opcode(), &unary_op)) {
+    auto arg0 = operand(0);
+    auto unary_expr = std::make_unique<ast::UnaryOpExpression>(
+        unary_op, std::move(arg0.expr));
+    auto* forced_result_ty =
+        parser_impl_.ForcedResultType(inst.opcode(), arg0.type);
+    if (forced_result_ty && forced_result_ty != ast_type) {
+      return {ast_type, std::make_unique<ast::AsExpression>(
+                            ast_type, std::move(unary_expr))};
+    }
+    return {ast_type, std::move(unary_expr)};
+  }
+
   // builtin readonly function
   // glsl.std.450 readonly function
 
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index cd9f8a6..acbfb28 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -117,6 +117,223 @@
   return "bad case";
 }
 
+using SpvUnaryArithTest = SpvParserTestBase<::testing::Test>;
+
+TEST_F(SpvUnaryArithTest, SNegate_Int_Int) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %int %int_30
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __i32
+    {
+      UnaryOp{
+        negation
+        ScalarConstructor{30}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_Int_Uint) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %int %uint_10
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __i32
+    {
+      UnaryOp{
+        negation
+        As<__i32>{
+          ScalarConstructor{10}
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_Uint_Int) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %uint %int_30
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __u32
+    {
+      As<__u32>{
+        UnaryOp{
+          negation
+          ScalarConstructor{30}
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_Uint_Uint) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %uint %uint_10
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __u32
+    {
+      As<__u32>{
+        UnaryOp{
+          negation
+          As<__i32>{
+            ScalarConstructor{10}
+          }
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_SignedVec_SignedVec) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %v2int %v2int_30_40
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __vec_2__i32
+    {
+      UnaryOp{
+        negation
+        TypeConstructor{
+          __vec_2__i32
+          ScalarConstructor{30}
+          ScalarConstructor{40}
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_SignedVec_UnsignedVec) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %v2int %v2uint_10_20
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __vec_2__i32
+    {
+      UnaryOp{
+        negation
+        As<__vec_2__i32>{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{10}
+            ScalarConstructor{20}
+          }
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvUnaryArithTest, SNegate_UnsignedVec_UnsignedVec) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSNegate %v2uint %v2uint_10_20
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __vec_2__u32
+    {
+      As<__vec_2__u32>{
+        UnaryOp{
+          negation
+          As<__vec_2__i32>{
+            TypeConstructor{
+              __vec_2__u32
+              ScalarConstructor{10}
+              ScalarConstructor{20}
+            }
+          }
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
 struct BinaryData {
   const std::string res_type;
   const std::string lhs;
@@ -804,6 +1021,28 @@
         BinaryData{"v2int", "v2int_40_30", "OpBitwiseXor", "v2uint_20_10",
                    "__vec_2__i32", AstFor("v2int_40_30"), "xor",
                    AstFor("v2uint_20_10")}));
+
+// TODO(dneto): OpSNegate
+// TODO(dneto): OpFNegate
+
+// TODO(dneto): OpSRem. Missing from WGSL
+// https://github.com/gpuweb/gpuweb/issues/702
+
+// TODO(dneto): OpFRem. Missing from WGSL
+// https://github.com/gpuweb/gpuweb/issues/702
+
+// TODO(dneto): OpVectorTimesScalar
+// TODO(dneto): OpMatrixTimesScalar
+// TODO(dneto): OpVectorTimesMatrix
+// TODO(dneto): OpMatrixTimesVector
+// TODO(dneto): OpMatrixTimesMatrix
+// TODO(dneto): OpOuterProduct
+// TODO(dneto): OpDot
+// TODO(dneto): OpIAddCarry
+// TODO(dneto): OpISubBorrow
+// TODO(dneto): OpIMulExtended
+// TODO(dneto): OpSMulExtended
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index fa8b9f5..ba7957b 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -60,6 +60,7 @@
 #include "src/ast/type/void_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/uint_literal.h"
+#include "src/ast/unary_op_expression.h"
 #include "src/ast/variable.h"
 #include "src/ast/variable_decl_statement.h"
 #include "src/ast/variable_decoration.h"
@@ -124,6 +125,7 @@
 // Returns true if the opcode operates as if its operands are signed integral.
 bool AssumesSignedOperands(SpvOp opcode) {
   switch (opcode) {
+    case SpvOpSNegate:
     case SpvOpSDiv:
     case SpvOpSRem:
     case SpvOpSMod:
@@ -158,9 +160,9 @@
 // the signedness of the result to match the signedness of the first operand.
 bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
   switch (opcode) {
-    // TODO(dneto): More arithmetic operations.
     case SpvOpSDiv:
     case SpvOpSMod:
+    case SpvOpSRem:
       return true;
     default:
       break;
@@ -947,7 +949,8 @@
     ast::type::Type* first_operand_type) {
   const bool binary_match_first_operand =
       AssumesResultSignednessMatchesBinaryFirstOperand(op);
-  if (binary_match_first_operand) {
+  const bool unary_match_operand = (op == SpvOpSNegate);
+  if (binary_match_first_operand || unary_match_operand) {
     return first_operand_type;
   }
   return nullptr;