[spirv-reader] Support select over scalars and vectors

Still TODO: OpSelect over arrays and structures, as permitted in SPIR-V 1.4

Bug: tint:3, tint:99
Change-Id: I70f6c8a43ea3339cd715813c6eb0128d66ff0df8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25301
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 3137cc4..c3a6d9b 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2605,7 +2605,8 @@
     default:
       break;
   }
-  return Fail() << "unhandled instruction with opcode " << inst.opcode();
+  return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
+                << inst.PrettyPrint();
 }
 
 TypedExpression FunctionEmitter::MakeOperand(
@@ -2703,6 +2704,10 @@
     return {ast_type, parser_impl_.MakeNullValue(ast_type)};
   }
 
+  if (opcode == SpvOpSelect) {
+    return MakeSimpleSelect(inst);
+  }
+
   // builtin readonly function
   // glsl.std.450 readonly function
 
@@ -3375,6 +3380,33 @@
   return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr));
 }
 
+TypedExpression FunctionEmitter::MakeSimpleSelect(
+    const spvtools::opt::Instruction& inst) {
+  auto condition = MakeOperand(inst, 0);
+  auto operand1 = MakeOperand(inst, 1);
+  auto operand2 = MakeOperand(inst, 2);
+
+  // SPIR-V validation requires:
+  // - the condition to be bool or bool vector, so we don't check it here.
+  // - operand1, operand2, and result type to match.
+  // - you can't select over pointers or pointer vectors, unless you also have
+  //   a VariablePointers* capability, which is not allowed in by WebGPU.
+  auto* op_ty = operand1.type;
+  if (op_ty->IsVector() || op_ty->is_float_scalar() ||
+      op_ty->is_integer_scalar() || op_ty->IsBool()) {
+    ast::ExpressionList params;
+    params.push_back(std::move(operand1.expr));
+    params.push_back(std::move(operand2.expr));
+    // The condition goes last.
+    params.push_back(std::move(condition.expr));
+    return {operand1.type,
+            std::make_unique<ast::CallExpression>(
+                std::make_unique<ast::IdentifierExpression>("select"),
+                std::move(params))};
+  }
+  return {};
+}
+
 }  // namespace spirv
 }  // namespace reader
 }  // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index bde237d..5861570 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -676,6 +676,13 @@
   /// @returns false if emission failed
   bool EmitFunctionCall(const spvtools::opt::Instruction& inst);
 
+  /// Returns an expression for an OpSelect, if its operands are scalars
+  /// or vectors. These translate directly to WGSL select.  Otherwise, return
+  /// an expression with a null owned expression
+  /// @param inst the SPIR-V OpSelect instruction
+  /// @returns a typed expression, or one with a null owned expression
+  TypedExpression MakeSimpleSelect(const spvtools::opt::Instruction& inst);
+
   /// Finds the header block for a structured construct that we can "break"
   /// out from, from deeply nested control flow, if such a block exists.
   /// If the construct is:
diff --git a/src/reader/spirv/function_logical_test.cc b/src/reader/spirv/function_logical_test.cc
index 4a0f1df..8f4272c 100644
--- a/src/reader/spirv/function_logical_test.cc
+++ b/src/reader/spirv/function_logical_test.cc
@@ -1104,12 +1104,187 @@
       << ToString(fe.ast_body());
 }
 
+TEST_F(SpvFUnordTest, Select_BoolCond_BoolParams) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSelect %bool %true %true %false
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+  Variable{
+    x_1
+    none
+    __bool
+    {
+      Call{
+        Identifier{select}
+        (
+          ScalarConstructor{true}
+          ScalarConstructor{false}
+          ScalarConstructor{true}
+        )
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_IntScalarParams) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSelect %uint %true %uint_10 %uint_20
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+  Variable{
+    x_1
+    none
+    __u32
+    {
+      Call{
+        Identifier{select}
+        (
+          ScalarConstructor{10}
+          ScalarConstructor{20}
+          ScalarConstructor{true}
+        )
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_FloatScalarParams) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSelect %float %true %float_50 %float_60
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+  Variable{
+    x_1
+    none
+    __f32
+    {
+      Call{
+        Identifier{select}
+        (
+          ScalarConstructor{50.000000}
+          ScalarConstructor{60.000000}
+          ScalarConstructor{true}
+        )
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_VectorParams) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSelect %v2uint %true %v2uint_10_20 %v2uint_20_10
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+  Variable{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Call{
+        Identifier{select}
+        (
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{10}
+            ScalarConstructor{20}
+          }
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{20}
+            ScalarConstructor{10}
+          }
+          ScalarConstructor{true}
+        )
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_VecBoolCond_VectorParams) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpSelect %v2uint %v2bool_t_f %v2uint_10_20 %v2uint_20_10
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+  Variable{
+    x_1
+    none
+    __vec_2__u32
+    {
+      Call{
+        Identifier{select}
+        (
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{10}
+            ScalarConstructor{20}
+          }
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{20}
+            ScalarConstructor{10}
+          }
+          TypeConstructor{
+            __vec_2__bool
+            ScalarConstructor{true}
+            ScalarConstructor{false}
+          }
+        )
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
 // TODO(dneto): OpAny  - likely builtin function TBD
 // TODO(dneto): OpAll  - likely builtin function TBD
 // TODO(dneto): OpIsNan - likely builtin function TBD
 // TODO(dneto): OpIsInf - likely builtin function TBD
 // TODO(dneto): Kernel-guarded instructions.
-// TODO(dneto): OpSelect - likely builtin function TBD
+// TODO(dneto): OpSelect over more general types, as in SPIR-V 1.4
 
 }  // namespace
 }  // namespace spirv