[tint][ir] Clone ControlInstruction results

Fixes a crash in DirectVariableAccess when a short-circuiting
operation is present.

Fixed: tint:2054
Change-Id: I3b64251c08d602b82b5371e1fce5c8cce0cbf7d7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/155860
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/if.cc b/src/tint/lang/core/ir/if.cc
index e54d00f..15d1e5d5 100644
--- a/src/tint/lang/core/ir/if.cc
+++ b/src/tint/lang/core/ir/if.cc
@@ -14,13 +14,13 @@
 
 #include "src/tint/lang/core/ir/if.h"
 
-TINT_INSTANTIATE_TYPEINFO(tint::core::ir::If);
-
 #include "src/tint/lang/core/ir/clone_context.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
 #include "src/tint/utils/ice/ice.h"
 
+TINT_INSTANTIATE_TYPEINFO(tint::core::ir::If);
+
 namespace tint::core::ir {
 
 If::If(Value* cond, ir::Block* t, ir::Block* f) : true_(t), false_(f) {
@@ -58,6 +58,9 @@
 
     true_->CloneInto(ctx, new_true);
     false_->CloneInto(ctx, new_false);
+
+    new_if->SetResults(ctx.Clone(results_));
+
     return new_if;
 }
 
diff --git a/src/tint/lang/core/ir/if_test.cc b/src/tint/lang/core/ir/if_test.cc
index 01a8538..9ff0208 100644
--- a/src/tint/lang/core/ir/if_test.cc
+++ b/src/tint/lang/core/ir/if_test.cc
@@ -92,5 +92,27 @@
     EXPECT_EQ(new_if, new_if->True()->Front()->As<ExitIf>()->If());
 }
 
+TEST_F(IR_IfTest, CloneWithResults) {
+    If* new_if = nullptr;
+    auto* r0 = b.InstructionResult(ty.i32());
+    auto* r1 = b.InstructionResult(ty.f32());
+    {
+        auto* if_ = b.If(true);
+        if_->SetResults(Vector{r0, r1});
+        b.Append(if_->True(), [&] { b.ExitIf(if_, b.Constant(42_i), b.Constant(42_f)); });
+        new_if = clone_ctx.Clone(if_);
+    }
+
+    ASSERT_EQ(2u, new_if->Results().Length());
+    auto* new_r0 = new_if->Results()[0]->As<InstructionResult>();
+    ASSERT_NE(new_r0, nullptr);
+    ASSERT_NE(new_r0, r0);
+    EXPECT_EQ(new_r0->Type(), ty.i32());
+    auto* new_r1 = new_if->Results()[1]->As<InstructionResult>();
+    ASSERT_NE(new_r1, nullptr);
+    ASSERT_NE(new_r1, r1);
+    EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
 }  // namespace
 }  // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/loop.cc b/src/tint/lang/core/ir/loop.cc
index ada2516..11b42db 100644
--- a/src/tint/lang/core/ir/loop.cc
+++ b/src/tint/lang/core/ir/loop.cc
@@ -14,8 +14,6 @@
 
 #include "src/tint/lang/core/ir/loop.h"
 
-#include <utility>
-
 #include "src/tint/lang/core/ir/clone_context.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
@@ -56,6 +54,8 @@
     body_->CloneInto(ctx, new_body);
     continuing_->CloneInto(ctx, new_continuing);
 
+    new_loop->SetResults(ctx.Clone(results_));
+
     return new_loop;
 }
 
diff --git a/src/tint/lang/core/ir/loop_test.cc b/src/tint/lang/core/ir/loop_test.cc
index 89699f0..53773f2 100644
--- a/src/tint/lang/core/ir/loop_test.cc
+++ b/src/tint/lang/core/ir/loop_test.cc
@@ -117,5 +117,27 @@
     EXPECT_EQ(new_loop, new_loop->Body()->Back()->As<NextIteration>()->Loop());
 }
 
