tint/hlsl: for default-only switch, only emit condition if it has side-effects

This fixes edge-cases, like the condition expression being a type-cast,
which DXC apparently sees as a variable re-declaration. Example:

fn foo(x : f32) {
  switch (i32(x)) {
    default {
    }
  }
}

was emitted as HLSL:

void foo(float x) {
  int(x);
  do {
  } while (false);
}

The `int(x)` is seen as a re-declaration of `x` by DXC.

We fix this by only emitted the condition expression if it has
side-effects (which currently means it contains a call expression).

Bug: tint:1820
Change-Id: I7e4320fa09ea2d634c9e324cb0b752b0ee7dcde9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118161
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index bde9d70..d49bcc9 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -3859,10 +3859,9 @@
     // default case body. We work around this here by emitting the default case
     // without the switch.
 
-    // Emit the switch condition as-is in case it has side-effects (e.g.
-    // function call). Note that's it's fine not to assign the result of the
-    // expression.
-    {
+    // Emit the switch condition as-is if it has side-effects (e.g.
+    // function call). Note that we can ignore the result of the expression (if any).
+    if (auto* sem_cond = builder_.Sem().Get(stmt->condition); sem_cond->HasSideEffects()) {
         auto out = line();
         if (!EmitExpression(out, stmt->condition)) {
             return false;
diff --git a/src/tint/writer/hlsl/generator_impl_switch_test.cc b/src/tint/writer/hlsl/generator_impl_switch_test.cc
index 698083a..6426b7c 100644
--- a/src/tint/writer/hlsl/generator_impl_switch_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_switch_test.cc
@@ -66,7 +66,16 @@
 )");
 }
 
-TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase) {
+TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase_NoSideEffectsCondition) {
+    // var<private> cond : i32;
+    // var<private> a : i32;
+    // fn test() {
+    //   switch(cond) {
+    //     default: {
+    //       a = 42;
+    //     }
+    //   }
+    // }
     GlobalVar("cond", ty.i32(), type::AddressSpace::kPrivate);
     GlobalVar("a", ty.i32(), type::AddressSpace::kPrivate);
     auto* s = Switch(  //
@@ -79,7 +88,45 @@
     gen.increment_indent();
 
     ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
-    EXPECT_EQ(gen.result(), R"(  cond;
+    EXPECT_EQ(gen.result(), R"(  do {
+    a = 42;
+  } while (false);
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase_SideEffectsCondition) {
+    // var<private> global : i32;
+    // fn bar() -> i32 {
+    //   global = 84;
+    //   return global;
+    // }
+    //
+    // var<private> a : i32;
+    // fn test() {
+    //   switch(bar()) {
+    //     default: {
+    //       a = 42;
+    //     }
+    //   }
+    // }
+    GlobalVar("global", ty.i32(), type::AddressSpace::kPrivate);
+    Func("bar", {}, ty.i32(),
+         utils::Vector{                               //
+                       Assign("global", Expr(84_i)),  //
+                       Return("global")});
+
+    GlobalVar("a", ty.i32(), type::AddressSpace::kPrivate);
+    auto* s = Switch(  //
+        Call("bar"),   //
+        DefaultCase(Block(Assign(Expr("a"), Expr(42_i)))));
+    WrapInFunction(s);
+
+    GeneratorImpl& gen = Build();
+
+    gen.increment_indent();
+
+    ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
+    EXPECT_EQ(gen.result(), R"(  bar();
   do {
     a = 42;
   } while (false);
diff --git a/test/tint/bug/tint/1820.wgsl b/test/tint/bug/tint/1820.wgsl
new file mode 100644
index 0000000..781a980
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl
@@ -0,0 +1,22 @@
+fn foo(x : f32) {
+  switch (i32(x)) {
+    default {
+    }
+  }
+}
+
+var<private> global : i32;
+fn baz(x : i32) -> i32 {
+    global = 42;
+    return x;
+}
+
+fn bar(x : f32) {
+  switch (baz(i32(x))) {
+    default {
+    }
+  }
+}
+
+fn main() {
+}
diff --git a/test/tint/bug/tint/1820.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1820.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..5e7dc39
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.dxc.hlsl
@@ -0,0 +1,25 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void foo(float x) {
+  do {
+  } while (false);
+}
+
+static int global = 0;
+
+int baz(int x) {
+  global = 42;
+  return x;
+}
+
+void bar(float x) {
+  baz(int(x));
+  do {
+  } while (false);
+}
+
+void main() {
+}
diff --git a/test/tint/bug/tint/1820.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1820.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..5e7dc39
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.fxc.hlsl
@@ -0,0 +1,25 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void foo(float x) {
+  do {
+  } while (false);
+}
+
+static int global = 0;
+
+int baz(int x) {
+  global = 42;
+  return x;
+}
+
+void bar(float x) {
+  baz(int(x));
+  do {
+  } while (false);
+}
+
+void main() {
+}
diff --git a/test/tint/bug/tint/1820.wgsl.expected.glsl b/test/tint/bug/tint/1820.wgsl.expected.glsl
new file mode 100644
index 0000000..9fa7dd0
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.glsl
@@ -0,0 +1,31 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+  return;
+}
+void foo(float x) {
+  switch(int(x)) {
+    default: {
+      break;
+    }
+  }
+}
+
+int global = 0;
+int baz(int x) {
+  global = 42;
+  return x;
+}
+
+void bar(float x) {
+  switch(baz(int(x))) {
+    default: {
+      break;
+    }
+  }
+}
+
+void tint_symbol() {
+}
+
diff --git a/test/tint/bug/tint/1820.wgsl.expected.msl b/test/tint/bug/tint/1820.wgsl.expected.msl
new file mode 100644
index 0000000..e811b1b
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.msl
@@ -0,0 +1,28 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void foo(float x) {
+  switch(int(x)) {
+    default: {
+      break;
+    }
+  }
+}
+
+int baz(int x) {
+  thread int tint_symbol_1 = 0;
+  tint_symbol_1 = 42;
+  return x;
+}
+
+void bar(float x) {
+  switch(baz(int(x))) {
+    default: {
+      break;
+    }
+  }
+}
+
+void tint_symbol() {
+}
+
diff --git a/test/tint/bug/tint/1820.wgsl.expected.spvasm b/test/tint/bug/tint/1820.wgsl.expected.spvasm
new file mode 100644
index 0000000..1a4bc10
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.spvasm
@@ -0,0 +1,65 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 31
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %global "global"
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %foo "foo"
+               OpName %x "x"
+               OpName %baz "baz"
+               OpName %x_0 "x"
+               OpName %bar "bar"
+               OpName %x_1 "x"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+%_ptr_Private_int = OpTypePointer Private %int
+          %4 = OpConstantNull %int
+     %global = OpVariable %_ptr_Private_int Private %4
+       %void = OpTypeVoid
+          %5 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+          %9 = OpTypeFunction %void %float
+         %17 = OpTypeFunction %int %int
+     %int_42 = OpConstant %int 42
+%unused_entry_point = OpFunction %void None %5
+          %8 = OpLabel
+               OpReturn
+               OpFunctionEnd
+        %foo = OpFunction %void None %9
+          %x = OpFunctionParameter %float
+         %13 = OpLabel
+         %15 = OpConvertFToS %int %x
+               OpSelectionMerge %14 None
+               OpSwitch %15 %16
+         %16 = OpLabel
+               OpBranch %14
+         %14 = OpLabel
+               OpReturn
+               OpFunctionEnd
+        %baz = OpFunction %int None %17
+        %x_0 = OpFunctionParameter %int
+         %20 = OpLabel
+               OpStore %global %int_42
+               OpReturnValue %x_0
+               OpFunctionEnd
+        %bar = OpFunction %void None %9
+        %x_1 = OpFunctionParameter %float
+         %24 = OpLabel
+         %27 = OpConvertFToS %int %x_1
+         %26 = OpFunctionCall %int %baz %27
+               OpSelectionMerge %25 None
+               OpSwitch %26 %28
+         %28 = OpLabel
+               OpBranch %25
+         %25 = OpLabel
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %5
+         %30 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/1820.wgsl.expected.wgsl b/test/tint/bug/tint/1820.wgsl.expected.wgsl
new file mode 100644
index 0000000..5fa9cd9
--- /dev/null
+++ b/test/tint/bug/tint/1820.wgsl.expected.wgsl
@@ -0,0 +1,23 @@
+fn foo(x : f32) {
+  switch(i32(x)) {
+    default: {
+    }
+  }
+}
+
+var<private> global : i32;
+
+fn baz(x : i32) -> i32 {
+  global = 42;
+  return x;
+}
+
+fn bar(x : f32) {
+  switch(baz(i32(x))) {
+    default: {
+    }
+  }
+}
+
+fn main() {
+}
diff --git a/test/tint/statements/switch/only_default_case.wgsl.expected.dxc.hlsl b/test/tint/statements/switch/only_default_case.wgsl.expected.dxc.hlsl
index 128b8a2..ded0d66 100644
--- a/test/tint/statements/switch/only_default_case.wgsl.expected.dxc.hlsl
+++ b/test/tint/statements/switch/only_default_case.wgsl.expected.dxc.hlsl
@@ -2,7 +2,6 @@
 void f() {
   int i = 0;
   int result = 0;
-  i;
   do {
     result = 44;
     break;
diff --git a/test/tint/statements/switch/only_default_case.wgsl.expected.fxc.hlsl b/test/tint/statements/switch/only_default_case.wgsl.expected.fxc.hlsl
index 128b8a2..ded0d66 100644
--- a/test/tint/statements/switch/only_default_case.wgsl.expected.fxc.hlsl
+++ b/test/tint/statements/switch/only_default_case.wgsl.expected.fxc.hlsl
@@ -2,7 +2,6 @@
 void f() {
   int i = 0;
   int result = 0;
-  i;
   do {
     result = 44;
     break;