[msl] Add polyfill for quantizeToF16
Convert the argument to f16 and then back to f32.
Bug: 42251016
Change-Id: I7006e0ce91e34a8a2a372b5a1f5d6b24bc5b0c5a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195214
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/raise/builtin_polyfill.cc b/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
index 0b75610..0d2e61b 100644
--- a/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
+++ b/src/tint/lang/msl/writer/raise/builtin_polyfill.cc
@@ -99,6 +99,7 @@
case core::BuiltinFn::kDistance:
case core::BuiltinFn::kDot:
case core::BuiltinFn::kLength:
+ case core::BuiltinFn::kQuantizeToF16:
case core::BuiltinFn::kTextureDimensions:
case core::BuiltinFn::kTextureGather:
case core::BuiltinFn::kTextureGatherCompare:
@@ -173,6 +174,9 @@
case core::BuiltinFn::kLength:
Length(builtin);
break;
+ case core::BuiltinFn::kQuantizeToF16:
+ QuantizeToF16(builtin);
+ break;
// Texture builtins.
case core::BuiltinFn::kTextureDimensions:
@@ -367,6 +371,23 @@
builtin->Destroy();
}
+ /// Polyfill a quantizeToF16 call.
+ /// @param builtin the builtin call instruction
+ void QuantizeToF16(core::ir::CoreBuiltinCall* builtin) {
+ auto* arg = builtin->Args()[0];
+ auto* type_f32 = arg->Type();
+ const core::type::Type* type_f16 = ty.f16();
+ if (auto* vec = type_f32->As<core::type::Vector>()) {
+ type_f16 = ty.vec(ty.f16(), vec->Width());
+ }
+
+ // Convert the argument to f16 and then back again.
+ b.InsertBefore(builtin, [&] {
+ b.ConvertWithResult(builtin->DetachResult(), b.Convert(type_f16, arg));
+ });
+ builtin->Destroy();
+ }
+
/// Replace a textureDimensions call with the equivalent MSL intrinsics.
/// @param builtin the builtin call instruction
void TextureDimensions(core::ir::CoreBuiltinCall* builtin) {
diff --git a/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc b/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
index d9bc55b..467bb5e 100644
--- a/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
+++ b/src/tint/lang/msl/writer/raise/builtin_polyfill_test.cc
@@ -1088,6 +1088,74 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_BuiltinPolyfillTest, QuantizeToF16_Scalar) {
+ auto* value = b.FunctionParam<f32>("value");
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({value});
+ b.Append(func->Block(), [&] {
+ auto* result = b.Call<f32>(core::BuiltinFn::kQuantizeToF16, value);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%value:f32):f32 {
+ $B1: {
+ %3:f32 = quantizeToF16 %value
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%value:f32):f32 {
+ $B1: {
+ %3:f16 = convert %value
+ %4:f32 = convert %3
+ ret %4
+ }
+}
+)";
+
+ Run(BuiltinPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_BuiltinPolyfillTest, QuantizeToF16_Vector) {
+ auto* value = b.FunctionParam<vec4<f32>>("value");
+ auto* func = b.Function("foo", ty.vec4<f32>());
+ func->SetParams({value});
+ b.Append(func->Block(), [&] {
+ auto* result = b.Call<vec4<f32>>(core::BuiltinFn::kQuantizeToF16, value);
+ b.Return(func, result);
+ });
+
+ auto* src = R"(
+%foo = func(%value:vec4<f32>):vec4<f32> {
+ $B1: {
+ %3:vec4<f32> = quantizeToF16 %value
+ ret %3
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%value:vec4<f32>):vec4<f32> {
+ $B1: {
+ %3:vec4<f16> = convert %value
+ %4:vec4<f32> = convert %3
+ ret %4
+ }
+}
+)";
+
+ Run(BuiltinPolyfill);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(MslWriter_BuiltinPolyfillTest, TextureDimensions_1d) {
auto* t = b.FunctionParam(
"t", ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k1d, ty.f32()));
diff --git a/test/tint/builtins/gen/var/quantizeToF16/12e50e.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/quantizeToF16/12e50e.wgsl.expected.ir.msl
index ec7fec1..60aadc8 100644
--- a/test/tint/builtins/gen/var/quantizeToF16/12e50e.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/quantizeToF16/12e50e.wgsl.expected.ir.msl
@@ -1,43 +1,44 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:500 internal compiler error: $B1: { # root
- %prevent_dce:ptr<storage, f32, read_write> = var @binding_point(2, 0)
+struct tint_module_vars_struct {
+ device float* prevent_dce;
+};
+
+struct VertexOutput {
+ float4 pos;
+ float prevent_dce;
+};
+
+struct vertex_main_outputs {
+ float4 VertexOutput_pos [[position]];
+ float VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
+};
+
+float quantizeToF16_12e50e() {
+ float arg_0 = 1.0f;
+ float res = float(half(arg_0));
+ return res;
}
-%quantizeToF16_12e50e = func():void {
- $B2: {
- %arg_0:ptr<function, f32, read_write> = var, 1.0f
- %4:f32 = load %arg_0
- %5:f32 = quantizeToF16 %4
- %res:ptr<function, f32, read_write> = var, %5
- %7:f32 = load %res
- store %prevent_dce, %7
- ret
- }
-}
-%vertex_main = @vertex func():vec4<f32> [@position] {
- $B3: {
- %9:void = call %quantizeToF16_12e50e
- ret vec4<f32>(0.0f)
- }
-}
-%fragment_main = @fragment func():void {
- $B4: {
- %11:void = call %quantizeToF16_12e50e
- ret
- }
-}
-%compute_main = @compute @workgroup_size(1, 1, 1) func():void {
- $B5: {
- %13:void = call %quantizeToF16_12e50e
- ret
- }
+fragment void fragment_main(device float* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_12e50e();
}
-unhandled variable address space
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+kernel void compute_main(device float* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_12e50e();
+}
+
+VertexOutput vertex_main_inner() {
+ VertexOutput out = {};
+ out.pos = float4(0.0f);
+ out.prevent_dce = quantizeToF16_12e50e();
+ return out;
+}
+
+vertex vertex_main_outputs vertex_main() {
+ VertexOutput const v = vertex_main_inner();
+ return vertex_main_outputs{.VertexOutput_pos=v.pos, .VertexOutput_prevent_dce=v.prevent_dce};
+}
diff --git a/test/tint/builtins/gen/var/quantizeToF16/2cddf3.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/quantizeToF16/2cddf3.wgsl.expected.ir.msl
index 97979d6..bb54664 100644
--- a/test/tint/builtins/gen/var/quantizeToF16/2cddf3.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/quantizeToF16/2cddf3.wgsl.expected.ir.msl
@@ -1,43 +1,44 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:500 internal compiler error: $B1: { # root
- %prevent_dce:ptr<storage, vec2<f32>, read_write> = var @binding_point(2, 0)
+struct tint_module_vars_struct {
+ device float2* prevent_dce;
+};
+
+struct VertexOutput {
+ float4 pos;
+ float2 prevent_dce;
+};
+
+struct vertex_main_outputs {
+ float4 VertexOutput_pos [[position]];
+ float2 VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
+};
+
+float2 quantizeToF16_2cddf3() {
+ float2 arg_0 = float2(1.0f);
+ float2 res = float2(half2(arg_0));
+ return res;
}
-%quantizeToF16_2cddf3 = func():void {
- $B2: {
- %arg_0:ptr<function, vec2<f32>, read_write> = var, vec2<f32>(1.0f)
- %4:vec2<f32> = load %arg_0
- %5:vec2<f32> = quantizeToF16 %4
- %res:ptr<function, vec2<f32>, read_write> = var, %5
- %7:vec2<f32> = load %res
- store %prevent_dce, %7
- ret
- }
-}
-%vertex_main = @vertex func():vec4<f32> [@position] {
- $B3: {
- %9:void = call %quantizeToF16_2cddf3
- ret vec4<f32>(0.0f)
- }
-}
-%fragment_main = @fragment func():void {
- $B4: {
- %11:void = call %quantizeToF16_2cddf3
- ret
- }
-}
-%compute_main = @compute @workgroup_size(1, 1, 1) func():void {
- $B5: {
- %13:void = call %quantizeToF16_2cddf3
- ret
- }
+fragment void fragment_main(device float2* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_2cddf3();
}
-unhandled variable address space
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+kernel void compute_main(device float2* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_2cddf3();
+}
+
+VertexOutput vertex_main_inner() {
+ VertexOutput out = {};
+ out.pos = float4(0.0f);
+ out.prevent_dce = quantizeToF16_2cddf3();
+ return out;
+}
+
+vertex vertex_main_outputs vertex_main() {
+ VertexOutput const v = vertex_main_inner();
+ return vertex_main_outputs{.VertexOutput_pos=v.pos, .VertexOutput_prevent_dce=v.prevent_dce};
+}
diff --git a/test/tint/builtins/gen/var/quantizeToF16/cba294.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/quantizeToF16/cba294.wgsl.expected.ir.msl
index f4fe649..643d6f8 100644
--- a/test/tint/builtins/gen/var/quantizeToF16/cba294.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/quantizeToF16/cba294.wgsl.expected.ir.msl
@@ -1,43 +1,44 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:500 internal compiler error: $B1: { # root
- %prevent_dce:ptr<storage, vec4<f32>, read_write> = var @binding_point(2, 0)
+struct tint_module_vars_struct {
+ device float4* prevent_dce;
+};
+
+struct VertexOutput {
+ float4 pos;
+ float4 prevent_dce;
+};
+
+struct vertex_main_outputs {
+ float4 VertexOutput_pos [[position]];
+ float4 VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
+};
+
+float4 quantizeToF16_cba294() {
+ float4 arg_0 = float4(1.0f);
+ float4 res = float4(half4(arg_0));
+ return res;
}
-%quantizeToF16_cba294 = func():void {
- $B2: {
- %arg_0:ptr<function, vec4<f32>, read_write> = var, vec4<f32>(1.0f)
- %4:vec4<f32> = load %arg_0
- %5:vec4<f32> = quantizeToF16 %4
- %res:ptr<function, vec4<f32>, read_write> = var, %5
- %7:vec4<f32> = load %res
- store %prevent_dce, %7
- ret
- }
-}
-%vertex_main = @vertex func():vec4<f32> [@position] {
- $B3: {
- %9:void = call %quantizeToF16_cba294
- ret vec4<f32>(0.0f)
- }
-}
-%fragment_main = @fragment func():void {
- $B4: {
- %11:void = call %quantizeToF16_cba294
- ret
- }
-}
-%compute_main = @compute @workgroup_size(1, 1, 1) func():void {
- $B5: {
- %13:void = call %quantizeToF16_cba294
- ret
- }
+fragment void fragment_main(device float4* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_cba294();
}
-unhandled variable address space
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+kernel void compute_main(device float4* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_cba294();
+}
+
+VertexOutput vertex_main_inner() {
+ VertexOutput out = {};
+ out.pos = float4(0.0f);
+ out.prevent_dce = quantizeToF16_cba294();
+ return out;
+}
+
+vertex vertex_main_outputs vertex_main() {
+ VertexOutput const v = vertex_main_inner();
+ return vertex_main_outputs{.VertexOutput_pos=v.pos, .VertexOutput_prevent_dce=v.prevent_dce};
+}
diff --git a/test/tint/builtins/gen/var/quantizeToF16/e8fd14.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/quantizeToF16/e8fd14.wgsl.expected.ir.msl
index 1e11ec8..236518a 100644
--- a/test/tint/builtins/gen/var/quantizeToF16/e8fd14.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/quantizeToF16/e8fd14.wgsl.expected.ir.msl
@@ -1,43 +1,44 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:500 internal compiler error: $B1: { # root
- %prevent_dce:ptr<storage, vec3<f32>, read_write> = var @binding_point(2, 0)
+struct tint_module_vars_struct {
+ device float3* prevent_dce;
+};
+
+struct VertexOutput {
+ float4 pos;
+ float3 prevent_dce;
+};
+
+struct vertex_main_outputs {
+ float4 VertexOutput_pos [[position]];
+ float3 VertexOutput_prevent_dce [[user(locn0)]] [[flat]];
+};
+
+float3 quantizeToF16_e8fd14() {
+ float3 arg_0 = float3(1.0f);
+ float3 res = float3(half3(arg_0));
+ return res;
}
-%quantizeToF16_e8fd14 = func():void {
- $B2: {
- %arg_0:ptr<function, vec3<f32>, read_write> = var, vec3<f32>(1.0f)
- %4:vec3<f32> = load %arg_0
- %5:vec3<f32> = quantizeToF16 %4
- %res:ptr<function, vec3<f32>, read_write> = var, %5
- %7:vec3<f32> = load %res
- store %prevent_dce, %7
- ret
- }
-}
-%vertex_main = @vertex func():vec4<f32> [@position] {
- $B3: {
- %9:void = call %quantizeToF16_e8fd14
- ret vec4<f32>(0.0f)
- }
-}
-%fragment_main = @fragment func():void {
- $B4: {
- %11:void = call %quantizeToF16_e8fd14
- ret
- }
-}
-%compute_main = @compute @workgroup_size(1, 1, 1) func():void {
- $B5: {
- %13:void = call %quantizeToF16_e8fd14
- ret
- }
+fragment void fragment_main(device float3* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_e8fd14();
}
-unhandled variable address space
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+kernel void compute_main(device float3* prevent_dce [[buffer(0)]]) {
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce};
+ (*tint_module_vars.prevent_dce) = quantizeToF16_e8fd14();
+}
+
+VertexOutput vertex_main_inner() {
+ VertexOutput out = {};
+ out.pos = float4(0.0f);
+ out.prevent_dce = quantizeToF16_e8fd14();
+ return out;
+}
+
+vertex vertex_main_outputs vertex_main() {
+ VertexOutput const v = vertex_main_inner();
+ return vertex_main_outputs{.VertexOutput_pos=v.pos, .VertexOutput_prevent_dce=v.prevent_dce};
+}