writer/hlsl: Support overridable workgroup sizes

Use the WGSL_SPEC_CONSTANT preprocessor macros as parameters to
[numthreads()] when the dimension is overridable.

Remove the macro #undef to make this possible.

Bug: tint:713
Change-Id: Icd927044a64a8b8a2f029f9e2db8168ec6a861de
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51264
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 9d58070..9db64f0 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -47,6 +47,7 @@
 const char kTintStructInVarPrefix[] = "tint_in";
 const char kTintStructOutVarPrefix[] = "tint_out";
 const char kTempNamePrefix[] = "tint_tmp";
+const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
 
 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
   if (stmts->empty()) {
@@ -1994,18 +1995,26 @@
   auto* func_sem = builder_.Sem().Get(func);
 
   if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
+    // Emit the workgroup_size attribute.
     auto wgsize = func_sem->workgroup_size();
-    if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
-        wgsize[2].overridable_const) {
-      // TODO(crbug.com/tint/713): Handle overridable constants.
-      TINT_UNIMPLEMENTED(builder_.Diagnostics())
-          << "pipeline-overridable workgroup sizes are not implemented";
+    out << "[numthreads(";
+    for (int i = 0; i < 3; i++) {
+      if (i > 0) {
+        out << ", ";
+      }
+
+      if (wgsize[i].overridable_const) {
+        auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const);
+        if (!sem_const->IsPipelineConstant()) {
+          TINT_ICE(builder_.Diagnostics())
+              << "expected a pipeline-overridable constant";
+        }
+        out << kSpecConstantPrefix << sem_const->ConstantId();
+      } else {
+        out << std::to_string(wgsize[i].value);
+      }
     }
-    uint32_t x = wgsize[0].value;
-    uint32_t y = wgsize[1].value;
-    uint32_t z = wgsize[2].value;
-    out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
-        << ", " << std::to_string(z) << ")]" << std::endl;
+    out << ")]" << std::endl;
     make_indent(out);
   }
 
@@ -2721,10 +2730,10 @@
   if (sem->IsPipelineConstant()) {
     auto const_id = sem->ConstantId();
 
-    out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
+    out << "#ifndef " << kSpecConstantPrefix << const_id << std::endl;
 
     if (var->constructor() != nullptr) {
-      out << "#define WGSL_SPEC_CONSTANT_" << const_id << " "
+      out << "#define " << kSpecConstantPrefix << const_id << " "
           << constructor_out.str() << std::endl;
     } else {
       out << "#error spec constant required for constant id " << const_id
@@ -2736,9 +2745,8 @@
                   builder_.Symbols().NameFor(var->symbol()))) {
       return false;
     }
-    out << " " << builder_.Symbols().NameFor(var->symbol())
-        << " = WGSL_SPEC_CONSTANT_" << const_id << ";" << std::endl;
-    out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
+    out << " " << builder_.Symbols().NameFor(var->symbol()) << " = "
+        << kSpecConstantPrefix << const_id << ";" << std::endl;
   } else {
     out << "static const ";
     if (!EmitType(out, type, sem->StorageClass(), sem->AccessControl(),
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 502b60d..9ee9788 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -919,11 +919,8 @@
 }
 
 TEST_F(HlslGeneratorImplTest_Function,
-       Emit_Decoration_EntryPoint_Compute_WithWorkgroup) {
-  Func("main", ast::VariableList{}, ty.void_(),
-       {
-           Return(),
-       },
+       Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal) {
+  Func("main", ast::VariableList{}, ty.void_(), {},
        {
            Stage(ast::PipelineStage::kCompute),
            WorkgroupSize(2, 4, 6),
@@ -942,6 +939,69 @@
   Validate();
 }
 
+TEST_F(HlslGeneratorImplTest_Function,
+       Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const) {
+  GlobalConst("width", ty.i32(), Construct(ty.i32(), 2));
+  GlobalConst("height", ty.i32(), Construct(ty.i32(), 3));
+  GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4));
+  Func("main", ast::VariableList{}, ty.void_(), {},
+       {
+           Stage(ast::PipelineStage::kCompute),
+           WorkgroupSize("width", "height", "depth"),
+       });
+
+  GeneratorImpl& gen = Build();
+
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
+  EXPECT_EQ(result(), R"(static const int width = int(2);
+static const int height = int(3);
+static const int depth = int(4);
+[numthreads(2, 3, 4)]
+void main() {
+  return;
+}
+
+)");
+
+  Validate();
+}
+
+TEST_F(HlslGeneratorImplTest_Function,
+       Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
+  GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
+  GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)});
+  GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)});
+  Func("main", ast::VariableList{}, ty.void_(), {},
+       {
+           Stage(ast::PipelineStage::kCompute),
+           WorkgroupSize("width", "height", "depth"),
+       });
+
+  GeneratorImpl& gen = Build();
+
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
+  EXPECT_EQ(result(), R"(#ifndef WGSL_SPEC_CONSTANT_7
+#define WGSL_SPEC_CONSTANT_7 int(2)
+#endif
+static const int width = WGSL_SPEC_CONSTANT_7;
+#ifndef WGSL_SPEC_CONSTANT_8
+#define WGSL_SPEC_CONSTANT_8 int(3)
+#endif
+static const int height = WGSL_SPEC_CONSTANT_8;
+#ifndef WGSL_SPEC_CONSTANT_9
+#define WGSL_SPEC_CONSTANT_9 int(4)
+#endif
+static const int depth = WGSL_SPEC_CONSTANT_9;
+[numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)]
+void main() {
+  return;
+}
+
+)");
+
+  Validate();
+}
+
 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
   Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
        {
diff --git a/src/writer/hlsl/generator_impl_module_constant_test.cc b/src/writer/hlsl/generator_impl_module_constant_test.cc
index 6e82c94..8ff9c88 100644
--- a/src/writer/hlsl/generator_impl_module_constant_test.cc
+++ b/src/writer/hlsl/generator_impl_module_constant_test.cc
@@ -45,7 +45,6 @@
 #define WGSL_SPEC_CONSTANT_23 3.0f
 #endif
 static const float pos = WGSL_SPEC_CONSTANT_23;
-#undef WGSL_SPEC_CONSTANT_23
 )");
 }
 
@@ -62,7 +61,6 @@
 #error spec constant required for constant id 23
 #endif
 static const float pos = WGSL_SPEC_CONSTANT_23;
-#undef WGSL_SPEC_CONSTANT_23
 )");
 }
 
@@ -84,12 +82,10 @@
 #define WGSL_SPEC_CONSTANT_0 3.0f
 #endif
 static const float a = WGSL_SPEC_CONSTANT_0;
-#undef WGSL_SPEC_CONSTANT_0
 #ifndef WGSL_SPEC_CONSTANT_1
 #define WGSL_SPEC_CONSTANT_1 2.0f
 #endif
 static const float b = WGSL_SPEC_CONSTANT_1;
-#undef WGSL_SPEC_CONSTANT_1
 )");
 }