Tint/Dawn: SPIRV AST writer to support experimental subgroup UCF

This CL make Tint SPIRV AST writer require
SPV_KHR_subgroup_uniform_control_flow SPIRV extension and require
execution mode SubgroupUniformControlFlowKHR for all compute
entrypoints, if the device has feature
Chromium-experimental-subgroup-uniform-control-flow. This CL also
implement the related Tint SPIRV writer unittest.

Bug: dawn:464

Change-Id: Ibcb2f539d9dc7af36753acdd7d797ed919f0e87a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/148841
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index c727acc..7251ba7 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -179,6 +179,7 @@
     X(bool, clampFragDepth)                                                                      \
     X(bool, disableImageRobustness)                                                              \
     X(bool, disableRuntimeSizedArrayIndexClamping)                                               \
+    X(bool, experimentalRequireSubgroupUniformControlFlow)                                       \
     X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
 
 DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
@@ -272,6 +273,11 @@
         GetDevice()->IsToggleEnabled(Toggle::VulkanUseBufferRobustAccess2);
     req.platform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
     req.substituteOverrideConfig = std::move(substituteOverrideConfig);
+    // Set subgroup uniform control flow flag for subgroup experiment, if device has
+    // Chromium-experimental-subgroup-uniform-control-flow feature. (dawn:464)
+    if (GetDevice()->HasFeature(Feature::ChromiumExperimentalSubgroupUniformControlFlow)) {
+        req.experimentalRequireSubgroupUniformControlFlow = true;
+    }
 
     const CombinedLimits& limits = GetDevice()->GetLimits();
     req.limits = LimitsForCompilationRequest::Create(limits.v1);
@@ -344,6 +350,8 @@
             options.disable_image_robustness = r.disableImageRobustness;
             options.disable_runtime_sized_array_index_clamping =
                 r.disableRuntimeSizedArrayIndexClamping;
+            options.experimental_require_subgroup_uniform_control_flow =
+                r.experimentalRequireSubgroupUniformControlFlow;
 
             TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::spirv::writer::Generate()");
             auto tintResult = tint::spirv::writer::Generate(&program, options);
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
index c0f818d..c36f0bc 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.cc
@@ -179,8 +179,12 @@
     return result;
 }
 
