[spirv-reader][ir] Correctly handle GLSL 450 SAbs
The SPIR-V `SAbs` treats all arguments as if they were an `i32`. The
argument type does not need to match the return type. This differs from
WGSL where the types must match.
Bug: 42250952
Change-Id: Ied37500361cdff635997dcb4313d90cc0614060c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/221735
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc
index 6016b84..462d047 100644
--- a/src/tint/lang/spirv/builtin_fn.cc
+++ b/src/tint/lang/spirv/builtin_fn.cc
@@ -114,6 +114,8 @@
return "inverse";
case BuiltinFn::kSign:
return "sign";
+ case BuiltinFn::kAbs:
+ return "abs";
case BuiltinFn::kSdot:
return "sdot";
case BuiltinFn::kUdot:
@@ -170,6 +172,7 @@
case BuiltinFn::kNormalize:
case BuiltinFn::kInverse:
case BuiltinFn::kSign:
+ case BuiltinFn::kAbs:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl
index 7ffcae8..d4dead77 100644
--- a/src/tint/lang/spirv/builtin_fn.cc.tmpl
+++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl
@@ -75,6 +75,7 @@
case BuiltinFn::kNormalize:
case BuiltinFn::kInverse:
case BuiltinFn::kSign:
+ case BuiltinFn::kAbs:
break;
}
return core::ir::Instruction::Accesses{};
diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h
index 55084ef..7b5b2c6 100644
--- a/src/tint/lang/spirv/builtin_fn.h
+++ b/src/tint/lang/spirv/builtin_fn.h
@@ -84,6 +84,7 @@
kNormalize,
kInverse,
kSign,
+ kAbs,
kSdot,
kUdot,
kNone,
diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc
index 30b644d..deee63b 100644
--- a/src/tint/lang/spirv/intrinsic/data.cc
+++ b/src/tint/lang/spirv/intrinsic/data.cc
@@ -5503,12 +5503,19 @@
},
{
/* [36] */
+ /* fn abs<R : iu32>[T : iu32](T) -> R */
+ /* fn abs<R : iu32>[T : iu32, N : num](vec<N, T>) -> vec<N, R> */
+ /* num overloads */ 2,
+ /* overloads */ OverloadIndex(154),
+ },
+ {
+ /* [37] */
/* fn sdot(u32, u32, u32) -> i32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(167),
},
{
- /* [37] */
+ /* [38] */
/* fn udot(u32, u32, u32) -> u32 */
/* num overloads */ 1,
/* overloads */ OverloadIndex(168),
diff --git a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
index f2cc8f3..c69e6b3 100644
--- a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc
@@ -696,8 +696,7 @@
INSTANTIATE_TEST_SUITE_P(SpirvReader,
SpirvReaderTest_GlslStd450_Inting_Inting,
- ::testing::Values(GlslStd450Case{"SAbs", "abs"},
- GlslStd450Case{"FindILsb", "firstTrailingBit"},
+ ::testing::Values(GlslStd450Case{"FindILsb", "firstTrailingBit"},
GlslStd450Case{"FindSMsb", "firstLeadingBit"},
GlslStd450Case{"SSign", "sign"}));
@@ -932,26 +931,6 @@
)");
}
-// Check that we convert signedness of operands and result type.
-// This is needed for each of the integer-based extended instructions.
-
-TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_SAbs) {
- EXPECT_IR(Preamble() + R"(
- %1 = OpExtInst %uint %glsl SAbs %uint_10
- %2 = OpExtInst %v2uint %glsl SAbs %v2uint_10_20
- OpReturn
- OpFunctionEnd
- )",
- R"(
-%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
- $B1: {
- let x_1 = bitcast<u32>(abs(bitcast<i32>(u1)));
- let x_2 = bitcast<vec2u>(abs(bitcast<vec2i>(v2u1)));
- }
-}
-)");
-}
-
TEST_F(SpirvReaderTest, DISABLED_RectifyOperandsAndResult_SMax) {
EXPECT_IR(Preamble() + R"(
%1 = OpExtInst %uint %glsl SMax %uint_10 %uint_15
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 4746914..1680866 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -70,6 +70,9 @@
case spirv::BuiltinFn::kSign:
Sign(builtin);
break;
+ case spirv::BuiltinFn::kAbs:
+ Abs(builtin);
+ break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
@@ -101,6 +104,28 @@
call->Destroy();
}
+ void Abs(spirv::ir::BuiltinCall* call) {
+ auto* arg = call->Args()[0];
+
+ b.InsertBefore(call, [&] {
+ auto* result_ty = call->Result(0)->Type();
+ if (arg->Type()->IsUnsignedIntegerScalarOrVector()) {
+ arg = b.Bitcast(ty.MatchWidth(ty.i32(), result_ty), arg)->Result(0);
+ }
+ auto* new_call =
+ b.Call(result_ty, core::BuiltinFn::kAbs, Vector<core::ir::Value*, 1>{arg});
+
+ core::ir::Value* replacement = new_call->Result(0);
+ // If the call is a `u32` result type, we need to cast it to `i32`.
+ if (result_ty->DeepestElement() == ty.u32()) {
+ new_call->Result(0)->SetType(ty.MatchWidth(ty.i32(), result_ty));
+ replacement = b.Bitcast(result_ty, replacement)->Result(0);
+ }
+ call->Result(0)->ReplaceAllUsesWith(replacement);
+ });
+ call->Destroy();
+ }
+
void Normalize(spirv::ir::BuiltinCall* call) {
auto* arg = call->Args()[0];
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index d622215..c30e589 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -1003,5 +1003,161 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvParser_BuiltinsTest, SAbs_UnsignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.u32()}, 10_u);
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<u32>(), 10_u));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.abs<u32> 10u
+ %3:vec2<u32> = spirv.abs<u32> vec2<u32>(10u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = bitcast 10u
+ %3:i32 = abs %2
+ %4:u32 = bitcast %3
+ %5:vec2<i32> = bitcast vec2<u32>(10u)
+ %6:vec2<i32> = abs %5
+ %7:vec2<u32> = bitcast %6
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_UnsignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.i32()}, 10_u);
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<u32>(), 10_u));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.abs<i32> 10u
+ %3:vec2<i32> = spirv.abs<i32> vec2<u32>(10u)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = bitcast 10u
+ %3:i32 = abs %2
+ %4:vec2<i32> = bitcast vec2<u32>(10u)
+ %5:vec2<i32> = abs %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_SignedToSigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.i32(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.i32()}, 10_i);
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<i32>(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.i32()},
+ b.Splat(ty.vec2<i32>(), 10_i));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.abs<i32> 10i
+ %3:vec2<i32> = spirv.abs<i32> vec2<i32>(10i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = abs 10i
+ %3:vec2<i32> = abs vec2<i32>(10i)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvParser_BuiltinsTest, SAbs_SignedToUnsigned) {
+ auto* ep = b.ComputeFunction("foo");
+
+ b.Append(ep->Block(), [&] { //
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.u32()}, 10_i);
+ b.CallExplicit<spirv::ir::BuiltinCall>(ty.vec2<u32>(), spirv::BuiltinFn::kAbs,
+ Vector<const core::type::Type*, 1>{ty.u32()},
+ b.Splat(ty.vec2<i32>(), 10_i));
+ b.Return(ep);
+ });
+
+ auto* src = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.abs<u32> 10i
+ %3:vec2<u32> = spirv.abs<u32> vec2<i32>(10i)
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ Run(Builtins);
+
+ auto* expect = R"(
+%foo = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = abs 10i
+ %3:u32 = bitcast %2
+ %4:vec2<i32> = abs vec2<i32>(10i)
+ %5:vec2<u32> = bitcast %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::spirv::reader::lower
diff --git a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
index b69f6a7..26ee356 100644
--- a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
+++ b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc
@@ -41,8 +41,12 @@
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
%float = OpTypeFloat 32
+ %v2int = OpTypeVector %int 2
+ %v2uint = OpTypeVector %uint 2
%v2float = OpTypeVector %float 2
%v3float = OpTypeVector %float 3
%v4float = OpTypeVector %float 4
@@ -50,10 +54,19 @@
%mat3v3float = OpTypeMatrix %v3float 3
%mat4v4float = OpTypeMatrix %v4float 4
+ %int_10 = OpConstant %int 10
+ %int_20 = OpConstant %int 20
+
+ %uint_10 = OpConstant %uint 10
+ %uint_20 = OpConstant %uint 20
+
%float_50 = OpConstant %float 50
%float_60 = OpConstant %float 60
%float_70 = OpConstant %float 70
+ %v2int_10_20 = OpConstantComposite %v2int %int_10 %int_20
+ %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
+
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
%v3float_50_60_70 = OpConstantComposite %v3float %float_50 %float_60 %float_70
%v4float_50_50_50_50 = OpConstantComposite %v4float %float_50 %float_50 %float_50 %float_50
@@ -143,5 +156,93 @@
)");
}
+TEST_F(SpirvParserTest, GlslStd450_SAbs_UnsignedToUnsigned) {
+ EXPECT_IR(Preamble() + R"(
+ %1 = OpExtInst %uint %glsl SAbs %uint_10
+ %2 = OpExtInst %v2uint %glsl SAbs %v2uint_10_20
+ %3 = OpCopyObject %uint %1
+ %4 = OpCopyObject %v2uint %2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.abs<u32> 10u
+ %3:vec2<u32> = spirv.abs<u32> vec2<u32>(10u, 20u)
+ %4:u32 = let %2
+ %5:vec2<u32> = let %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_UnsignedToSigned) {
+ EXPECT_IR(Preamble() + R"(
+ %1 = OpExtInst %int %glsl SAbs %uint_10
+ %2 = OpExtInst %v2int %glsl SAbs %v2uint_10_20
+ %3 = OpCopyObject %int %1
+ %4 = OpCopyObject %v2int %2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.abs<i32> 10u
+ %3:vec2<i32> = spirv.abs<i32> vec2<u32>(10u, 20u)
+ %4:i32 = let %2
+ %5:vec2<i32> = let %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_SignedToUnsigned) {
+ EXPECT_IR(Preamble() + R"(
+ %1 = OpExtInst %uint %glsl SAbs %int_10
+ %2 = OpExtInst %v2uint %glsl SAbs %v2int_10_20
+ %3 = OpCopyObject %uint %1
+ %4 = OpCopyObject %v2uint %2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = spirv.abs<u32> 10i
+ %3:vec2<u32> = spirv.abs<u32> vec2<i32>(10i, 20i)
+ %4:u32 = let %2
+ %5:vec2<u32> = let %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, GlslStd450_SAbs_SignedToSigned) {
+ EXPECT_IR(Preamble() + R"(
+ %1 = OpExtInst %int %glsl SAbs %int_10
+ %2 = OpExtInst %v2int %glsl SAbs %v2int_10_20
+ %3 = OpCopyObject %int %1
+ %4 = OpCopyObject %v2int %2
+ OpReturn
+ OpFunctionEnd
+ )",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = spirv.abs<i32> 10i
+ %3:vec2<i32> = spirv.abs<i32> vec2<i32>(10i, 20i)
+ %4:i32 = let %2
+ %5:vec2<i32> = let %3
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index db2d433..ea3ce72 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -664,7 +664,6 @@
case GLSLstd450Exp2:
return core::BuiltinFn::kExp2;
case GLSLstd450FAbs:
- case GLSLstd450SAbs:
return core::BuiltinFn::kAbs;
case GLSLstd450FSign:
return core::BuiltinFn::kSign;
@@ -758,6 +757,8 @@
spirv::BuiltinFn GetGlslStd450SpirvEquivalentFuncName(uint32_t ext_opcode) {
switch (ext_opcode) {
+ case GLSLstd450SAbs:
+ return spirv::BuiltinFn::kAbs;
case GLSLstd450SSign:
return spirv::BuiltinFn::kSign;
case GLSLstd450Normalize:
@@ -772,7 +773,7 @@
Vector<const core::type::Type*, 1> GlslStd450ExplicitParams(uint32_t ext_opcode,
const core::type::Type* result_ty) {
- if (ext_opcode != GLSLstd450SSign) {
+ if (ext_opcode != GLSLstd450SSign && ext_opcode != GLSLstd450SAbs) {
return {};
}
return {result_ty->DeepestElement()};
diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def
index 39aeff1..e3c6855 100644
--- a/src/tint/lang/spirv/spirv.def
+++ b/src/tint/lang/spirv/spirv.def
@@ -328,6 +328,9 @@
implicit(T: iu32) fn sign<R: iu32>(T) -> R
implicit(T: iu32, N: num) fn sign<R: iu32>(vec<N, T>) -> vec<N, R>
+implicit(T: iu32) fn abs<R: iu32>(T) -> R
+implicit(T: iu32, N: num) fn abs<R: iu32>(vec<N, T>) -> vec<N, R>
+
////////////////////////////////////////////////////////////////////////////////
// SPV_KHR_integer_dot_product instructions
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index a37bccd..f701b80 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -1286,6 +1286,9 @@
};
switch (builtin->Func()) {
+ case spirv::BuiltinFn::kAbs:
+ ext_inst(GLSLstd450SAbs);
+ break;
case spirv::BuiltinFn::kArrayLength:
op = spv::Op::OpArrayLength;
break;