[spirv-writer][ast] Polyfill f16 shader IO
Emit f32 types and convert to/from f16 if requested.
Bug: tint:2161
Change-Id: I85e15f9aa7f858de4688f852d683ce9fb194e953
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/173705
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
index d76a591..4272172 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
@@ -188,7 +188,7 @@
data.Add<ast::transform::CanonicalizeEntryPointIO::Config>(
ast::transform::CanonicalizeEntryPointIO::Config(
ast::transform::CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
- options.emit_vertex_point_size));
+ options.emit_vertex_point_size, !options.use_storage_input_output_16));
SanitizedResult result;
ast::transform::DataMap outputs;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
index 842e41b..db9a682 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -339,6 +339,17 @@
value = b.IndexAccessor(value, 0_i);
}
}
+
+ // Replace f16 types with f32 types if necessary.
+ if (cfg.polyfill_f16_io && type->DeepestElement()->Is<core::type::F16>()) {
+ value = b.Call(ast_type, value);
+
+ ast_type = b.ty.f32();
+ if (auto* vec = type->As<core::type::Vector>()) {
+ ast_type = b.ty.vec(ast_type, vec->Width());
+ }
+ }
+
b.GlobalVar(symbol, ast_type, core::AddressSpace::kIn, std::move(attrs));
return value;
} else if (cfg.shader_style == ShaderStyle::kMsl &&
@@ -406,9 +417,27 @@
}
}
+ ast::Type ast_type;
+
+ // Replace f16 types with f32 types if necessary.
+ if (cfg.shader_style == ShaderStyle::kSpirv && cfg.polyfill_f16_io &&
+ type->DeepestElement()->Is<core::type::F16>()) {
+ auto make_ast_type = [&] {
+ auto ty = b.ty.f32();
+ if (auto* vec = type->As<core::type::Vector>()) {
+ ty = b.ty.vec(ty, vec->Width());
+ }
+ return ty;
+ };
+ ast_type = make_ast_type();
+ value = b.Call(make_ast_type(), value);
+ } else {
+ ast_type = CreateASTTypeFor(ctx, type);
+ }
+
OutputValue output;
output.name = name;
- output.type = CreateASTTypeFor(ctx, type);
+ output.type = ast_type;
output.attributes = std::move(attrs);
output.value = value;
output.location = location;
@@ -984,10 +1013,12 @@
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
uint32_t sample_mask,
- bool emit_point_size)
+ bool emit_point_size,
+ bool polyfill_f16)
: shader_style(style),
fixed_sample_mask(sample_mask),
- emit_vertex_point_size(emit_point_size) {}
+ emit_vertex_point_size(emit_point_size),
+ polyfill_f16_io(polyfill_f16) {}
CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
CanonicalizeEntryPointIO::Config::~Config() = default;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
index d63d479..ddf761d 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
@@ -118,9 +118,11 @@
/// @param style the approach to use for emitting shader IO.
/// @param sample_mask an optional sample mask to combine with shader masks
/// @param emit_vertex_point_size `true` to generate a pointsize builtin
+ /// @param polyfill_f16_io `true` to replace f16 types with f32 types
explicit Config(ShaderStyle style,
uint32_t sample_mask = 0xFFFFFFFF,
- bool emit_vertex_point_size = false);
+ bool emit_vertex_point_size = false,
+ bool polyfill_f16_io = false);
/// Copy constructor
Config(const Config&);
@@ -137,6 +139,9 @@
/// Set to `true` to generate a pointsize builtin and have it set to 1.0
/// from all vertex shaders in the module.
const bool emit_vertex_point_size;
+
+ /// Set to `true` to replace f16 IO types with f32 types and convert them.
+ const bool polyfill_f16_io = false;
};
/// HLSLWaveIntrinsic is an InternalAttribute that is used to decorate a stub function so that
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
index 7c33b29..ec83e9e 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
@@ -4420,5 +4420,60 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(CanonicalizeEntryPointIOTest, F16_Polyfill_Spirv) {
+ auto* src = R"(
+enable f16;
+
+struct Outputs {
+ @location(1) a : f16,
+ @location(2) b : vec4<f16>,
+}
+
+@fragment
+fn frag_main(@location(1) loc1 : f16,
+ @location(2) loc2 : vec4<f16>) -> Outputs {
+ return Outputs(loc1 * 2, loc2 * 3);
+}
+)";
+
+ auto* expect = R"(
+enable f16;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__in> loc1_1 : f32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__in> loc2_1 : vec4<f32>;
+
+@location(1) @internal(disable_validation__ignore_address_space) var<__out> a_1 : f32;
+
+@location(2) @internal(disable_validation__ignore_address_space) var<__out> b_1 : vec4<f32>;
+
+struct Outputs {
+ a : f16,
+ b : vec4<f16>,
+}
+
+fn frag_main_inner(loc1 : f16, loc2 : vec4<f16>) -> Outputs {
+ return Outputs((loc1 * 2), (loc2 * 3));
+}
+
+@fragment
+fn frag_main() {
+ let inner_result = frag_main_inner(f16(loc1_1), vec4<f16>(loc2_1));
+ a_1 = f32(inner_result.a);
+ b_1 = vec4<f32>(inner_result.b);
+}
+)";
+
+ DataMap data;
+
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kSpirv,
+ /* fixed_sample_mask */ 0xFFFFFFFF,
+ /* emit_vertex_point_size */ false,
+ /* polyfill_f16_io */ true);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::ast::transform