-ASTPrinter::ASTPrinter(const Program* program, bool zero_initialize_workgroup_memory)
-    : builder_(program, zero_initialize_workgroup_memory) {}
+ASTPrinter::ASTPrinter(const Program* program,
+                       bool zero_initialize_workgroup_memory,
+                       bool experimental_require_subgroup_uniform_control_flow)
+    : builder_(program,
+               zero_initialize_workgroup_memory,
+               experimental_require_subgroup_uniform_control_flow) {}
 
 bool ASTPrinter::Generate() {
     if (builder_.Build()) {
diff --git a/src/tint/lang/spirv/writer/ast_printer/ast_printer.h b/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
index a5687db..7c69ef6 100644
--- a/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
+++ b/src/tint/lang/spirv/writer/ast_printer/ast_printer.h
@@ -43,7 +43,12 @@
     /// @param program the program to generate
     /// @param zero_initialize_workgroup_memory `true` to initialize all the
     /// variables in the Workgroup address space with OpConstantNull
-    ASTPrinter(const Program* program, bool zero_initialize_workgroup_memory);
+    /// @param experimental_require_subgroup_uniform_control_flow `true` to require
+    /// `SPV_KHR_subgroup_uniform_control_flow` extension and `SubgroupUniformControlFlowKHR`
+    /// execution mode for compute stage entry points.
+    ASTPrinter(const Program* program,
+               bool zero_initialize_workgroup_memory,
+               bool experimental_require_subgroup_uniform_control_flow);
 
     /// @returns true on successful generation; false otherwise
     bool Generate();
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.cc b/src/tint/lang/spirv/writer/ast_printer/builder.cc
index a06081e..65a13af 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.cc
@@ -248,10 +248,14 @@
 
 Builder::AccessorInfo::~AccessorInfo() {}
 
-Builder::Builder(const Program* program, bool zero_initialize_workgroup_memory)
+Builder::Builder(const Program* program,
+                 bool zero_initialize_workgroup_memory,
+                 bool experimental_require_subgroup_uniform_control_flow)
     : builder_(ProgramBuilder::Wrap(program)),
       scope_stack_{Scope{}},
-      zero_initialize_workgroup_memory_(zero_initialize_workgroup_memory) {}
+      zero_initialize_workgroup_memory_(zero_initialize_workgroup_memory),
+      experimental_require_subgroup_uniform_control_flow_(
+          experimental_require_subgroup_uniform_control_flow) {}
 
 Builder::~Builder() = default;
 
@@ -280,6 +284,11 @@
         GenerateExtension(ext);
     }
 
+    // Emit SPV_KHR_subgroup_uniform_control_flow extension if required.
+    if (experimental_require_subgroup_uniform_control_flow_) {
+        module_.PushExtension("SPV_KHR_subgroup_uniform_control_flow");
+    }
+
     for (auto* var : builder_.AST().GlobalVariables()) {
         if (!GenerateGlobalVariable(var)) {
             return false;
@@ -483,9 +492,18 @@
         if (builtin == core::BuiltinValue::kFragDepth) {
             module_.PushExecutionMode(spv::Op::OpExecutionMode,
                                       {Operand(id), U32Operand(SpvExecutionModeDepthReplacing)});
+            break;
         }
     }
 
+    // Use SubgroupUniformControlFlow execution mode for compute stage if required.
+    if (experimental_require_subgroup_uniform_control_flow_ &&
+        func->PipelineStage() == ast::PipelineStage::kCompute) {
+        module_.PushExecutionMode(
+            spv::Op::OpExecutionMode,
+            {Operand(id), U32Operand(SpvExecutionModeSubgroupUniformControlFlowKHR)});
+    }
+
     return true;
 }
 
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.h b/src/tint/lang/spirv/writer/ast_printer/builder.h
index adca66c..0f0c84b 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.h
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.h
@@ -82,7 +82,12 @@
     /// @param program the program
     /// @param zero_initialize_workgroup_memory `true` to initialize all the
     /// variables in the Workgroup address space with OpConstantNull
-    explicit Builder(const Program* program, bool zero_initialize_workgroup_memory = false);
+    /// @param experimental_require_subgroup_uniform_control_flow `true` to require
+    /// `SPV_KHR_subgroup_uniform_control_flow` extension and `SubgroupUniformControlFlowKHR`
+    /// execution mode for compute stage entry points.
+    explicit Builder(const Program* program,
+                     bool zero_initialize_workgroup_memory = false,
+                     bool experimental_require_subgroup_uniform_control_flow = false);
     ~Builder();
 
     /// Generates the SPIR-V instructions for the given program
@@ -537,6 +542,7 @@
     std::vector<uint32_t> merge_stack_;
     std::vector<uint32_t> continue_stack_;
     bool zero_initialize_workgroup_memory_ = false;
+    bool experimental_require_subgroup_uniform_control_flow_ = false;
 
     struct ContinuingInfo {
         ContinuingInfo(const ast::Statement* last_statement,
diff --git a/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc b/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
index 95c910f..fdced66 100644
--- a/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/entry_point_test.cc
@@ -208,6 +208,11 @@
     // fn frag_main(inputs : Interface) -> @builtin(frag_depth) f32 {
     //   return inputs.value;
     // }
+    //
+    // @compute @workgroup_size(1)
+    // fn compute_main() {
+    //   return;
+    // }
 
     auto* interface =
         Structure("Interface",
@@ -232,6 +237,9 @@
              Builtin(core::BuiltinValue::kFragDepth),
          });
 
+    Func("compute_main", tint::Empty, ty.void_(), Vector{Return()},
+         Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_u)});
+
     Builder& b = SanitizeAndBuild();
 
     ASSERT_TRUE(b.Build()) << b.Diagnostics();
@@ -240,8 +248,10 @@
 OpMemoryModel Logical GLSL450
 OpEntryPoint Vertex %23 "vert_main" %1 %5 %9
 OpEntryPoint Fragment %34 "frag_main" %10 %12 %14
+OpEntryPoint GLCompute %40 "compute_main"
 OpExecutionMode %34 OriginUpperLeft
 OpExecutionMode %34 DepthReplacing
+OpExecutionMode %40 LocalSize 1 1 1
 OpName %1 "value_1"
 OpName %5 "pos_1"
 OpName %9 "vertex_point_size"
@@ -256,6 +266,7 @@
 OpName %30 "frag_main_inner"
 OpName %31 "inputs"
 OpName %34 "frag_main"
+OpName %40 "compute_main"
 OpDecorate %1 Location 1
 OpDecorate %5 BuiltIn Position
 OpDecorate %9 BuiltIn PointSize
