[spirv-writer][ast] Polyfill f16 shader IO
Emit f32 types and convert to/from f16 if requested.
Bug: tint:2161
Change-Id: I85e15f9aa7f858de4688f852d683ce9fb194e953
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/173705
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
index d76a591..4272172 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
@@ -188,7 +188,7 @@
     data.Add<ast::transform::CanonicalizeEntryPointIO::Config>(
         ast::transform::CanonicalizeEntryPointIO::Config(
             ast::transform::CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
-            options.emit_vertex_point_size));
+            options.emit_vertex_point_size, !options.use_storage_input_output_16));
 
     SanitizedResult result;
     ast::transform::DataMap outputs;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
index 842e41b..db9a682 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -339,6 +339,17 @@
                     value = b.IndexAccessor(value, 0_i);
                 }
             }
+
+            // Replace f16 types with f32 types if necessary.
+            if (cfg.polyfill_f16_io && type->DeepestElement()->Is<core::type::F16>()) {
+                value = b.Call(ast_type, value);
+
+                ast_type = b.ty.f32();
+                if (auto* vec = type->As<core::type::Vector>()) {
+                    ast_type = b.ty.vec(ast_type, vec->Width());
+                }
+            }
+
             b.GlobalVar(symbol, ast_type, core::AddressSpace::kIn, std::move(attrs));
             return value;
         } else if (cfg.shader_style == ShaderStyle::kMsl &&
@@ -406,9 +417,27 @@
             }
         }
 
+        ast::Type ast_type;
+
+        // Replace f16 types with f32 types if necessary.
+        if (cfg.shader_style == ShaderStyle::kSpirv && cfg.polyfill_f16_io &&
+            type->DeepestElement()->Is<core::type::F16>()) {
+            auto make_ast_type = [&] {
+                auto ty = b.ty.f32();
+                if (auto* vec = type->As<core::type::Vector>()) {
+                    ty = b.ty.vec(ty, vec->Width());
+                }
+                return ty;
+            };
+            ast_type = make_ast_type();
+            value = b.Call(make_ast_type(), value);
+        } else {
+            ast_type = CreateASTTypeFor(ctx, type);
+        }
+
         OutputValue output;
         output.name = name;
-        output.type = CreateASTTypeFor(ctx, type);
+        output.type = ast_type;
         output.attributes = std::move(attrs);
         output.value = value;
         output.location = location;
@@ -984,10 +1013,12 @@
 
 CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
                                          uint32_t sample_mask,
-                                         bool emit_point_size)
+                                         bool emit_point_size,
+                                         bool polyfill_f16)
     : shader_style(style),
       fixed_sample_mask(sample_mask),
-      emit_vertex_point_size(emit_point_size) {}
+      emit_vertex_point_size(emit_point_size),
+      polyfill_f16_io(polyfill_f16) {}
 
 CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
 CanonicalizeEntryPointIO::Config::~Config() = default;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
index d63d479..ddf761d 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
@@ -118,9 +118,11 @@
         /// @param style the approach to use for emitting shader IO.
         /// @param sample_mask an optional sample mask to combine with shader masks
         /// @param emit_vertex_point_size `true` to generate a pointsize builtin
+        /// @param polyfill_f16_io `true` to replace f16 types with f32 types
         explicit Config(ShaderStyle style,
                         uint32_t sample_mask = 0xFFFFFFFF,
-                        bool emit_vertex_point_size = false);
+                        bool emit_vertex_point_size = false,
+                        bool polyfill_f16_io = false);
 
         /// Copy constructor
         Config(const Config&);
@@ -137,6 +139,9 @@
         /// Set to `true` to generate a pointsize builtin and have it set to 1.0
         /// from all vertex shaders in the module.
         const bool emit_vertex_point_size;
+
+        /// Set to `true` to replace f16 IO types with f32 types and convert them.
+        const bool polyfill_f16_io = false;
     };
 
     /// HLSLWaveIntrinsic is an InternalAttribute that is used to decorate a stub function so that
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
index 7c33b29..ec83e9e 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
@@ -4420,5 +4420,60 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(CanonicalizeEntryPointIOTest, F16_Polyfill_Spirv) {
+    auto* src = R"(
+enable f16;
+
+struct Outputs {
+  @location(1) a : f16,
+  @location(2) b : vec4<f16>,
+}
+
+@fragment
+fn frag_main(@location(1) loc1 : f16,
+             @location(2) loc2 : vec4<f16>) -> Outputs {
+  return Outputs(loc1 * 2, loc2 * 3);
+}
+)";
+
+    auto* expect = R"(
+enable f16;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> loc1_1 : f32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__in> loc2_1 : vec4<f32>;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__out> a_1 : f32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__out> b_1 : vec4<f32>;
+
+struct Outputs {
+  a : f16,
+  b : vec4<f16>,
+}
+
+fn frag_main_inner(loc1 : f16, loc2 : vec4<f16>) -> Outputs {
+  return Outputs((loc1 * 2), (loc2 * 3));
+}
+
+@fragment
+fn frag_main() {
+  let inner_result = frag_main_inner(f16(loc1_1), vec4<f16>(loc2_1));
+  a_1 = f32(inner_result.a);
+  b_1 = vec4<f32>(inner_result.b);
+}
+)";
+
+    DataMap data;
+
+    data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+                                               /* fixed_sample_mask */ 0xFFFFFFFF,
+                                               /* emit_vertex_point_size */ false,
+                                               /* polyfill_f16_io */ true);
+    auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+    EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace tint::ast::transform