Tint/Dawn: SPIRV AST writer to support experimental subgroup UCF
This CL make Tint SPIRV AST writer require
SPV_KHR_subgroup_uniform_control_flow SPIRV extension and require
execution mode SubgroupUniformControlFlowKHR for all compute
entrypoints, if the device has feature
Chromium-experimental-subgroup-uniform-control-flow. This CL also
implement the related Tint SPIRV writer unittest.
Bug: dawn:464
Change-Id: Ibcb2f539d9dc7af36753acdd7d797ed919f0e87a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/148841
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index c727acc..7251ba7 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -179,6 +179,7 @@
X(bool, clampFragDepth) \
X(bool, disableImageRobustness) \
X(bool, disableRuntimeSizedArrayIndexClamping) \
+ X(bool, experimentalRequireSubgroupUniformControlFlow) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
@@ -272,6 +273,11 @@
GetDevice()->IsToggleEnabled(Toggle::VulkanUseBufferRobustAccess2);
req.platform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
req.substituteOverrideConfig = std::move(substituteOverrideConfig);
+ // Set subgroup uniform control flow flag for subgroup experiment, if device has
+ // Chromium-experimental-subgroup-uniform-control-flow feature. (dawn:464)
+ if (GetDevice()->HasFeature(Feature::ChromiumExperimentalSubgroupUniformControlFlow)) {
+ req.experimentalRequireSubgroupUniformControlFlow = true;
+ }
const CombinedLimits& limits = GetDevice()->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
@@ -344,6 +350,8 @@
options.disable_image_robustness = r.disableImageRobustness;
options.disable_runtime_sized_array_index_clamping =
r.disableRuntimeSizedArrayIndexClamping;
+ options.experimental_require_subgroup_uniform_control_flow =
+ r.experimentalRequireSubgroupUniformControlFlow;
TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::spirv::writer::Generate()");
auto tintResult = tint::spirv::writer::Generate(&program, options);
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
index c0f818d..c36f0bc 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
@@ -179,8 +179,12 @@
return result;
}
-ASTPrinter::ASTPrinter(const Program* program, bool zero_initialize_workgroup_memory)
- : builder_(program, zero_initialize_workgroup_memory) {}
+ASTPrinter::ASTPrinter(const Program* program,
+ bool zero_initialize_workgroup_memory,
+ bool experimental_require_subgroup_uniform_control_flow)
+ : builder_(program,
+ zero_initialize_workgroup_memory,
+ experimental_require_subgroup_uniform_control_flow) {}
bool ASTPrinter::Generate() {
if (builder_.Build()) {
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.h b/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
index a5687db..7c69ef6 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
@@ -43,7 +43,12 @@
/// @param program the program to generate
/// @param zero_initialize_workgroup_memory `true` to initialize all the
/// variables in the Workgroup address space with OpConstantNull
- ASTPrinter(const Program* program, bool zero_initialize_workgroup_memory);
+ /// @param experimental_require_subgroup_uniform_control_flow `true` to require
+ /// `SPV_KHR_subgroup_uniform_control_flow` extension and `SubgroupUniformControlFlowKHR`
+ /// execution mode for compute stage entry points.
+ ASTPrinter(const Program* program,
+ bool zero_initialize_workgroup_memory,
+ bool experimental_require_subgroup_uniform_control_flow);
/// @returns true on successful generation; false otherwise
bool Generate();
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.cc b/src/tint/lang/spirv/writer/ast_printer/builder.cc
index a06081e..65a13af 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.cc
@@ -248,10 +248,14 @@
Builder::AccessorInfo::~AccessorInfo() {}
-Builder::Builder(const Program* program, bool zero_initialize_workgroup_memory)
+Builder::Builder(const Program* program,
+ bool zero_initialize_workgroup_memory,
+ bool experimental_require_subgroup_uniform_control_flow)
: builder_(ProgramBuilder::Wrap(program)),
scope_stack_{Scope{}},
- zero_initialize_workgroup_memory_(zero_initialize_workgroup_memory) {}
+ zero_initialize_workgroup_memory_(zero_initialize_workgroup_memory),
+ experimental_require_subgroup_uniform_control_flow_(
+ experimental_require_subgroup_uniform_control_flow) {}
Builder::~Builder() = default;
@@ -280,6 +284,11 @@
GenerateExtension(ext);
}
+ // Emit SPV_KHR_subgroup_uniform_control_flow extension if required.
+ if (experimental_require_subgroup_uniform_control_flow_) {
+ module_.PushExtension("SPV_KHR_subgroup_uniform_control_flow");
+ }
+
for (auto* var : builder_.AST().GlobalVariables()) {
if (!GenerateGlobalVariable(var)) {
return false;
@@ -483,9 +492,18 @@
if (builtin == core::BuiltinValue::kFragDepth) {
module_.PushExecutionMode(spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeDepthReplacing)});
+ break;
}
}
+ // Use SubgroupUniformControlFlow execution mode for compute stage if required.
+ if (experimental_require_subgroup_uniform_control_flow_ &&
+ func->PipelineStage() == ast::PipelineStage::kCompute) {
+ module_.PushExecutionMode(
+ spv::Op::OpExecutionMode,
+ {Operand(id), U32Operand(SpvExecutionModeSubgroupUniformControlFlowKHR)});
+ }
+
return true;
}
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.h b/src/tint/lang/spirv/writer/ast_printer/builder.h
index adca66c..0f0c84b 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.h
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.h
@@ -82,7 +82,12 @@
/// @param program the program
/// @param zero_initialize_workgroup_memory `true` to initialize all the
/// variables in the Workgroup address space with OpConstantNull
- explicit Builder(const Program* program, bool zero_initialize_workgroup_memory = false);
+ /// @param experimental_require_subgroup_uniform_control_flow `true` to require
+ /// `SPV_KHR_subgroup_uniform_control_flow` extension and `SubgroupUniformControlFlowKHR`
+ /// execution mode for compute stage entry points.
+ explicit Builder(const Program* program,
+ bool zero_initialize_workgroup_memory = false,
+ bool experimental_require_subgroup_uniform_control_flow = false);
~Builder();
/// Generates the SPIR-V instructions for the given program
@@ -537,6 +542,7 @@
std::vector<uint32_t> merge_stack_;
std::vector<uint32_t> continue_stack_;
bool zero_initialize_workgroup_memory_ = false;
+ bool experimental_require_subgroup_uniform_control_flow_ = false;
struct ContinuingInfo {
ContinuingInfo(const ast::Statement* last_statement,
diff --git a/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc b/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
index 95c910f..fdced66 100644
--- a/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
@@ -208,6 +208,11 @@
// fn frag_main(inputs : Interface) -> @builtin(frag_depth) f32 {
// return inputs.value;
// }
+ //
+ // @compute @workgroup_size(1)
+ // fn compute_main() {
+ // return;
+ // }
auto* interface =
Structure("Interface",
@@ -232,6 +237,9 @@
Builtin(core::BuiltinValue::kFragDepth),
});
+ Func("compute_main", tint::Empty, ty.void_(), Vector{Return()},
+ Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_u)});
+
Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build()) << b.Diagnostics();
@@ -240,8 +248,10 @@
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %23 "vert_main" %1 %5 %9
OpEntryPoint Fragment %34 "frag_main" %10 %12 %14
+OpEntryPoint GLCompute %40 "compute_main"
OpExecutionMode %34 OriginUpperLeft
OpExecutionMode %34 DepthReplacing
+OpExecutionMode %40 LocalSize 1 1 1
OpName %1 "value_1"
OpName %5 "pos_1"
OpName %9 "vertex_point_size"
@@ -256,6 +266,7 @@
OpName %30 "frag_main_inner"
OpName %31 "inputs"
OpName %34 "frag_main"
+OpName %40 "compute_main"
OpDecorate %1 Location 1
OpDecorate %5 BuiltIn Position
OpDecorate %9 BuiltIn PointSize
@@ -315,6 +326,160 @@
OpStore %14 %36
OpReturn
OpFunctionEnd
+%40 = OpFunction %22 None %21
+%41 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+
+ Validate(b);
+}
+
+// Tests SPIRV generation with experimental_require_subgroup_uniform_control_flow in
+// spirv::writer::Options set to true, should require "SPV_KHR_subgroup_uniform_control_flow"
+// extension and use SubgroupUniformControlFlowKHR execution mode on compute stage entry points.
+TEST_F(SpirvASTPrinterTest, EntryPoint_ExperimentalSubgroupUniformControlFlow) {
+ // struct Interface {
+ // @location(1) value : f32;
+ // @builtin(position) pos : vec4<f32>;
+ // };
+ //
+ // @vertex
+ // fn vert_main() -> Interface {
+ // return Interface(42.0, vec4<f32>());
+ // }
+ //
+ // @fragment
+ // fn frag_main(inputs : Interface) -> @builtin(frag_depth) f32 {
+ // return inputs.value;
+ // }
+ //
+ // @compute @workgroup_size(1)
+ // fn compute_main() {
+ // return;
+ // }
+
+ auto* interface =
+ Structure("Interface",
+ Vector{
+ Member("value", ty.f32(), Vector{Location(1_u)}),
+ Member("pos", ty.vec4<f32>(), Vector{Builtin(core::BuiltinValue::kPosition)}),
+ });
+
+ auto* vert_retval = Call(ty.Of(interface), 42_f, Call<vec4<f32>>());
+ Func("vert_main", tint::Empty, ty.Of(interface), Vector{Return(vert_retval)},
+ Vector{
+ Stage(ast::PipelineStage::kVertex),
+ });
+
+ auto* frag_inputs = Param("inputs", ty.Of(interface));
+ Func("frag_main", Vector{frag_inputs}, ty.f32(),
+ Vector{
+ Return(MemberAccessor(Expr("inputs"), "value")),
+ },
+ Vector{Stage(ast::PipelineStage::kFragment)},
+ Vector{
+ Builtin(core::BuiltinValue::kFragDepth),
+ });
+
+ Func("compute_main", tint::Empty, ty.void_(), Vector{Return()},
+ Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_u)});
+
+ Options options = DefaultOptions();
+ options.experimental_require_subgroup_uniform_control_flow = true;
+
+ Builder& b = SanitizeAndBuild(options);
+
+ ASSERT_TRUE(b.Build()) << b.Diagnostics();
+
+ EXPECT_EQ(DumpModule(b.Module()), R"(OpCapability Shader
+OpExtension "SPV_KHR_subgroup_uniform_control_flow"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %23 "vert_main" %1 %5 %9
+OpEntryPoint Fragment %34 "frag_main" %10 %12 %14
+OpEntryPoint GLCompute %40 "compute_main"
+OpExecutionMode %34 OriginUpperLeft
+OpExecutionMode %34 DepthReplacing
+OpExecutionMode %40 LocalSize 1 1 1
+OpExecutionMode %40 SubgroupUniformControlFlowKHR
+OpName %1 "value_1"
+OpName %5 "pos_1"
+OpName %9 "vertex_point_size"
+OpName %10 "value_2"
+OpName %12 "pos_2"
+OpName %14 "value_3"
+OpName %16 "Interface"
+OpMemberName %16 0 "value"
+OpMemberName %16 1 "pos"
+OpName %17 "vert_main_inner"
+OpName %23 "vert_main"
+OpName %30 "frag_main_inner"
+OpName %31 "inputs"
+OpName %34 "frag_main"
+OpName %40 "compute_main"
+OpDecorate %1 Location 1
+OpDecorate %5 BuiltIn Position
+OpDecorate %9 BuiltIn PointSize
+OpDecorate %10 Location 1
+OpDecorate %12 BuiltIn FragCoord
+OpDecorate %14 BuiltIn FragDepth
+OpMemberDecorate %16 0 Offset 0
+OpMemberDecorate %16 1 Offset 16
+%3 = OpTypeFloat 32
+%2 = OpTypePointer Output %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Output %4
+%7 = OpTypeVector %3 4
+%6 = OpTypePointer Output %7
+%8 = OpConstantNull %7
+%5 = OpVariable %6 Output %8
+%9 = OpVariable %2 Output %4
+%11 = OpTypePointer Input %3
+%10 = OpVariable %11 Input
+%13 = OpTypePointer Input %7
+%12 = OpVariable %13 Input
+%14 = OpVariable %2 Output %4
+%16 = OpTypeStruct %3 %7
+%15 = OpTypeFunction %16
+%19 = OpConstant %3 42
+%20 = OpConstantComposite %16 %19 %8
+%22 = OpTypeVoid
+%21 = OpTypeFunction %22
+%28 = OpConstant %3 1
+%29 = OpTypeFunction %3 %16
+%17 = OpFunction %16 None %15
+%18 = OpLabel
+OpReturnValue %20
+OpFunctionEnd
+%23 = OpFunction %22 None %21
+%24 = OpLabel
+%25 = OpFunctionCall %16 %17
+%26 = OpCompositeExtract %3 %25 0
+OpStore %1 %26
+%27 = OpCompositeExtract %7 %25 1
+OpStore %5 %27
+OpStore %9 %28
+OpReturn
+OpFunctionEnd
+%30 = OpFunction %3 None %29
+%31 = OpFunctionParameter %16
+%32 = OpLabel
+%33 = OpCompositeExtract %3 %31 0
+OpReturnValue %33
+OpFunctionEnd
+%34 = OpFunction %22 None %21
+%35 = OpLabel
+%37 = OpLoad %3 %10
+%38 = OpLoad %7 %12
+%39 = OpCompositeConstruct %16 %37 %38
+%36 = OpFunctionCall %3 %30 %39
+OpStore %14 %36
+OpReturn
+OpFunctionEnd
+%40 = OpFunction %22 None %21
+%41 = OpLabel
+OpReturn
+OpFunctionEnd
)");
Validate(b);
diff --git a/src/tint/lang/spirv/writer/ast_printer/helper_test.h b/src/tint/lang/spirv/writer/ast_printer/helper_test.h
index cee975d..e97bb77 100644
--- a/src/tint/lang/spirv/writer/ast_printer/helper_test.h
+++ b/src/tint/lang/spirv/writer/ast_printer/helper_test.h
@@ -80,7 +80,12 @@
auto result = Sanitize(program.get(), options);
[&] { ASSERT_TRUE(result.program.IsValid()) << result.program.Diagnostics().str(); }();
*program = std::move(result.program);
- spirv_builder = std::make_unique<Builder>(program.get());
+ bool zero_initialize_workgroup_memory =
+ !options.disable_workgroup_init &&
+ options.use_zero_initialize_workgroup_memory_extension;
+ spirv_builder =
+ std::make_unique<Builder>(program.get(), zero_initialize_workgroup_memory,
+ options.experimental_require_subgroup_uniform_control_flow);
return *spirv_builder;
}
diff --git a/src/tint/lang/spirv/writer/common/options.h b/src/tint/lang/spirv/writer/common/options.h
index 068f394..fce8659 100644
--- a/src/tint/lang/spirv/writer/common/options.h
+++ b/src/tint/lang/spirv/writer/common/options.h
@@ -55,6 +55,11 @@
/// Set to `true` to generate SPIR-V via the Tint IR instead of from the AST.
bool use_tint_ir = false;
+ /// Set to `true` to require `SPV_KHR_subgroup_uniform_control_flow` extension and
+ /// `SubgroupUniformControlFlowKHR` execution mode for compute stage entry points in generated
+ /// SPIRV module. Issue: dawn:464
+ bool experimental_require_subgroup_uniform_control_flow = false;
+
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(disable_robustness,
emit_vertex_point_size,
@@ -65,7 +70,8 @@
use_zero_initialize_workgroup_memory_extension,
disable_image_robustness,
disable_runtime_sized_array_index_clamping,
- use_tint_ir);
+ use_tint_ir,
+ experimental_require_subgroup_uniform_control_flow);
};
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index 815c395..81acf02 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -78,8 +78,9 @@
}
// Generate the SPIR-V code.
- auto impl = std::make_unique<ASTPrinter>(&sanitized_result.program,
- zero_initialize_workgroup_memory);
+ auto impl = std::make_unique<ASTPrinter>(
+ &sanitized_result.program, zero_initialize_workgroup_memory,
+ options.experimental_require_subgroup_uniform_control_flow);
if (!impl->Generate()) {
return impl->Diagnostics().str();
}