+TEST_F(IR_LoopTest, CloneWithResults) {
+    Loop* new_loop = nullptr;
+    auto* r0 = b.InstructionResult(ty.i32());
+    auto* r1 = b.InstructionResult(ty.f32());
+    {
+        auto* loop = b.Loop();
+        loop->SetResults(Vector{r0, r1});
+        b.Append(loop->Body(), [&] { b.ExitLoop(loop, b.Constant(42_i), b.Constant(42_f)); });
+        new_loop = clone_ctx.Clone(loop);
+    }
+
+    ASSERT_EQ(2u, new_loop->Results().Length());
+    auto* new_r0 = new_loop->Results()[0]->As<InstructionResult>();
+    ASSERT_NE(new_r0, nullptr);
+    ASSERT_NE(new_r0, r0);
+    EXPECT_EQ(new_r0->Type(), ty.i32());
+    auto* new_r1 = new_loop->Results()[1]->As<InstructionResult>();
+    ASSERT_NE(new_r1, nullptr);
+    ASSERT_NE(new_r1, r1);
+    EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
 }  // namespace
 }  // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/switch.cc b/src/tint/lang/core/ir/switch.cc
index cae733b..395b668 100644
--- a/src/tint/lang/core/ir/switch.cc
+++ b/src/tint/lang/core/ir/switch.cc
@@ -54,6 +54,9 @@
         }
         new_switch->cases_.Push(new_case);
     }
+
+    new_switch->SetResults(ctx.Clone(results_));
+
     return new_switch;
 }
 
diff --git a/src/tint/lang/core/ir/switch_test.cc b/src/tint/lang/core/ir/switch_test.cc
index 83cc223..944ba97 100644
--- a/src/tint/lang/core/ir/switch_test.cc
+++ b/src/tint/lang/core/ir/switch_test.cc
@@ -101,5 +101,30 @@
     EXPECT_EQ(new_switch, case_.block->Front()->As<ExitSwitch>()->Switch());
 }
 
+TEST_F(IR_SwitchTest, CloneWithResults) {
+    Switch* new_switch = nullptr;
+    auto* r0 = b.InstructionResult(ty.i32());
+    auto* r1 = b.InstructionResult(ty.f32());
+    {
+        auto* switch_ = b.Switch(1_i);
+        switch_->SetResults(Vector{r0, r1});
+
+        auto* blk = b.Block();
+        b.Append(blk, [&] { b.ExitSwitch(switch_, b.Constant(42_i), b.Constant(42_f)); });
+        switch_->Cases().Push(Switch::Case{{Switch::CaseSelector{b.Constant(3_i)}}, blk});
+        new_switch = clone_ctx.Clone(switch_);
+    }
+
+    ASSERT_EQ(2u, new_switch->Results().Length());
+    auto* new_r0 = new_switch->Results()[0]->As<InstructionResult>();
+    ASSERT_NE(new_r0, nullptr);
+    ASSERT_NE(new_r0, r0);
+    EXPECT_EQ(new_r0->Type(), ty.i32());
+    auto* new_r1 = new_switch->Results()[1]->As<InstructionResult>();
+    ASSERT_NE(new_r1, nullptr);
+    ASSERT_NE(new_r1, r1);
+    EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
 }  // namespace
 }  // namespace tint::core::ir
