[msl] Always use core polyfill for DP4 builtins
The MSL-specific polyfill used in both the IR and AST backends is
susceptible to signed integer overflow, which is UB. We know that this
polyfill does not work correctly on at least one device
(crbug.com/355485146), so just use the core polyfill always.
Bug: 42251016
Change-Id: I7b04a85e3a73c73ad9d879c24c26b0973f563dd3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/205754
Reviewed-by: David Neto <dneto@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/dawn/native/metal/PhysicalDeviceMTL.mm b/src/dawn/native/metal/PhysicalDeviceMTL.mm
index a128a1e..04b782f 100644
--- a/src/dawn/native/metal/PhysicalDeviceMTL.mm
+++ b/src/dawn/native/metal/PhysicalDeviceMTL.mm
@@ -552,14 +552,6 @@
deviceToggles->Default(
Toggle::MetalUseBothDepthAndStencilAttachmentsForCombinedDepthStencilFormats, true);
}
-
- // Packed 4x8 integer dot products fail on Macbook Pro 16" with AMD Radeon Pro 5300M,
- // which are the RDNA1 architecture.
- // Conservatively, polyfill these functions on RDNA1 and RDNA2.
- // crbug.com/355485146
- if (gpu_info::IsAMDRDNA1(vendorId, deviceId) || gpu_info::IsAMDRDNA2(vendorId, deviceId)) {
- deviceToggles->Default(Toggle::PolyFillPacked4x8DotProduct, true);
- }
#endif
}
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 6e087e8..0151a18 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -290,8 +290,6 @@
req.use_tint_ir = device->IsToggleEnabled(Toggle::UseTintIR);
req.tintOptions.disable_polyfill_integer_div_mod =
device->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo);
- req.tintOptions.polyfill_dot_4x8_packed =
- device->IsToggleEnabled(Toggle::PolyFillPacked4x8DotProduct);
const CombinedLimits& limits = device->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 9296b43..68a20c8 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -202,7 +202,7 @@
polyfills.sign_int = true;
polyfills.texture_sample_base_clamp_to_edge_2d_f32 = true;
polyfills.workgroup_uniform_load = true;
- polyfills.dot_4x8_packed = options.polyfill_dot_4x8_packed;
+ polyfills.dot_4x8_packed = true;
polyfills.pack_unpack_4x8 = true;
polyfills.pack_4xu8_clamp = true;
data.Add<ast::transform::BuiltinPolyfill::Config>(polyfills);
@@ -736,10 +736,6 @@
return EmitDegreesCall(out, expr, builtin);
case wgsl::BuiltinFn::kRadians:
return EmitRadiansCall(out, expr, builtin);
- case wgsl::BuiltinFn::kDot4I8Packed:
- return EmitDot4I8PackedCall(out, expr, builtin);
- case wgsl::BuiltinFn::kDot4U8Packed:
- return EmitDot4U8PackedCall(out, expr, builtin);
case wgsl::BuiltinFn::kPack2X16Float:
case wgsl::BuiltinFn::kUnpack2X16Float: {
@@ -1594,32 +1590,6 @@
return true;
}
-bool ASTPrinter::EmitDot4I8PackedCall(StringStream& out,
- const ast::CallExpression* expr,
- const sem::BuiltinFn* builtin) {
- return CallBuiltinHelper(
- out, expr, builtin, [&](TextBuffer* b, const std::vector<std::string>& params) {
- Line(b) << "char4 vec1 = as_type<char4>(" << params[0] << ");";
- Line(b) << "char4 vec2 = as_type<char4>(" << params[1] << ");";
- Line(b) << "return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] "
- "* vec2[3];";
- return true;
- });
-}
-
-bool ASTPrinter::EmitDot4U8PackedCall(StringStream& out,
- const ast::CallExpression* expr,
- const sem::BuiltinFn* builtin) {
- return CallBuiltinHelper(
- out, expr, builtin, [&](TextBuffer* b, const std::vector<std::string>& params) {
- Line(b) << "uchar4 vec1 = as_type<uchar4>(" << params[0] << ");";
- Line(b) << "uchar4 vec2 = as_type<uchar4>(" << params[1] << ");";
- Line(b) << "return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] "
- "* vec2[3];";
- return true;
- });
-}
-
bool ASTPrinter::EmitModfCall(StringStream& out,
const ast::CallExpression* expr,
const sem::BuiltinFn* builtin) {
diff --git a/src/tint/lang/msl/writer/ast_printer/builtin_test.cc b/src/tint/lang/msl/writer/ast_printer/builtin_test.cc
index 0539763..fff3db0 100644
--- a/src/tint/lang/msl/writer/ast_printer/builtin_test.cc
+++ b/src/tint/lang/msl/writer/ast_printer/builtin_test.cc
@@ -1142,41 +1142,11 @@
)");
}
-TEST_F(MslASTPrinterTest, PolyfillDot4I8Packed_False) {
+TEST_F(MslASTPrinterTest, PolyfillDot4I8Packed) {
WrapInFunction(Decl(Let("zero", Expr(0_u))), //
Decl(Let("v", Call("dot4I8Packed", "zero", Expr(1_u)))));
- Options options;
- options.polyfill_dot_4x8_packed = false;
- ASTPrinter& gen = SanitizeAndBuild(options);
-
- ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
- EXPECT_EQ(gen.Result(), R"(#include <metal_stdlib>
-
-using namespace metal;
-
-int tint_dot4I8Packed(uint param_0, uint param_1) {
- char4 vec1 = as_type<char4>(param_0);
- char4 vec2 = as_type<char4>(param_1);
- return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];
-}
-
-kernel void test_function() {
- uint const zero = 0u;
- int const v = tint_dot4I8Packed(zero, 1u);
- return;
-}
-
-)");
-}
-
-TEST_F(MslASTPrinterTest, PolyfillDot4I8Packed_True) {
- WrapInFunction(Decl(Let("zero", Expr(0_u))), //
- Decl(Let("v", Call("dot4I8Packed", "zero", Expr(1_u)))));
-
- Options options;
- options.polyfill_dot_4x8_packed = true;
- ASTPrinter& gen = SanitizeAndBuild(options);
+ ASTPrinter& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
EXPECT_EQ(gen.Result(), R"(#include <metal_stdlib>
@@ -1202,41 +1172,11 @@
)");
}
-TEST_F(MslASTPrinterTest, PolyfillDot4U8Packed_False) {
+TEST_F(MslASTPrinterTest, PolyfillDot4U8Packed) {
WrapInFunction(Decl(Let("zero", Expr(0_u))), //
Decl(Let("v", Call("dot4U8Packed", "zero", Expr(1_u)))));
- Options options;
- options.polyfill_dot_4x8_packed = false;
- ASTPrinter& gen = SanitizeAndBuild(options);
-
- ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
- EXPECT_EQ(gen.Result(), R"(#include <metal_stdlib>
-
-using namespace metal;
-
-uint tint_dot4U8Packed(uint param_0, uint param_1) {
- uchar4 vec1 = as_type<uchar4>(param_0);
- uchar4 vec2 = as_type<uchar4>(param_1);
- return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];
-}
-
-kernel void test_function() {
- uint const zero = 0u;
- uint const v = tint_dot4U8Packed(zero, 1u);
- return;
-}
-
-)");
-}
-
-TEST_F(MslASTPrinterTest, PolyfillDot4U8Packed_True) {
- WrapInFunction(Decl(Let("zero", Expr(0_u))), //
- Decl(Let("v", Call("dot4U8Packed", "zero", Expr(1_u)))));
-
- Options options;
- options.polyfill_dot_4x8_packed = true;
- ASTPrinter& gen = SanitizeAndBuild(options);
+ ASTPrinter& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.Diagnostics();
EXPECT_EQ(gen.Result(), R"(#include <metal_stdlib>
diff --git a/src/tint/lang/msl/writer/common/options.h b/src/tint/lang/msl/writer/common/options.h
index fb57e0d..bd783c9 100644
--- a/src/tint/lang/msl/writer/common/options.h
+++ b/src/tint/lang/msl/writer/common/options.h
@@ -166,9 +166,6 @@
/// The bindings
Bindings bindings;
- /// Set to `true` to polyfill dot4I8Packed() dot4U8Packed().
- bool polyfill_dot_4x8_packed = false;
-
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(Options,
disable_robustness,
@@ -179,8 +176,7 @@
fixed_sample_mask,
pixel_local_attachments,
array_length_from_uniform,
- bindings,
- polyfill_dot_4x8_packed);
+ bindings);
};
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/raise/builtin_polyfill.cc b/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
index e024197..c706f6a 100644
--- a/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
@@ -79,9 +79,6 @@
/// A map from an integer vector type to a dot polyfill.
Hashmap<const core::type::Vector*, core::ir::Function*, 4> integer_dot_polyfills{};
- /// A map from an integer type to a packed 8-bit dot polyfill.
- Hashmap<const core::type::Type*, core::ir::Function*, 2> packed_8bit_integer_dot_polyfills{};
-
/// Process the module.
void Process() {
// Find the builtins that need replacing.
@@ -102,8 +99,6 @@
case core::BuiltinFn::kAtomicXor:
case core::BuiltinFn::kDistance:
case core::BuiltinFn::kDot:
- case core::BuiltinFn::kDot4I8Packed:
- case core::BuiltinFn::kDot4U8Packed:
case core::BuiltinFn::kFrexp:
case core::BuiltinFn::kLength:
case core::BuiltinFn::kModf:
@@ -184,10 +179,6 @@
case core::BuiltinFn::kDot:
Dot(builtin);
break;
- case core::BuiltinFn::kDot4I8Packed:
- case core::BuiltinFn::kDot4U8Packed:
- Dot4x8Packed(builtin);
- break;
case core::BuiltinFn::kFrexp:
Frexp(builtin);
break;
@@ -395,51 +386,6 @@
builtin->Destroy();
}
- /// Polyfill a packed 8-bit dot product call.
- /// @param builtin the builtin call instruction
- void Dot4x8Packed(core::ir::CoreBuiltinCall* builtin) {
- b.InsertBefore(builtin, [&] {
- auto* arg0 = builtin->Args()[0];
- auto* arg1 = builtin->Args()[1];
- auto* int32 = builtin->Result(0)->Type();
- auto* int8 = int32->Is<core::type::I32>()
- ? static_cast<const core::type::Type*>(ty.i8())
- : static_cast<const core::type::Type*>(ty.u8());
- // Calls to packed 8-bit dot products are polyfilled by casting to [u]char4, performing
- // the dot product, and converting the result to a {i,u}32:
- // uchar4 vec1 = as_type<uchar4>(param_0);
- // uchar4 vec2 = as_type<uchar4>(param_1);
- // result = uint(vec1[0] * vec2[0] + vec1[1] * vec2[1]
- // + vec1[2] * vec2[2] + vec1[3] * vec2[3]);
- auto* polyfill = packed_8bit_integer_dot_polyfills.GetOrAdd(int32, [&] {
- auto* lhs_32 = b.FunctionParam("lhs", ty.u32());
- auto* rhs_32 = b.FunctionParam("rhs", ty.u32());
- auto* func = b.Function("tint_packed_8bit_dot", int32);
- func->SetParams({lhs_32, rhs_32});
- b.Append(func->Block(), [&] {
- auto* lhs = b.Bitcast(ty.vec4(int8), lhs_32);
- auto* rhs = b.Bitcast(ty.vec4(int8), rhs_32);
- core::ir::Value* sum = nullptr;
- for (uint32_t i = 0; i < 4; i++) {
- auto* l = b.Access(int8, lhs, u32(i));
- auto* r = b.Access(int8, rhs, u32(i));
- auto* mul = b.Binary<ir::Binary>(core::BinaryOp::kMultiply, int8, l, r);
- if (sum) {
- auto* add = b.Binary<ir::Binary>(core::BinaryOp::kAdd, int8, sum, mul);
- sum = add->Result(0);
- } else {
- sum = mul->Result(0);
- }
- }
- b.Return(func, b.Convert(int32, sum));
- });
- return func;
- });
- b.CallWithResult(builtin->DetachResult(), polyfill, arg0, arg1);
- });
- builtin->Destroy();
- }
-
/// Polyfill a frexp call.
/// @param builtin the builtin call instruction
void Frexp(core::ir::CoreBuiltinCall* builtin) {
diff --git a/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
index a2701e4..01fbf25 100644
--- a/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
@@ -1022,211 +1022,6 @@
EXPECT_EQ(expect, str());
}
-TEST_F(MslWriter_BuiltinPolyfillTest, Dot4I8Packed) {
- auto* value0 = b.FunctionParam<u32>("value0");
- auto* value1 = b.FunctionParam<u32>("value1");
- auto* func = b.Function("foo", ty.i32());
- func->SetParams({value0, value1});
- b.Append(func->Block(), [&] {
- auto* result = b.Call<i32>(core::BuiltinFn::kDot4I8Packed, value0, value1);
- b.Return(func, result);
- });
-
- auto* src = R"(
-%foo = func(%value0:u32, %value1:u32):i32 {
- $B1: {
- %4:i32 = dot4I8Packed %value0, %value1
- ret %4
- }
-}
-)";
- EXPECT_EQ(src, str());
-
- auto* expect = R"(
-%foo = func(%value0:u32, %value1:u32):i32 {
- $B1: {
- %4:i32 = call %tint_packed_8bit_dot, %value0, %value1
- ret %4
- }
-}
-%tint_packed_8bit_dot = func(%lhs:u32, %rhs:u32):i32 {
- $B2: {
- %8:vec4<i8> = bitcast %lhs
- %9:vec4<i8> = bitcast %rhs
- %10:i8 = access %8, 0u
- %11:i8 = access %9, 0u
- %12:i8 = mul %10, %11
- %13:i8 = access %8, 1u
- %14:i8 = access %9, 1u
- %15:i8 = mul %13, %14
- %16:i8 = add %12, %15
- %17:i8 = access %8, 2u
- %18:i8 = access %9, 2u
- %19:i8 = mul %17, %18
- %20:i8 = add %16, %19
- %21:i8 = access %8, 3u
- %22:i8 = access %9, 3u
- %23:i8 = mul %21, %22
- %24:i8 = add %20, %23
- %25:i32 = convert %24
- ret %25
- }
-}
-)";
-
- capabilities.Add(core::ir::Capability::kAllow8BitIntegers);
- Run(BuiltinPolyfill);
-
- EXPECT_EQ(expect, str());
-}
-
-TEST_F(MslWriter_BuiltinPolyfillTest, Dot4U8Packed) {
- auto* value0 = b.FunctionParam<u32>("value0");
- auto* value1 = b.FunctionParam<u32>("value1");
- auto* func = b.Function("foo", ty.u32());
- func->SetParams({value0, value1});
- b.Append(func->Block(), [&] {
- auto* result = b.Call<u32>(core::BuiltinFn::kDot4U8Packed, value0, value1);
- b.Return(func, result);
- });
-
- auto* src = R"(
-%foo = func(%value0:u32, %value1:u32):u32 {
- $B1: {
- %4:u32 = dot4U8Packed %value0, %value1
- ret %4
- }
-}
-)";
- EXPECT_EQ(src, str());
-
- auto* expect = R"(
-%foo = func(%value0:u32, %value1:u32):u32 {
- $B1: {
- %4:u32 = call %tint_packed_8bit_dot, %value0, %value1
- ret %4
- }
-}
-%tint_packed_8bit_dot = func(%lhs:u32, %rhs:u32):u32 {
- $B2: {
- %8:vec4<u8> = bitcast %lhs
- %9:vec4<u8> = bitcast %rhs
- %10:u8 = access %8, 0u
- %11:u8 = access %9, 0u
- %12:u8 = mul %10, %11
- %13:u8 = access %8, 1u
- %14:u8 = access %9, 1u
- %15:u8 = mul %13, %14
- %16:u8 = add %12, %15
- %17:u8 = access %8, 2u
- %18:u8 = access %9, 2u
- %19:u8 = mul %17, %18
- %20:u8 = add %16, %19
- %21:u8 = access %8, 3u
- %22:u8 = access %9, 3u
- %23:u8 = mul %21, %22
- %24:u8 = add %20, %23
- %25:u32 = convert %24
- ret %25
- }
-}
-)";
-
- capabilities.Add(core::ir::Capability::kAllow8BitIntegers);
- Run(BuiltinPolyfill);
-
- EXPECT_EQ(expect, str());
-}
-
-TEST_F(MslWriter_BuiltinPolyfillTest, Dot4x8Packed_MultipleCalls) {
- auto* v = b.FunctionParam<u32>("v");
- auto* func = b.Function("foo", ty.void_());
- func->SetParams({v, v});
- b.Append(func->Block(), [&] {
- b.Call<i32>(core::BuiltinFn::kDot4I8Packed, v, v);
- b.Call<i32>(core::BuiltinFn::kDot4I8Packed, v, v);
- b.Call<u32>(core::BuiltinFn::kDot4U8Packed, v, v);
- b.Call<u32>(core::BuiltinFn::kDot4U8Packed, v, v);
- b.Return(func);
- });
-
- auto* src = R"(
-%foo = func(%v:u32%v:u32):void {
- $B1: {
- %3:i32 = dot4I8Packed %v, %v
- %4:i32 = dot4I8Packed %v, %v
- %5:u32 = dot4U8Packed %v, %v
- %6:u32 = dot4U8Packed %v, %v
- ret
- }
-}
-)";
- EXPECT_EQ(src, str());
-
- auto* expect = R"(
-%foo = func(%v:u32%v:u32):void {
- $B1: {
- %3:i32 = call %tint_packed_8bit_dot, %v, %v
- %5:i32 = call %tint_packed_8bit_dot, %v, %v
- %6:u32 = call %tint_packed_8bit_dot_1, %v, %v
- %8:u32 = call %tint_packed_8bit_dot_1, %v, %v
- ret
- }
-}
-%tint_packed_8bit_dot = func(%lhs:u32, %rhs:u32):i32 {
- $B2: {
- %11:vec4<i8> = bitcast %lhs
- %12:vec4<i8> = bitcast %rhs
- %13:i8 = access %11, 0u
- %14:i8 = access %12, 0u
- %15:i8 = mul %13, %14
- %16:i8 = access %11, 1u
- %17:i8 = access %12, 1u
- %18:i8 = mul %16, %17
- %19:i8 = add %15, %18
- %20:i8 = access %11, 2u
- %21:i8 = access %12, 2u
- %22:i8 = mul %20, %21
- %23:i8 = add %19, %22
- %24:i8 = access %11, 3u
- %25:i8 = access %12, 3u
- %26:i8 = mul %24, %25
- %27:i8 = add %23, %26
- %28:i32 = convert %27
- ret %28
- }
-}
-%tint_packed_8bit_dot_1 = func(%lhs_1:u32, %rhs_1:u32):u32 { # %tint_packed_8bit_dot_1: 'tint_packed_8bit_dot', %lhs_1: 'lhs', %rhs_1: 'rhs'
- $B3: {
- %31:vec4<u8> = bitcast %lhs_1
- %32:vec4<u8> = bitcast %rhs_1
- %33:u8 = access %31, 0u
- %34:u8 = access %32, 0u
- %35:u8 = mul %33, %34
- %36:u8 = access %31, 1u
- %37:u8 = access %32, 1u
- %38:u8 = mul %36, %37
- %39:u8 = add %35, %38
- %40:u8 = access %31, 2u
- %41:u8 = access %32, 2u
- %42:u8 = mul %40, %41
- %43:u8 = add %39, %42
- %44:u8 = access %31, 3u
- %45:u8 = access %32, 3u
- %46:u8 = mul %44, %45
- %47:u8 = add %43, %46
- %48:u32 = convert %47
- ret %48
- }
-}
-)";
-
- capabilities.Add(core::ir::Capability::kAllow8BitIntegers);
- Run(BuiltinPolyfill);
-
- EXPECT_EQ(expect, str());
-}
-
TEST_F(MslWriter_BuiltinPolyfillTest, Frexp_Scalar) {
auto* value = b.FunctionParam<f32>("value");
auto* func = b.Function("foo", ty.f32());
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 0230db1..21cb674 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -84,6 +84,7 @@
core::ir::transform::BuiltinPolyfillConfig core_polyfills{};
core_polyfills.clamp_int = true;
core_polyfills.degrees = true;
+ core_polyfills.dot_4x8_packed = true;
core_polyfills.extract_bits = core::ir::transform::BuiltinPolyfillLevel::kClampOrRangeCheck;
core_polyfills.first_leading_bit = true;
core_polyfills.first_trailing_bit = true;
diff --git a/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.ir.msl
index a3c751f..36225a6 100644
--- a/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.ir.msl
@@ -15,14 +15,21 @@
int VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
};
-int tint_packed_8bit_dot(uint lhs, uint rhs) {
- return int(((((as_type<char4>(lhs)[0u] * as_type<char4>(rhs)[0u]) + (as_type<char4>(lhs)[1u] * as_type<char4>(rhs)[1u])) + (as_type<char4>(lhs)[2u] * as_type<char4>(rhs)[2u])) + (as_type<char4>(lhs)[3u] * as_type<char4>(rhs)[3u])));
+int tint_dot(int4 lhs, int4 rhs) {
+ return ((((lhs * rhs)[0u] + (lhs * rhs)[1u]) + (lhs * rhs)[2u]) + (lhs * rhs)[3u]);
}
int dot4I8Packed_881e62() {
uint arg_0 = 1u;
uint arg_1 = 1u;
- int res = tint_packed_8bit_dot(arg_0, arg_1);
+ uint const v = arg_0;
+ uint const v_1 = arg_1;
+ uint4 const v_2 = uint4(24u, 16u, 8u, 0u);
+ int4 const v_3 = as_type<int4>((uint4(v) << v_2));
+ int4 const v_4 = (v_3 >> uint4(24u));
+ uint4 const v_5 = uint4(24u, 16u, 8u, 0u);
+ int4 const v_6 = as_type<int4>((uint4(v_1) << v_5));
+ int res = tint_dot(v_4, (v_6 >> uint4(24u)));
return res;
}
@@ -44,9 +51,9 @@
}
vertex vertex_main_outputs vertex_main() {
- VertexOutput const v = vertex_main_inner();
+ VertexOutput const v_7 = vertex_main_inner();
vertex_main_outputs tint_wrapper_result = {};
- tint_wrapper_result.VertexOutput_pos = v.pos;
- tint_wrapper_result.VertexOutput_prevent_dce = v.prevent_dce;
+ tint_wrapper_result.VertexOutput_pos = v_7.pos;
+ tint_wrapper_result.VertexOutput_prevent_dce = v_7.prevent_dce;
return tint_wrapper_result;
}
diff --git a/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.msl b/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.msl
index f17d4ac..f6089ba 100644
--- a/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.msl
+++ b/test/tint/builtins/gen/var/dot4I8Packed/881e62.wgsl.expected.msl
@@ -2,16 +2,20 @@
using namespace metal;
-int tint_dot4I8Packed(uint param_0, uint param_1) {
- char4 vec1 = as_type<char4>(param_0);
- char4 vec2 = as_type<char4>(param_1);
- return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];
+template<typename T>
+T tint_dot4(vec<T,4> a, vec<T,4> b) {
+ return a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3];
+}
+int tint_dot4_i8_packed(uint a, uint b) {
+ int4 const a_i8 = (as_type<int4>((uint4(a) << uint4(24u, 16u, 8u, 0u))) >> uint4(24u));
+ int4 const b_i8 = (as_type<int4>((uint4(b) << uint4(24u, 16u, 8u, 0u))) >> uint4(24u));
+ return tint_dot4(a_i8, b_i8);
}
int dot4I8Packed_881e62() {
uint arg_0 = 1u;
uint arg_1 = 1u;
- int res = tint_dot4I8Packed(arg_0, arg_1);
+ int res = tint_dot4_i8_packed(arg_0, arg_1);
return res;
}
diff --git a/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.ir.msl
index 9b457cd..37d41fc 100644
--- a/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.ir.msl
@@ -15,14 +15,21 @@
uint VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
};
-uint tint_packed_8bit_dot(uint lhs, uint rhs) {
- return uint(((((as_type<uchar4>(lhs)[0u] * as_type<uchar4>(rhs)[0u]) + (as_type<uchar4>(lhs)[1u] * as_type<uchar4>(rhs)[1u])) + (as_type<uchar4>(lhs)[2u] * as_type<uchar4>(rhs)[2u])) + (as_type<uchar4>(lhs)[3u] * as_type<uchar4>(rhs)[3u])));
+uint tint_dot(uint4 lhs, uint4 rhs) {
+ return ((((lhs * rhs)[0u] + (lhs * rhs)[1u]) + (lhs * rhs)[2u]) + (lhs * rhs)[3u]);
}
uint dot4U8Packed_fbed7b() {
uint arg_0 = 1u;
uint arg_1 = 1u;
- uint res = tint_packed_8bit_dot(arg_0, arg_1);
+ uint const v = arg_0;
+ uint const v_1 = arg_1;
+ uint4 const v_2 = uint4(0u, 8u, 16u, 24u);
+ uint4 const v_3 = (uint4(v) >> v_2);
+ uint4 const v_4 = (v_3 & uint4(255u));
+ uint4 const v_5 = uint4(0u, 8u, 16u, 24u);
+ uint4 const v_6 = (uint4(v_1) >> v_5);
+ uint res = tint_dot(v_4, (v_6 & uint4(255u)));
return res;
}
@@ -44,9 +51,9 @@
}
vertex vertex_main_outputs vertex_main() {
- VertexOutput const v = vertex_main_inner();
+ VertexOutput const v_7 = vertex_main_inner();
vertex_main_outputs tint_wrapper_result = {};
- tint_wrapper_result.VertexOutput_pos = v.pos;
- tint_wrapper_result.VertexOutput_prevent_dce = v.prevent_dce;
+ tint_wrapper_result.VertexOutput_pos = v_7.pos;
+ tint_wrapper_result.VertexOutput_prevent_dce = v_7.prevent_dce;
return tint_wrapper_result;
}
diff --git a/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.msl b/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.msl
index dda21ac..384b67c 100644
--- a/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.msl
+++ b/test/tint/builtins/gen/var/dot4U8Packed/fbed7b.wgsl.expected.msl
@@ -2,16 +2,20 @@
using namespace metal;
-uint tint_dot4U8Packed(uint param_0, uint param_1) {
- uchar4 vec1 = as_type<uchar4>(param_0);
- uchar4 vec2 = as_type<uchar4>(param_1);
- return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];
+template<typename T>
+T tint_dot4(vec<T,4> a, vec<T,4> b) {
+ return a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3];
+}
+uint tint_dot4_u8_packed(uint a, uint b) {
+ uint4 const a_u8 = ((uint4(a) >> uint4(24u, 16u, 8u, 0u)) & uint4(255u));
+ uint4 const b_u8 = ((uint4(b) >> uint4(24u, 16u, 8u, 0u)) & uint4(255u));
+ return tint_dot4(a_u8, b_u8);
}
uint dot4U8Packed_fbed7b() {
uint arg_0 = 1u;
uint arg_1 = 1u;
- uint res = tint_dot4U8Packed(arg_0, arg_1);
+ uint res = tint_dot4_u8_packed(arg_0, arg_1);
return res;
}