[spirv-reader][ir] Correctly handle GLSL 450 UMax
The SPIR-V `UMax` method allows unsigned arguments and return types.
This is not permitted in WGSL. The SPIR-V spec states that the argument
is treated as a signed value, so bitcast the argument/result as needed.
Bug: 42250952
Change-Id: Iec13d8771c9f0f4b42610cf442a8bbf1b1bb9d0b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/222754
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc
index 8acd2d0..78d0254 100644
--- a/src/tint/lang/spirv/builtin_fn.cc
+++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -122,6 +122,8 @@
return "smin";
case BuiltinFn::kSclamp:
return "sclamp";
+ case BuiltinFn::kUmax:
+ return "umax";
case BuiltinFn::kSdot:
return "sdot";
case BuiltinFn::kUdot:
@@ -182,6 +184,7 @@
case BuiltinFn::kSmax:
case BuiltinFn::kSmin:
case BuiltinFn::kSclamp:
+ case BuiltinFn::kUmax:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl
index 03ab304..c20bf32 100644
--- a/src/tint/lang/spirv/builtin_fn.cc.tmpl
+++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -79,6 +79,7 @@
case BuiltinFn::kSmax:
case BuiltinFn::kSmin:
case BuiltinFn::kSclamp:
+ case BuiltinFn::kUmax:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h
index 07afe1d..c34ecac 100644
--- a/src/tint/lang/spirv/builtin_fn.h
+++ b/src/tint/lang/spirv/builtin_fn.h
@@ -88,6 +88,7 @@
kSmax,
kSmin,
kSclamp,
+ kUmax,
kSdot,
kUdot,
kNone,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc
index fd411aa..ff0dd41 100644
--- a/src/tint/lang/spirv/intrinsic/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -5675,12 +5675,19 @@
},
{
/* [40] */
+ /* fn umax<R : iu32>[T : iu32, U : iu32](T, U) -> R */
+ /* fn umax<R : iu32>[T : iu32, U : iu32, N : num](vec<N, T>, vec<N, U>) -> vec<N, R> */
+ /* num overloads */ 2,
+ /* overloads */ OverloadIndex(156),
+ },
+ {
+ /* [41] */
/* fn sdot(u32, u32, u32) -> i32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(171),
},
{
- /* [41] */
+ /* [42] */
/* fn udot(u32, u32, u32) -> u32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(172),
diff --git a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
index f135bc1..0d67a1e 100644
--- a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
@@ -646,23 +646,6 @@
)");
}
-TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_UMax) {
- EXPECT_IR(Preamble() + R"(
- %1 = OpExtInst %int %glsl UMax %int_30 %int_35
- %2 = OpExtInst %v2int %glsl UMax %v2int_30_40 %v2int_40_30
- OpReturn
- OpFunctionEnd
- )",
- R"(
-%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
- $B1: {
- let x_1 = bitcast<i32>(max(bitcast<u32>(i1), bitcast<u32>(i2)));
- let x_2 = bitcast<vec2i>(max(bitcast<vec2u>(v2i1), bitcast<vec2u>(v2i2)));
- }
-}
-)");
-}
-
TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_UMin) {
EXPECT_IR(Preamble() + R"(
%1 = OpExtInst %int %glsl UMin %int_30 %int_35
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index c0dfd05..3cd419e 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -74,13 +74,16 @@
Abs(builtin);
break;
case spirv::BuiltinFn::kSmax:
- Max(builtin);
+ SMax(builtin);
break;
case spirv::BuiltinFn::kSmin:
- Min(builtin);
+ SMin(builtin);
break;
case spirv::BuiltinFn::kSclamp:
- Clamp(builtin);
+ SClamp(builtin);
+ break;
+ case spirv::BuiltinFn::kUmax:
+ UMax(builtin);
break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
@@ -112,7 +115,6 @@
auto* new_call = b.Call(result_ty, func, new_args);
core::ir::Value* replacement = new_call->Result(0);
-
if (result_ty->DeepestElement() == ty.u32()) {
new_call->Result(0)->SetType(ty.MatchWidth(ty.i32(), result_ty));
replacement = b.Bitcast(result_ty, replacement)->Result(0);
@@ -126,12 +128,49 @@
WrapSignedSpirvMethods(call, core::BuiltinFn::kSign);
}
void Abs(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kAbs); }
- void Max(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
- void Min(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
- void Clamp(spirv::ir::BuiltinCall* call) {
+ void SMax(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMax); }
+ void SMin(spirv::ir::BuiltinCall* call) { WrapSignedSpirvMethods(call, core::BuiltinFn::kMin); }
+ void SClamp(spirv::ir::BuiltinCall* call) {
WrapSignedSpirvMethods(call, core::BuiltinFn::kClamp);
}
+ // The SPIR-V Unsigned methods all interpret their arguments as unsigned (regardless of the type
+ // of the argument). In order to satisfy this, we must bitcast any signed argument to an
+ // unsigned type before calling the WGSL equivalent method.
+ //
+ // The result of the WGSL method will match the arguments, or in this case an unsigned value. If
+ // the SPIR-V instruction expected a signed result we must bitcast the WGSL result to the
+ // correct signed type.
+ void WrapUnsignedSpirvMethods(spirv::ir::BuiltinCall* call, core::BuiltinFn func) {
+ auto args = call->Args();
+
+ b.InsertBefore(call, [&] {
+ auto* result_ty = call->Result(0)->Type();
+ Vector<core::ir::Value*, 2> new_args;
+
+ for (auto* arg : args) {
+ if (arg->Type()->IsSignedIntegerScalarOrVector()) {
+ arg = b.Bitcast(ty.MatchWidth(ty.u32(), result_ty), arg)->Result(0);
+ }
+ new_args.Push(arg);
+ }
+
+ auto* new_call = b.Call(result_ty, func, new_args);
+
+ core::ir::Value* replacement = new_call->Result(0);
+ if (result_ty->DeepestElement() == ty.i32()) {
+ new_call->Result(0)->SetType(ty.MatchWidth(ty.u32(), result_ty));
+ replacement = b.Bitcast(result_ty, replacement)->Result(0);
+ }
+ call->Result(0)->ReplaceAllUsesWith(replacement);
+ });
+ call->Destroy();
+ }
+
+ void UMax(spirv::ir::BuiltinCall* call) {
+ WrapUnsignedSpirvMethods(call, core::BuiltinFn::kMax);
+ }
+
void Normalize(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index 2af04bc..6e87fd1 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -1170,10 +1170,10 @@
return out;
}
-using SpirvParser_BuiltinsTest_TwoParam =
+using SpirvParser_BuiltinsTest_TwoParamSigned =
core::ir::transform::TransformTestWithParam<SpirvReaderParams>;
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, UnsignedToUnsigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, UnsignedToUnsigned) {
auto& params = GetParam();
auto* ep = b.ComputeFunction("foo");
@@ -1222,7 +1222,7 @@
EXPECT_EQ(expect, str());
}
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, SignedToSigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, SignedToSigned) {
auto params = GetParam();
auto* ep = b.ComputeFunction("foo");
@@ -1265,7 +1265,7 @@
EXPECT_EQ(expect, str());
}
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, MixedToUnsigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, MixedToUnsigned) {
auto params = GetParam();
auto* ep = b.ComputeFunction("foo");
@@ -1312,7 +1312,7 @@
EXPECT_EQ(expect, str());
}
-TEST_P(SpirvParser_BuiltinsTest_TwoParam, MixedToSigned) {
+TEST_P(SpirvParser_BuiltinsTest_TwoParamSigned, MixedToSigned) {
auto params = GetParam();
auto* ep = b.ComputeFunction("foo");
@@ -1358,10 +1358,201 @@
}
INSTANTIATE_TEST_SUITE_P(SpirvReader,
- SpirvParser_BuiltinsTest_TwoParam,
+ SpirvParser_BuiltinsTest_TwoParamSigned,
::testing::Values(SpirvReaderParams{spirv::BuiltinFn::kSmax, "max"},
SpirvReaderParams{spirv::BuiltinFn::kSmin, "min"}));
+using SpirvParser_BuiltinsTest_TwoParamUnsigned =
+ core::ir::transform::TransformTestWithParam<SpirvReaderParams>;
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, UnsignedToUnsigned) {
+ auto& params = GetParam();
+
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.u32(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()}, 10_u, 15_u);
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.vec2<u32>(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<u32>(), 10_u), b.Splat(ty.vec2<u32>(), 15_u));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.u)" +
+ params.name + R"(<u32> 10u, 15u
+ %3:vec2<u32> = spirv.u)" +
+ params.name + R"(<u32> vec2<u32>(10u), vec2<u32>(15u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = )" + params.name +
+ R"( 10u, 15u
+ %3:vec2<u32> = )" +
+ params.name + R"( vec2<u32>(10u), vec2<u32>(15u)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, SignedToSigned) {
+ auto params = GetParam();
+
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.i32(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()}, 10_i, 15_i);
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.vec2<i32>(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<i32>(), 10_i), b.Splat(ty.vec2<i32>(), 15_i));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.u)" +
+ params.name + R"(<i32> 10i, 15i
+ %3:vec2<i32> = spirv.u)" +
+ params.name + R"(<i32> vec2<i32>(10i), vec2<i32>(15i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = bitcast 10i
+ %3:u32 = bitcast 15i
+ %4:u32 = )" + params.name +
+ R"( %2, %3
+ %5:i32 = bitcast %4
+ %6:vec2<u32> = bitcast vec2<i32>(10i)
+ %7:vec2<u32> = bitcast vec2<i32>(15i)
+ %8:vec2<u32> = )" +
+ params.name + R"( %6, %7
+ %9:vec2<i32> = bitcast %8
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, MixedToUnsigned) {
+ auto params = GetParam();
+
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.u32(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()}, 10_i, 10_u);
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.vec2<u32>(), params.fn, Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<i32>(), 10_i), b.Splat(ty.vec2<u32>(), 10_u));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.u)" +
+ params.name + R"(<u32> 10i, 10u
+ %3:vec2<u32> = spirv.u)" +
+ params.name + R"(<u32> vec2<i32>(10i), vec2<u32>(10u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = bitcast 10i
+ %3:u32 = )" + params.name +
+ R"( %2, 10u
+ %4:vec2<u32> = bitcast vec2<i32>(10i)
+ %5:vec2<u32> = )" +
+ params.name + R"( %4, vec2<u32>(10u)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_P(SpirvParser_BuiltinsTest_TwoParamUnsigned, MixedToSigned) {
+ auto params = GetParam();
+
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.i32(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()}, 10_u, 10_i);
+ b.CallExplicit<spirv::ir::BuiltinCall>(
+ ty.vec2<i32>(), params.fn, Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<u32>(), 10_u), b.Splat(ty.vec2<i32>(), 10_i));
+ b.Return(ep);
+ });
+
+ auto src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.u)" +
+ params.name + R"(<i32> 10u, 10i
+ %3:vec2<i32> = spirv.u)" +
+ params.name + R"(<i32> vec2<u32>(10u), vec2<i32>(10i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = bitcast 10i
+ %3:u32 = )" + params.name +
+ R"( 10u, %2
+ %4:i32 = bitcast %3
+ %5:vec2<u32> = bitcast vec2<i32>(10i)
+ %6:vec2<u32> = )" +
+ params.name + R"( vec2<u32>(10u), %5
+ %7:vec2<i32> = bitcast %6
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+INSTANTIATE_TEST_SUITE_P(SpirvReader,
+ SpirvParser_BuiltinsTest_TwoParamUnsigned,
+ ::testing::Values(SpirvReaderParams{spirv::BuiltinFn::kUmax, "max"}));
+
TEST_F(SpirvParser_BuiltinsTest, SClamp_UnsignedToUnsigned) {
auto* ep = b.ComputeFunction("foo");
diff --git a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
index 4a4f22d..32010c6 100644
--- a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
@@ -373,7 +373,8 @@
INSTANTIATE_TEST_SUITE_P(SpirvParser,
GlslStd450TwoParamTest,
::testing::Values(GlslStd450TwoParams{"SMax", "smax"},
- GlslStd450TwoParams{"SMin", "smin"}));
+ GlslStd450TwoParams{"SMin", "smin"},
+ GlslStd450TwoParams{"UMax", "umax"}));
TEST_F(SpirvParserTest, GlslStd450_SClamp_UnsignedToUnsigned) {
EXPECT_IR(Preamble() + R"(
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 15de75e..9258df4 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -699,7 +699,6 @@
case GLSLstd450NMin:
case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
return core::BuiltinFn::kMin;
- case GLSLstd450UMax:
case GLSLstd450NMax:
case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
return core::BuiltinFn::kMax;
@@ -770,6 +769,8 @@
return spirv::BuiltinFn::kSmin;
case GLSLstd450SClamp:
return spirv::BuiltinFn::kSclamp;
+ case GLSLstd450UMax:
+ return spirv::BuiltinFn::kUmax;
default:
break;
}
@@ -780,7 +781,7 @@
const core::type::Type* result_ty) {
if (ext_opcode == GLSLstd450SSign || ext_opcode == GLSLstd450SAbs ||
ext_opcode == GLSLstd450SMax || ext_opcode == GLSLstd450SMin ||
- ext_opcode == GLSLstd450SClamp) {
+ ext_opcode == GLSLstd450SClamp || ext_opcode == GLSLstd450UMax) {
return {result_ty->DeepestElement()};
}
return {};
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index 5efb7b5..1202a8f 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -338,6 +338,9 @@
implicit(T: iu32, U: iu32, V: iu32) fn sclamp<R: iu32>(T, U, V) -> R
implicit(T: iu32, U: iu32, V: iu32, N: num) fn sclamp<R: iu32>(vec<N, T>, vec<N, U>, vec<N, V>) -> vec<N, R>
+implicit(T: iu32, U: iu32) fn umax<R: iu32>(T, U) -> R
+implicit(T: iu32, U: iu32, N: num) fn umax<R: iu32>(vec<N, T>, vec<N, U>) -> vec<N, R>
+
////////////////////////////////////////////////////////////////////////////////
// SPV_KHR_integer_dot_product instructions
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 402695a..e768da5 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1387,6 +1387,9 @@
case spirv::BuiltinFn::kSclamp:
ext_inst(GLSLstd450SClamp);
break;
+ case spirv::BuiltinFn::kUmax:
+ ext_inst(GLSLstd450UMax);
+ break;
case spirv::BuiltinFn::kNormalize:
ext_inst(GLSLstd450Normalize);
break;