@@ -315,6 +326,160 @@
 OpStore %14 %36
 OpReturn
 OpFunctionEnd
+%40 = OpFunction %22 None %21
+%41 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+
+    Validate(b);
+}
+
+// Tests SPIRV generation with experimental_require_subgroup_uniform_control_flow in
+// spirv::writer::Options set to true, should require "SPV_KHR_subgroup_uniform_control_flow"
+// extension and use SubgroupUniformControlFlowKHR execution mode on compute stage entry points.
+TEST_F(SpirvASTPrinterTest, EntryPoint_ExperimentalSubgroupUniformControlFlow) {
+    // struct Interface {
+    //   @location(1) value : f32;
+    //   @builtin(position) pos : vec4<f32>;
+    // };
+    //
+    // @vertex
+    // fn vert_main() -> Interface {
+    //   return Interface(42.0, vec4<f32>());
+    // }
+    //
+    // @fragment
+    // fn frag_main(inputs : Interface) -> @builtin(frag_depth) f32 {
+    //   return inputs.value;
+    // }
+    //
+    // @compute @workgroup_size(1)
+    // fn compute_main() {
+    //   return;
+    // }
+
+    auto* interface =
+        Structure("Interface",
+                  Vector{
+                      Member("value", ty.f32(), Vector{Location(1_u)}),
+                      Member("pos", ty.vec4<f32>(), Vector{Builtin(core::BuiltinValue::kPosition)}),
+                  });
+
+    auto* vert_retval = Call(ty.Of(interface), 42_f, Call<vec4<f32>>());
+    Func("vert_main", tint::Empty, ty.Of(interface), Vector{Return(vert_retval)},
+         Vector{
+             Stage(ast::PipelineStage::kVertex),
+         });
+
+    auto* frag_inputs = Param("inputs", ty.Of(interface));
+    Func("frag_main", Vector{frag_inputs}, ty.f32(),
+         Vector{
+             Return(MemberAccessor(Expr("inputs"), "value")),
+         },
+         Vector{Stage(ast::PipelineStage::kFragment)},
+         Vector{
+             Builtin(core::BuiltinValue::kFragDepth),
+         });
+
+    Func("compute_main", tint::Empty, ty.void_(), Vector{Return()},
+         Vector{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_u)});
+
+    Options options = DefaultOptions();
+    options.experimental_require_subgroup_uniform_control_flow = true;
+
+    Builder& b = SanitizeAndBuild(options);
+
+    ASSERT_TRUE(b.Build()) << b.Diagnostics();
+
+    EXPECT_EQ(DumpModule(b.Module()), R"(OpCapability Shader
+OpExtension "SPV_KHR_subgroup_uniform_control_flow"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %23 "vert_main" %1 %5 %9
+OpEntryPoint Fragment %34 "frag_main" %10 %12 %14
+OpEntryPoint GLCompute %40 "compute_main"
+OpExecutionMode %34 OriginUpperLeft
+OpExecutionMode %34 DepthReplacing
+OpExecutionMode %40 LocalSize 1 1 1
+OpExecutionMode %40 SubgroupUniformControlFlowKHR
+OpName %1 "value_1"
+OpName %5 "pos_1"
+OpName %9 "vertex_point_size"
+OpName %10 "value_2"
+OpName %12 "pos_2"
+OpName %14 "value_3"
+OpName %16 "Interface"
+OpMemberName %16 0 "value"
+OpMemberName %16 1 "pos"
+OpName %17 "vert_main_inner"
+OpName %23 "vert_main"
+OpName %30 "frag_main_inner"
+OpName %31 "inputs"
+OpName %34 "frag_main"
+OpName %40 "compute_main"
+OpDecorate %1 Location 1
+OpDecorate %5 BuiltIn Position
+OpDecorate %9 BuiltIn PointSize
+OpDecorate %10 Location 1
+OpDecorate %12 BuiltIn FragCoord
+OpDecorate %14 BuiltIn FragDepth
+OpMemberDecorate %16 0 Offset 0
+OpMemberDecorate %16 1 Offset 16
+%3 = OpTypeFloat 32
+%2 = OpTypePointer Output %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Output %4
+%7 = OpTypeVector %3 4
+%6 = OpTypePointer Output %7
+%8 = OpConstantNull %7
+%5 = OpVariable %6 Output %8
+%9 = OpVariable %2 Output %4
+%11 = OpTypePointer Input %3
+%10 = OpVariable %11 Input
+%13 = OpTypePointer Input %7
+%12 = OpVariable %13 Input
+%14 = OpVariable %2 Output %4
+%16 = OpTypeStruct %3 %7
+%15 = OpTypeFunction %16
+%19 = OpConstant %3 42
+%20 = OpConstantComposite %16 %19 %8
+%22 = OpTypeVoid
+%21 = OpTypeFunction %22
+%28 = OpConstant %3 1
+%29 = OpTypeFunction %3 %16
+%17 = OpFunction %16 None %15
+%18 = OpLabel
+OpReturnValue %20
+OpFunctionEnd
+%23 = OpFunction %22 None %21
+%24 = OpLabel
+%25 = OpFunctionCall %16 %17
+%26 = OpCompositeExtract %3 %25 0
+OpStore %1 %26
+%27 = OpCompositeExtract %7 %25 1
+OpStore %5 %27
+OpStore %9 %28
+OpReturn
+OpFunctionEnd
+%30 = OpFunction %3 None %29
+%31 = OpFunctionParameter %16
+%32 = OpLabel
+%33 = OpCompositeExtract %3 %31 0
+OpReturnValue %33
+OpFunctionEnd
+%34 = OpFunction %22 None %21
+%35 = OpLabel
+%37 = OpLoad %3 %10
+%38 = OpLoad %7 %12
+%39 = OpCompositeConstruct %16 %37 %38
+%36 = OpFunctionCall %3 %30 %39
+OpStore %14 %36
+OpReturn
+OpFunctionEnd
+%40 = OpFunction %22 None %21
+%41 = OpLabel
+OpReturn
+OpFunctionEnd
 )");
 
     Validate(b);
