[spirv-reader][ir] Add OpSelect support.
Bug: 391486001
Change-Id: I8f198accb403336f3ad04563c693e441163a6f60
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/229714
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 74cdf3c..7c1cd54 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -212,12 +212,25 @@
case spirv::BuiltinFn::kFMod:
FMod(builtin);
break;
+ case spirv::BuiltinFn::kSelect:
+ Select(builtin);
+ break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
}
}
+ void Select(spirv::ir::BuiltinCall* call) {
+ auto* cond = call->Args()[0];
+ auto* true_ = call->Args()[1];
+ auto* false_ = call->Args()[2];
+ b.InsertBefore(call, [&] {
+ b.CallWithResult(call->DetachResult(), core::BuiltinFn::kSelect, false_, true_, cond);
+ });
+ call->Destroy();
+ }
+
// FMod(x, y) emulated with: x - y * floor(x / y)
void FMod(spirv::ir::BuiltinCall* call) {
auto* x = call->Args()[0];
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index a89ab00..82b806b 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -9062,5 +9062,64 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvReader_BuiltinsTest, Select_Scalar) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.Call<spirv::ir::BuiltinCall>(ty.f32(), spirv::BuiltinFn::kSelect, true, 1_f, 2_f);
+ b.Return(ep);
+ });
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = spirv.select true, 1.0f, 2.0f
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = select 2.0f, 1.0f, true
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+TEST_F(SpirvReader_BuiltinsTest, Select_Vector) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.Call<spirv::ir::BuiltinCall>(ty.vec2<f32>(), spirv::BuiltinFn::kSelect,
+ b.Splat<vec2<bool>>(false), b.Splat<vec2<f32>>(1_f),
+ b.Splat<vec2<f32>>(2_f));
+ b.Return(ep);
+ });
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<f32> = spirv.select vec2<bool>(false), vec2<f32>(1.0f), vec2<f32>(2.0f)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<f32> = select vec2<f32>(2.0f), vec2<f32>(1.0f), vec2<bool>(false)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::spirv::reader::lower
diff --git a/src/tint/lang/spirv/reader/parser/builtin_test.cc b/src/tint/lang/spirv/reader/parser/builtin_test.cc
index f6af996..f418106 100644
--- a/src/tint/lang/spirv/reader/parser/builtin_test.cc
+++ b/src/tint/lang/spirv/reader/parser/builtin_test.cc
@@ -1527,5 +1527,67 @@
)");
}
+TEST_F(SpirvParserTest, SelectScalar) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %ep_type = OpTypeFunction %void
+ %float_50 = OpConstant %float 50
+ %float_60 = OpConstant %float 60
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpSelect %float %true %float_50 %float_60
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = spirv.select true, 50.0f, 60.0f
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, SelectVector) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %v2float = OpTypeVector %float 2
+ %v2bool = OpTypeVector %bool 2
+ %ep_type = OpTypeFunction %void
+ %float_50 = OpConstant %float 50
+ %float_60 = OpConstant %float 60
+%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
+%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
+ %true = OpConstantTrue %bool
+ %false = OpConstantFalse %bool
+%true_false_vec2 = OpConstantComposite %v2bool %true %false
+ %main = OpFunction %void None %ep_type
+ %entry = OpLabel
+ %1 = OpSelect %v2float %true_false_vec2 %v2float_50_60 %v2float_60_50
+ OpReturn
+ OpFunctionEnd)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec2<f32> = spirv.select vec2<bool>(true, false), vec2<f32>(50.0f, 60.0f), vec2<f32>(60.0f, 50.0f)
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 0c0435b..c515514 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -1002,6 +1002,9 @@
case spv::Op::OpFMod:
EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kFMod);
break;
+ case spv::Op::OpSelect:
+ EmitSpirvBuiltinCall(inst, spirv::BuiltinFn::kSelect);
+ break;
default:
TINT_UNIMPLEMENTED()
<< "unhandled SPIR-V instruction: " << static_cast<uint32_t>(inst.opcode());