[tint][ir] Clone ControlInstruction results
Fixes a crash in DirectVariableAccess when a short-circuiting
operation is present.
Fixed: tint:2054
Change-Id: I3b64251c08d602b82b5371e1fce5c8cce0cbf7d7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/155860
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/core/ir/if.cc b/src/tint/lang/core/ir/if.cc
index e54d00f..15d1e5d5 100644
--- a/src/tint/lang/core/ir/if.cc
+++ b/src/tint/lang/core/ir/if.cc
@@ -14,13 +14,13 @@
#include "src/tint/lang/core/ir/if.h"
-TINT_INSTANTIATE_TYPEINFO(tint::core::ir::If);
-
#include "src/tint/lang/core/ir/clone_context.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/utils/ice/ice.h"
+TINT_INSTANTIATE_TYPEINFO(tint::core::ir::If);
+
namespace tint::core::ir {
If::If(Value* cond, ir::Block* t, ir::Block* f) : true_(t), false_(f) {
@@ -58,6 +58,9 @@
true_->CloneInto(ctx, new_true);
false_->CloneInto(ctx, new_false);
+
+ new_if->SetResults(ctx.Clone(results_));
+
return new_if;
}
diff --git a/src/tint/lang/core/ir/if_test.cc b/src/tint/lang/core/ir/if_test.cc
index 01a8538..9ff0208 100644
--- a/src/tint/lang/core/ir/if_test.cc
+++ b/src/tint/lang/core/ir/if_test.cc
@@ -92,5 +92,27 @@
EXPECT_EQ(new_if, new_if->True()->Front()->As<ExitIf>()->If());
}
+TEST_F(IR_IfTest, CloneWithResults) {
+ If* new_if = nullptr;
+ auto* r0 = b.InstructionResult(ty.i32());
+ auto* r1 = b.InstructionResult(ty.f32());
+ {
+ auto* if_ = b.If(true);
+ if_->SetResults(Vector{r0, r1});
+ b.Append(if_->True(), [&] { b.ExitIf(if_, b.Constant(42_i), b.Constant(42_f)); });
+ new_if = clone_ctx.Clone(if_);
+ }
+
+ ASSERT_EQ(2u, new_if->Results().Length());
+ auto* new_r0 = new_if->Results()[0]->As<InstructionResult>();
+ ASSERT_NE(new_r0, nullptr);
+ ASSERT_NE(new_r0, r0);
+ EXPECT_EQ(new_r0->Type(), ty.i32());
+ auto* new_r1 = new_if->Results()[1]->As<InstructionResult>();
+ ASSERT_NE(new_r1, nullptr);
+ ASSERT_NE(new_r1, r1);
+ EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
} // namespace
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/loop.cc b/src/tint/lang/core/ir/loop.cc
index ada2516..11b42db 100644
--- a/src/tint/lang/core/ir/loop.cc
+++ b/src/tint/lang/core/ir/loop.cc
@@ -14,8 +14,6 @@
#include "src/tint/lang/core/ir/loop.h"
-#include <utility>
-
#include "src/tint/lang/core/ir/clone_context.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
@@ -56,6 +54,8 @@
body_->CloneInto(ctx, new_body);
continuing_->CloneInto(ctx, new_continuing);
+ new_loop->SetResults(ctx.Clone(results_));
+
return new_loop;
}
diff --git a/src/tint/lang/core/ir/loop_test.cc b/src/tint/lang/core/ir/loop_test.cc
index 89699f0..53773f2 100644
--- a/src/tint/lang/core/ir/loop_test.cc
+++ b/src/tint/lang/core/ir/loop_test.cc
@@ -117,5 +117,27 @@
EXPECT_EQ(new_loop, new_loop->Body()->Back()->As<NextIteration>()->Loop());
}
+TEST_F(IR_LoopTest, CloneWithResults) {
+ Loop* new_loop = nullptr;
+ auto* r0 = b.InstructionResult(ty.i32());
+ auto* r1 = b.InstructionResult(ty.f32());
+ {
+ auto* loop = b.Loop();
+ loop->SetResults(Vector{r0, r1});
+ b.Append(loop->Body(), [&] { b.ExitLoop(loop, b.Constant(42_i), b.Constant(42_f)); });
+ new_loop = clone_ctx.Clone(loop);
+ }
+
+ ASSERT_EQ(2u, new_loop->Results().Length());
+ auto* new_r0 = new_loop->Results()[0]->As<InstructionResult>();
+ ASSERT_NE(new_r0, nullptr);
+ ASSERT_NE(new_r0, r0);
+ EXPECT_EQ(new_r0->Type(), ty.i32());
+ auto* new_r1 = new_loop->Results()[1]->As<InstructionResult>();
+ ASSERT_NE(new_r1, nullptr);
+ ASSERT_NE(new_r1, r1);
+ EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
} // namespace
} // namespace tint::core::ir
diff --git a/src/tint/lang/core/ir/switch.cc b/src/tint/lang/core/ir/switch.cc
index cae733b..395b668 100644
--- a/src/tint/lang/core/ir/switch.cc
+++ b/src/tint/lang/core/ir/switch.cc
@@ -54,6 +54,9 @@
}
new_switch->cases_.Push(new_case);
}
+
+ new_switch->SetResults(ctx.Clone(results_));
+
return new_switch;
}
diff --git a/src/tint/lang/core/ir/switch_test.cc b/src/tint/lang/core/ir/switch_test.cc
index 83cc223..944ba97 100644
--- a/src/tint/lang/core/ir/switch_test.cc
+++ b/src/tint/lang/core/ir/switch_test.cc
@@ -101,5 +101,30 @@
EXPECT_EQ(new_switch, case_.block->Front()->As<ExitSwitch>()->Switch());
}
+TEST_F(IR_SwitchTest, CloneWithResults) {
+ Switch* new_switch = nullptr;
+ auto* r0 = b.InstructionResult(ty.i32());
+ auto* r1 = b.InstructionResult(ty.f32());
+ {
+ auto* switch_ = b.Switch(1_i);
+ switch_->SetResults(Vector{r0, r1});
+
+ auto* blk = b.Block();
+ b.Append(blk, [&] { b.ExitSwitch(switch_, b.Constant(42_i), b.Constant(42_f)); });
+ switch_->Cases().Push(Switch::Case{{Switch::CaseSelector{b.Constant(3_i)}}, blk});
+ new_switch = clone_ctx.Clone(switch_);
+ }
+
+ ASSERT_EQ(2u, new_switch->Results().Length());
+ auto* new_r0 = new_switch->Results()[0]->As<InstructionResult>();
+ ASSERT_NE(new_r0, nullptr);
+ ASSERT_NE(new_r0, r0);
+ EXPECT_EQ(new_r0->Type(), ty.i32());
+ auto* new_r1 = new_switch->Results()[1]->As<InstructionResult>();
+ ASSERT_NE(new_r1, nullptr);
+ ASSERT_NE(new_r1, r1);
+ EXPECT_EQ(new_r1->Type(), ty.f32());
+}
+
} // namespace
} // namespace tint::core::ir
diff --git a/test/tint/bug/tint/2054.wgsl b/test/tint/bug/tint/2054.wgsl
new file mode 100644
index 0000000..a526fa0
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl
@@ -0,0 +1,16 @@
+@group(0) @binding(0)
+var<storage, read_write> out : f32;
+
+fn bar(p : ptr<function, f32>) {
+ let a : f32 = 1.0;
+ let b : f32 = 2.0;
+ let cond = (a >= 0) && (b >= 0);
+ *p = select(a, b, cond);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ var param : f32;
+ bar(¶m);
+ out = param;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..f8409af
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.dxc.hlsl
@@ -0,0 +1,20 @@
+RWByteAddressBuffer tint_symbol : register(u0);
+
+void bar(inout float p) {
+ const float a = 1.0f;
+ const float b = 2.0f;
+ bool tint_tmp = (a >= 0.0f);
+ if (tint_tmp) {
+ tint_tmp = (b >= 0.0f);
+ }
+ const bool cond = (tint_tmp);
+ p = (cond ? b : a);
+}
+
+[numthreads(1, 1, 1)]
+void foo() {
+ float param = 0.0f;
+ bar(param);
+ tint_symbol.Store(0u, asuint(param));
+ return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..f8409af
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.fxc.hlsl
@@ -0,0 +1,20 @@
+RWByteAddressBuffer tint_symbol : register(u0);
+
+void bar(inout float p) {
+ const float a = 1.0f;
+ const float b = 2.0f;
+ bool tint_tmp = (a >= 0.0f);
+ if (tint_tmp) {
+ tint_tmp = (b >= 0.0f);
+ }
+ const bool cond = (tint_tmp);
+ p = (cond ? b : a);
+}
+
+[numthreads(1, 1, 1)]
+void foo() {
+ float param = 0.0f;
+ bar(param);
+ tint_symbol.Store(0u, asuint(param));
+ return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.glsl b/test/tint/bug/tint/2054.wgsl.expected.glsl
new file mode 100644
index 0000000..0216cd0
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.glsl
@@ -0,0 +1,28 @@
+#version 310 es
+
+layout(binding = 0, std430) buffer tint_symbol_block_ssbo {
+ float inner;
+} tint_symbol;
+
+void bar(inout float p) {
+ float a = 1.0f;
+ float b = 2.0f;
+ bool tint_tmp = (a >= 0.0f);
+ if (tint_tmp) {
+ tint_tmp = (b >= 0.0f);
+ }
+ bool cond = (tint_tmp);
+ p = (cond ? b : a);
+}
+
+void foo() {
+ float param = 0.0f;
+ bar(param);
+ tint_symbol.inner = param;
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+ foo();
+ return;
+}
diff --git a/test/tint/bug/tint/2054.wgsl.expected.msl b/test/tint/bug/tint/2054.wgsl.expected.msl
new file mode 100644
index 0000000..ba30f4e
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.msl
@@ -0,0 +1,17 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void bar(thread float* const p) {
+ float const a = 1.0f;
+ float const b = 2.0f;
+ bool const cond = ((a >= 0.0f) && (b >= 0.0f));
+ *(p) = select(a, b, cond);
+}
+
+kernel void foo(device float* tint_symbol [[buffer(0)]]) {
+ float param = 0.0f;
+ bar(&(param));
+ *(tint_symbol) = param;
+ return;
+}
+
diff --git a/test/tint/bug/tint/2054.wgsl.expected.spvasm b/test/tint/bug/tint/2054.wgsl.expected.spvasm
new file mode 100644
index 0000000..2c6e927
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.spvasm
@@ -0,0 +1,59 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 33
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %foo "foo"
+ OpExecutionMode %foo LocalSize 1 1 1
+ OpName %out_block "out_block"
+ OpMemberName %out_block 0 "inner"
+ OpName %out "out"
+ OpName %bar "bar"
+ OpName %p "p"
+ OpName %foo "foo"
+ OpName %param "param"
+ OpDecorate %out_block Block
+ OpMemberDecorate %out_block 0 Offset 0
+ OpDecorate %out DescriptorSet 0
+ OpDecorate %out Binding 0
+ %float = OpTypeFloat 32
+ %out_block = OpTypeStruct %float
+%_ptr_StorageBuffer_out_block = OpTypePointer StorageBuffer %out_block
+ %out = OpVariable %_ptr_StorageBuffer_out_block StorageBuffer
+ %void = OpTypeVoid
+%_ptr_Function_float = OpTypePointer Function %float
+ %5 = OpTypeFunction %void %_ptr_Function_float
+ %float_1 = OpConstant %float 1
+ %float_2 = OpConstant %float 2
+ %13 = OpConstantNull %float
+ %bool = OpTypeBool
+ %22 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
+ %bar = OpFunction %void None %5
+ %p = OpFunctionParameter %_ptr_Function_float
+ %10 = OpLabel
+ %14 = OpFOrdGreaterThanEqual %bool %float_1 %13
+ OpSelectionMerge %16 None
+ OpBranchConditional %14 %17 %16
+ %17 = OpLabel
+ %18 = OpFOrdGreaterThanEqual %bool %float_2 %13
+ OpBranch %16
+ %16 = OpLabel
+ %19 = OpPhi %bool %14 %10 %18 %17
+ %21 = OpSelect %float %19 %float_2 %float_1
+ OpStore %p %21
+ OpReturn
+ OpFunctionEnd
+ %foo = OpFunction %void None %22
+ %24 = OpLabel
+ %param = OpVariable %_ptr_Function_float Function %13
+ %26 = OpFunctionCall %void %bar %param
+ %31 = OpAccessChain %_ptr_StorageBuffer_float %out %uint_0
+ %32 = OpLoad %float %param
+ OpStore %31 %32
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/tint/2054.wgsl.expected.wgsl b/test/tint/bug/tint/2054.wgsl.expected.wgsl
new file mode 100644
index 0000000..e6b417f
--- /dev/null
+++ b/test/tint/bug/tint/2054.wgsl.expected.wgsl
@@ -0,0 +1,15 @@
+@group(0) @binding(0) var<storage, read_write> out : f32;
+
+fn bar(p : ptr<function, f32>) {
+ let a : f32 = 1.0;
+ let b : f32 = 2.0;
+ let cond = ((a >= 0) && (b >= 0));
+ *(p) = select(a, b, cond);
+}
+
+@compute @workgroup_size(1)
+fn foo() {
+ var param : f32;
+ bar(&(param));
+ out = param;
+}