diff --git a/src/tint/lang/spirv/writer/ast_printer/helper_test.h b/src/tint/lang/spirv/writer/ast_printer/helper_test.h
index cee975d..e97bb77 100644
--- a/src/tint/lang/spirv/writer/ast_printer/helper_test.h
+++ b/src/tint/lang/spirv/writer/ast_printer/helper_test.h
@@ -80,7 +80,12 @@
         auto result = Sanitize(program.get(), options);
         [&] { ASSERT_TRUE(result.program.IsValid()) << result.program.Diagnostics().str(); }();
         *program = std::move(result.program);
-        spirv_builder = std::make_unique<Builder>(program.get());
+        bool zero_initialize_workgroup_memory =
+            !options.disable_workgroup_init &&
+            options.use_zero_initialize_workgroup_memory_extension;
+        spirv_builder =
+            std::make_unique<Builder>(program.get(), zero_initialize_workgroup_memory,
+                                      options.experimental_require_subgroup_uniform_control_flow);
         return *spirv_builder;
     }
 
diff --git a/src/tint/lang/spirv/writer/common/options.h b/src/tint/lang/spirv/writer/common/options.h
index 068f394..fce8659 100644
--- a/src/tint/lang/spirv/writer/common/options.h
+++ b/src/tint/lang/spirv/writer/common/options.h
@@ -55,6 +55,11 @@
     /// Set to `true` to generate SPIR-V via the Tint IR instead of from the AST.
     bool use_tint_ir = false;
 
+    /// Set to `true` to require `SPV_KHR_subgroup_uniform_control_flow` extension and
+    /// `SubgroupUniformControlFlowKHR` execution mode for compute stage entry points in generated
+    /// SPIRV module. Issue: dawn:464
+    bool experimental_require_subgroup_uniform_control_flow = false;
+
     /// Reflect the fields of this class so that it can be used by tint::ForeachField()
     TINT_REFLECT(disable_robustness,
                  emit_vertex_point_size,
@@ -65,7 +70,8 @@
                  use_zero_initialize_workgroup_memory_extension,
                  disable_image_robustness,
                  disable_runtime_sized_array_index_clamping,
-                 use_tint_ir);
+                 use_tint_ir,
+                 experimental_require_subgroup_uniform_control_flow);
 };
 
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index 815c395..81acf02 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -78,8 +78,9 @@
         }
 
         // Generate the SPIR-V code.
-        auto impl = std::make_unique<ASTPrinter>(&sanitized_result.program,
-                                                 zero_initialize_workgroup_memory);
+        auto impl = std::make_unique<ASTPrinter>(
+            &sanitized_result.program, zero_initialize_workgroup_memory,
+            options.experimental_require_subgroup_uniform_control_flow);
         if (!impl->Generate()) {
             return impl->Diagnostics().str();
         }