[spirv-writer] Only use StorageInputOutput16 if needed
We only need to emit this capability if there is an Input or Output
variable with an f16 type.
Bug: tint:2161
Change-Id: Ibcc65b1a6280247df27ad78f82b2aa3421b1a737
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/173702
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/spirv/writer/function_test.cc b/src/tint/lang/spirv/writer/function_test.cc
index 2658f3f..24c80bf 100644
--- a/src/tint/lang/spirv/writer/function_test.cc
+++ b/src/tint/lang/spirv/writer/function_test.cc
@@ -346,6 +346,58 @@
)");
}
+TEST_F(SpirvWriterTest, Function_ShaderIO_F16_Input) {
+ auto* input = b.FunctionParam("input", ty.vec4<f16>());
+ input->SetLocation(1, std::nullopt);
+ auto* func = b.Function("main", ty.vec4<f32>(), core::ir::Function::PipelineStage::kFragment);
+ func->SetReturnLocation(2, std::nullopt);
+ func->SetParams({input});
+ b.Append(func->Block(), [&] { //
+ b.Return(func, b.Convert(ty.vec4<f32>(), input));
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpCapability StorageInputOutput16");
+ EXPECT_INST(R"(OpEntryPoint Fragment %main "main" %main_loc1_Input %main_loc2_Output)");
+ EXPECT_INST("%main_loc1_Input = OpVariable %_ptr_Input_v4half Input");
+ EXPECT_INST("%main_loc2_Output = OpVariable %_ptr_Output_v4float Output");
+ EXPECT_INST(R"(
+ %main = OpFunction %void None %16
+ %17 = OpLabel
+ %18 = OpLoad %v4half %main_loc1_Input
+ %19 = OpFunctionCall %v4float %main_inner %18
+ OpStore %main_loc2_Output %19
+ OpReturn
+ OpFunctionEnd
+)");
+}
+
+TEST_F(SpirvWriterTest, Function_ShaderIO_F16_Output) {
+ auto* input = b.FunctionParam("input", ty.vec4<f32>());
+ input->SetLocation(1, std::nullopt);
+ auto* func = b.Function("main", ty.vec4<f16>(), core::ir::Function::PipelineStage::kFragment);
+ func->SetReturnLocation(2, std::nullopt);
+ func->SetParams({input});
+ b.Append(func->Block(), [&] { //
+ b.Return(func, b.Convert(ty.vec4<f16>(), input));
+ });
+
+ ASSERT_TRUE(Generate()) << Error() << output_;
+ EXPECT_INST("OpCapability StorageInputOutput16");
+ EXPECT_INST(R"(OpEntryPoint Fragment %main "main" %main_loc1_Input %main_loc2_Output)");
+ EXPECT_INST("%main_loc1_Input = OpVariable %_ptr_Input_v4float Input");
+ EXPECT_INST("%main_loc2_Output = OpVariable %_ptr_Output_v4half Output");
+ EXPECT_INST(R"(
+ %main = OpFunction %void None %16
+ %17 = OpLabel
+ %18 = OpLoad %v4float %main_loc1_Input
+ %19 = OpFunctionCall %v4half %main_inner %18
+ OpStore %main_loc2_Output %19
+ OpReturn
+ OpFunctionEnd
+)");
+}
+
TEST_F(SpirvWriterTest, Function_ShaderIO_DualSourceBlend) {
auto* outputs =
ty.Struct(mod.symbols.New("Outputs"), {
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index ef41faf..9fc98fd 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -493,7 +493,6 @@
module_.PushCapability(SpvCapabilityFloat16);
module_.PushCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
module_.PushCapability(SpvCapabilityStorageBuffer16BitAccess);
- module_.PushCapability(SpvCapabilityStorageInputOutput16);
module_.PushType(spv::Op::OpTypeFloat, {id, 16u});
},
[&](const core::type::Vector* vec) {
@@ -2065,6 +2064,9 @@
}
case core::AddressSpace::kIn: {
TINT_ASSERT(!current_function_);
+ if (store_ty->DeepestElement()->Is<core::type::F16>()) {
+ module_.PushCapability(SpvCapabilityStorageInputOutput16);
+ }
module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassInput)});
EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kIn);
break;
@@ -2089,6 +2091,9 @@
}
case core::AddressSpace::kOut: {
TINT_ASSERT(!current_function_);
+ if (store_ty->DeepestElement()->Is<core::type::F16>()) {
+ module_.PushCapability(SpvCapabilityStorageInputOutput16);
+ }
module_.PushType(spv::Op::OpVariable, {ty, id, U32Operand(SpvStorageClassOutput)});
EmitIOAttributes(id, var->Attributes(), core::AddressSpace::kOut);
break;
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index 08fd086..27fdc1b 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -98,7 +98,6 @@
EXPECT_INST("OpCapability Float16");
EXPECT_INST("OpCapability UniformAndStorageBuffer16BitAccess");
EXPECT_INST("OpCapability StorageBuffer16BitAccess");
- EXPECT_INST("OpCapability StorageInputOutput16");
EXPECT_INST("%half = OpTypeFloat 16");
}