diff --git a/test/tint/bug/tint/2054.wgsl b/test/tint/bug/tint/2054.wgsl
new file mode 100644
index 0000000..a526fa0
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl
@@ -0,0 +1,16 @@
+@group(0) @binding(0)
+var<storage, read_write> out : f32;
+
+fn bar(p : ptr<function, f32>) {
+  let a : f32 = 1.0;
+  let b : f32 = 2.0;
+  let cond = (a >= 0) && (b >= 0);
+  *p = select(a, b, cond);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+  var param : f32;
+  bar(&param);
+  out = param;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..f8409af
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl
@@ -0,0 +1,20 @@
+RWByteAddressBuffer tint_symbol : register(u0);
+
+void bar(inout float p) {
+  const float a = 1.0f;
+  const float b = 2.0f;
+  bool tint_tmp = (a >= 0.0f);
+  if (tint_tmp) {
+    tint_tmp = (b >= 0.0f);
+  }
+  const bool cond = (tint_tmp);
+  p = (cond ? b : a);
+}
+
+[numthreads(1, 1, 1)]
+void foo() {
+  float param = 0.0f;
+  bar(param);
+  tint_symbol.Store(0u, asuint(param));
+  return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..f8409af
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl
@@ -0,0 +1,20 @@
+RWByteAddressBuffer tint_symbol : register(u0);
+
+void bar(inout float p) {
+  const float a = 1.0f;
+  const float b = 2.0f;
+  bool tint_tmp = (a >= 0.0f);
+  if (tint_tmp) {
+    tint_tmp = (b >= 0.0f);
+  }
+  const bool cond = (tint_tmp);
+  p = (cond ? b : a);
+}
+
+[numthreads(1, 1, 1)]
+void foo() {
+  float param = 0.0f;
+  bar(param);
+  tint_symbol.Store(0u, asuint(param));
+  return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.glsl b/test/tint/bug/tint/2054.wgsl.expected.glsl
new file mode 100644
index 0000000..0216cd0
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.glsl
@@ -0,0 +1,28 @@
+#version 310 es
+
+layout(binding = 0, std430) buffer tint_symbol_block_ssbo {
+  float inner;
+} tint_symbol;
+
+void bar(inout float p) {
+  float a = 1.0f;
+  float b = 2.0f;
+  bool tint_tmp = (a >= 0.0f);
+  if (tint_tmp) {
+    tint_tmp = (b >= 0.0f);
+  }
+  bool cond = (tint_tmp);
+  p = (cond ? b : a);
+}
+
+void foo() {
+  float param = 0.0f;
+  bar(param);
+  tint_symbol.inner = param;
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+  foo();
+  return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.msl b/test/tint/bug/tint/2054.wgsl.expected.msl
new file mode 100644
index 0000000..ba30f4e
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.msl
@@ -0,0 +1,17 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void bar(thread float* const p) {
+  float const a = 1.0f;
+  float const b = 2.0f;
+  bool const cond = ((a >= 0.0f) && (b >= 0.0f));
+  *(p) = select(a, b, cond);
+}
+
+kernel void foo(device float* tint_symbol [[buffer(0)]]) {
+  float param = 0.0f;
+  bar(&(param));
+  *(tint_symbol) = param;
+  return;
+}
+
diff --git a/test/tint/bug/tint/2054.wgsl.expected.spvasm b/test/tint/bug/tint/2054.wgsl.expected.spvasm
new file mode 100644
index 0000000..2c6e927
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.spvasm
@@ -0,0 +1,59 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 33
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %foo "foo"
+               OpExecutionMode %foo LocalSize 1 1 1
+               OpName %out_block "out_block"
+               OpMemberName %out_block 0 "inner"
+               OpName %out "out"
+               OpName %bar "bar"
+               OpName %p "p"
+               OpName %foo "foo"
+               OpName %param "param"
+               OpDecorate %out_block Block
+               OpMemberDecorate %out_block 0 Offset 0
+               OpDecorate %out DescriptorSet 0
+               OpDecorate %out Binding 0
+      %float = OpTypeFloat 32
+  %out_block = OpTypeStruct %float
+%_ptr_StorageBuffer_out_block = OpTypePointer StorageBuffer %out_block
+        %out = OpVariable %_ptr_StorageBuffer_out_block StorageBuffer
+       %void = OpTypeVoid
+%_ptr_Function_float = OpTypePointer Function %float
+          %5 = OpTypeFunction %void %_ptr_Function_float
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+         %13 = OpConstantNull %float
+       %bool = OpTypeBool
+         %22 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
+        %bar = OpFunction %void None %5
+          %p = OpFunctionParameter %_ptr_Function_float
+         %10 = OpLabel
+         %14 = OpFOrdGreaterThanEqual %bool %float_1 %13
+               OpSelectionMerge %16 None
+               OpBranchConditional %14 %17 %16
+         %17 = OpLabel
+         %18 = OpFOrdGreaterThanEqual %bool %float_2 %13
+               OpBranch %16
+         %16 = OpLabel
+         %19 = OpPhi %bool %14 %10 %18 %17
+         %21 = OpSelect %float %19 %float_2 %float_1
+               OpStore %p %21
+               OpReturn
+               OpFunctionEnd
+        %foo = OpFunction %void None %22
+         %24 = OpLabel
+      %param = OpVariable %_ptr_Function_float Function %13
+         %26 = OpFunctionCall %void %bar %param
+         %31 = OpAccessChain %_ptr_StorageBuffer_float %out %uint_0
+         %32 = OpLoad %float %param
+               OpStore %31 %32
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/2054.wgsl.expected.wgsl b/test/tint/bug/tint/2054.wgsl.expected.wgsl
new file mode 100644
index 0000000..e6b417f
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.wgsl
@@ -0,0 +1,15 @@
+@group(0) @binding(0) var<storage, read_write> out : f32;
+
+fn bar(p : ptr<function, f32>) {
+  let a : f32 = 1.0;
+  let b : f32 = 2.0;
+  let cond = ((a >= 0) && (b >= 0));
+  *(p) = select(a, b, cond);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+  var param : f32;
+  bar(&(param));
+  out = param;
+}