[ir][spirv-writer] Scalarize quantizeToF16
The vector form crashes certain NVIDIA drivers on Linux.
Bug: tint:1906
Change-Id: Ia3eb4434d7220c88566fd20d47efe778f007ef7f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/151883
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc b/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
index 7a3bd38..fea3909 100644
--- a/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/spirv/writer/raise/builtin_polyfill.cc
@@ -91,6 +91,11 @@
case core::Function::kTextureStore:
worklist.Push(builtin);
break;
+ case core::Function::kQuantizeToF16:
+ if (builtin->Result()->Type()->Is<core::type::Vector>()) {
+ worklist.Push(builtin);
+ }
+ break;
default:
break;
}
@@ -147,6 +152,9 @@
case core::Function::kTextureStore:
replacement = TextureStore(builtin);
break;
+ case core::Function::kQuantizeToF16:
+ replacement = QuantizeToF16Vec(builtin);
+ break;
default:
break;
}
@@ -818,6 +826,29 @@
extract->InsertBefore(builtin);
return extract->Result();
}
+
+ /// Scalarize the vector form of a `quantizeToF16()` builtin.
+ /// See crbug.com/tint/1741.
+ /// @param builtin the builtin call instruction
+ /// @returns the replacement value
+ core::ir::Value* QuantizeToF16Vec(core::ir::CoreBuiltinCall* builtin) {
+ auto* arg = builtin->Args()[0];
+ auto* vec = arg->Type()->As<core::type::Vector>();
+ TINT_ASSERT(vec);
+
+ // Replace the builtin call with a call to the spirv.dot intrinsic.
+ Vector<core::ir::Value*, 4> args;
+ for (uint32_t i = 0; i < vec->Width(); i++) {
+ auto* el = b.Access(ty.f32(), arg, u32(i));
+ auto* scalar_call = b.Call(ty.f32(), core::Function::kQuantizeToF16, el);
+ args.Push(scalar_call->Result());
+ el->InsertBefore(builtin);
+ scalar_call->InsertBefore(builtin);
+ }
+ auto* construct = b.Construct(vec, std::move(args));
+ construct->InsertBefore(builtin);
+ return construct->Result();
+ }
};
} // namespace
diff --git a/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
index 0fa6eaf..d20393f 100644
--- a/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/spirv/writer/raise/builtin_polyfill_test.cc
@@ -2804,5 +2804,74 @@
EXPECT_EQ(expect, str());
}
+TEST_F(SpirvWriter_BuiltinPolyfillTest, QuantizeToF16_Scalar) {
+ auto* arg = b.FunctionParam("arg", ty.f32());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({arg});
+
+ b.Append(func->Block(), [&] {
+ auto* result = b.Call(ty.f32(), core::Function::kQuantizeToF16, arg);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg:f32):f32 -> %b1 {
+ %b1 = block {
+ %3:f32 = quantizeToF16 %arg
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(BuiltinPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_BuiltinPolyfillTest, QuantizeToF16_Vector) {
+ auto* arg = b.FunctionParam("arg", ty.vec4<f32>());
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({arg});
+
+ b.Append(func->Block(), [&] {
+ auto* result = b.Call(ty.vec4<f32>(), core::Function::kQuantizeToF16, arg);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%arg:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %3:vec4<f32> = quantizeToF16 %arg
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%arg:vec4<f32>):vec4<f32> -> %b1 {
+ %b1 = block {
+ %3:f32 = access %arg, 0u
+ %4:f32 = quantizeToF16 %3
+ %5:f32 = access %arg, 1u
+ %6:f32 = quantizeToF16 %5
+ %7:f32 = access %arg, 2u
+ %8:f32 = quantizeToF16 %7
+ %9:f32 = access %arg, 3u
+ %10:f32 = quantizeToF16 %9
+ %11:vec4<f32> = construct %4, %6, %8, %10
+ ret %11
+ }
+}
+)";
+
+ Run(BuiltinPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::spirv::writer::raise