[spirv-reader] Support select over scalars and vectors
Still TODO: OpSelect over arrays and structures, as permitted in SPIR-V 1.4
Bug: tint:3, tint:99
Change-Id: I70f6c8a43ea3339cd715813c6eb0128d66ff0df8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25301
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 3137cc4..c3a6d9b 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2605,7 +2605,8 @@
default:
break;
}
- return Fail() << "unhandled instruction with opcode " << inst.opcode();
+ return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
+ << inst.PrettyPrint();
}
TypedExpression FunctionEmitter::MakeOperand(
@@ -2703,6 +2704,10 @@
return {ast_type, parser_impl_.MakeNullValue(ast_type)};
}
+ if (opcode == SpvOpSelect) {
+ return MakeSimpleSelect(inst);
+ }
+
// builtin readonly function
// glsl.std.450 readonly function
@@ -3375,6 +3380,33 @@
return EmitConstDefOrWriteToHoistedVar(inst, std::move(expr));
}
+TypedExpression FunctionEmitter::MakeSimpleSelect(
+ const spvtools::opt::Instruction& inst) {
+ auto condition = MakeOperand(inst, 0);
+ auto operand1 = MakeOperand(inst, 1);
+ auto operand2 = MakeOperand(inst, 2);
+
+ // SPIR-V validation requires:
+ // - the condition to be bool or bool vector, so we don't check it here.
+ // - operand1, operand2, and result type to match.
+ // - you can't select over pointers or pointer vectors, unless you also have
+ // a VariablePointers* capability, which is not allowed in by WebGPU.
+ auto* op_ty = operand1.type;
+ if (op_ty->IsVector() || op_ty->is_float_scalar() ||
+ op_ty->is_integer_scalar() || op_ty->IsBool()) {
+ ast::ExpressionList params;
+ params.push_back(std::move(operand1.expr));
+ params.push_back(std::move(operand2.expr));
+ // The condition goes last.
+ params.push_back(std::move(condition.expr));
+ return {operand1.type,
+ std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("select"),
+ std::move(params))};
+ }
+ return {};
+}
+
} // namespace spirv
} // namespace reader
} // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index bde237d..5861570 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -676,6 +676,13 @@
/// @returns false if emission failed
bool EmitFunctionCall(const spvtools::opt::Instruction& inst);
+ /// Returns an expression for an OpSelect, if its operands are scalars
+ /// or vectors. These translate directly to WGSL select. Otherwise, return
+ /// an expression with a null owned expression
+ /// @param inst the SPIR-V OpSelect instruction
+ /// @returns a typed expression, or one with a null owned expression
+ TypedExpression MakeSimpleSelect(const spvtools::opt::Instruction& inst);
+
/// Finds the header block for a structured construct that we can "break"
/// out from, from deeply nested control flow, if such a block exists.
/// If the construct is:
diff --git a/src/reader/spirv/function_logical_test.cc b/src/reader/spirv/function_logical_test.cc
index 4a0f1df..8f4272c 100644
--- a/src/reader/spirv/function_logical_test.cc
+++ b/src/reader/spirv/function_logical_test.cc
@@ -1104,12 +1104,187 @@
<< ToString(fe.ast_body());
}
+TEST_F(SpvFUnordTest, Select_BoolCond_BoolParams) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpSelect %bool %true %true %false
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+ Variable{
+ x_1
+ none
+ __bool
+ {
+ Call{
+ Identifier{select}
+ (
+ ScalarConstructor{true}
+ ScalarConstructor{false}
+ ScalarConstructor{true}
+ )
+ }
+ }
+ }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_IntScalarParams) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpSelect %uint %true %uint_10 %uint_20
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+ Variable{
+ x_1
+ none
+ __u32
+ {
+ Call{
+ Identifier{select}
+ (
+ ScalarConstructor{10}
+ ScalarConstructor{20}
+ ScalarConstructor{true}
+ )
+ }
+ }
+ }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_FloatScalarParams) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpSelect %float %true %float_50 %float_60
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+ Variable{
+ x_1
+ none
+ __f32
+ {
+ Call{
+ Identifier{select}
+ (
+ ScalarConstructor{50.000000}
+ ScalarConstructor{60.000000}
+ ScalarConstructor{true}
+ )
+ }
+ }
+ }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_BoolCond_VectorParams) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpSelect %v2uint %true %v2uint_10_20 %v2uint_20_10
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+ Variable{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Call{
+ Identifier{select}
+ (
+ TypeConstructor{
+ __vec_2__u32
+ ScalarConstructor{10}
+ ScalarConstructor{20}
+ }
+ TypeConstructor{
+ __vec_2__u32
+ ScalarConstructor{20}
+ ScalarConstructor{10}
+ }
+ ScalarConstructor{true}
+ )
+ }
+ }
+ }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvFUnordTest, Select_VecBoolCond_VectorParams) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpSelect %v2uint %v2bool_t_f %v2uint_10_20 %v2uint_20_10
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableDeclStatement{
+ Variable{
+ x_1
+ none
+ __vec_2__u32
+ {
+ Call{
+ Identifier{select}
+ (
+ TypeConstructor{
+ __vec_2__u32
+ ScalarConstructor{10}
+ ScalarConstructor{20}
+ }
+ TypeConstructor{
+ __vec_2__u32
+ ScalarConstructor{20}
+ ScalarConstructor{10}
+ }
+ TypeConstructor{
+ __vec_2__bool
+ ScalarConstructor{true}
+ ScalarConstructor{false}
+ }
+ )
+ }
+ }
+ }
+})")) << ToString(fe.ast_body());
+}
+
// TODO(dneto): OpAny - likely builtin function TBD
// TODO(dneto): OpAll - likely builtin function TBD
// TODO(dneto): OpIsNan - likely builtin function TBD
// TODO(dneto): OpIsInf - likely builtin function TBD
// TODO(dneto): Kernel-guarded instructions.
-// TODO(dneto): OpSelect - likely builtin function TBD
+// TODO(dneto): OpSelect over more general types, as in SPIR-V 1.4
} // namespace
} // namespace spirv