Convert CaseSelector to IR.
This CL converts the case selectors over from ast CaseSelectors to IR
CaseSelectors. They work the same way in that a `nullptr` value signals
a `default` selector but they only store the resulting `constant::Value`
instead of the `ast::Expression`.
Bug: tint:1718
Change-Id: Ied62d661e03a7f8da4c1e1bdaccc04f21ab38111
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116364
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index e65c2fc..6bf98ef 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -77,7 +77,7 @@
return ir_switch;
}
-Block* Builder::CreateCase(Switch* s, utils::VectorRef<const ast::CaseSelector*> selectors) {
+Block* Builder::CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors) {
s->cases.Push(Switch::Case{selectors, CreateBlock()});
Block* b = s->cases.Back().start_target;
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index ad9e2e1..e444996 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -39,9 +39,6 @@
namespace tint {
class Program;
} // namespace tint
-namespace tint::ast {
-class CaseSelector;
-} // namespace tint::ast
namespace tint::ir {
@@ -87,7 +84,7 @@
/// @param s the switch to create the case into
/// @param selectors the case selectors for the case statement
/// @returns the start block for the case flow node
- Block* CreateCase(Switch* s, utils::VectorRef<const ast::CaseSelector*> selectors);
+ Block* CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors);
/// Branches the given block to the given flow node.
/// @param from the block to branch from
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index 7355db8..ccbe0e9 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -44,6 +44,7 @@
#include "src/tint/program.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/module.h"
+#include "src/tint/sem/switch_statement.h"
namespace tint::ir {
namespace {
@@ -437,9 +438,19 @@
{
FlowStackScope scope(this, switch_node);
- for (const auto* c : stmt->body) {
- current_flow_block = builder.CreateCase(switch_node, c->selectors);
- if (!EmitStatement(c->body)) {
+ const auto* sem = builder.ir.program->Sem().Get(stmt);
+ for (const auto* c : sem->Cases()) {
+ utils::Vector<Switch::CaseSelector, 4> selectors;
+ for (const auto* selector : c->Selectors()) {
+ if (selector->IsDefault()) {
+ selectors.Push({nullptr});
+ } else {
+ selectors.Push({selector->Value()->Clone(clone_ctx_)});
+ }
+ }
+
+ current_flow_block = builder.CreateCase(switch_node, selectors);
+ if (!EmitStatement(c->Body()->Declaration())) {
return false;
}
BranchToIfNeeded(switch_node->merge_target);
diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc
index ced23f6..6b6cd20 100644
--- a/src/tint/ir/builder_impl_test.cc
+++ b/src/tint/ir/builder_impl_test.cc
@@ -16,6 +16,7 @@
#include "src/tint/ast/case_selector.h"
#include "src/tint/ast/int_literal_expression.h"
+#include "src/tint/constant/scalar.h"
namespace tint::ir {
namespace {
@@ -1151,15 +1152,15 @@
auto* func = m.functions[0];
ASSERT_EQ(1u, flow->cases[0].selectors.Length());
- ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
- EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+ ASSERT_TRUE(flow->cases[0].selectors[0].val->Is<constant::Scalar<tint::i32>>());
+ EXPECT_EQ(0_i, flow->cases[0].selectors[0].val->As<constant::Scalar<tint::i32>>()->ValueOf());
ASSERT_EQ(1u, flow->cases[1].selectors.Length());
- ASSERT_TRUE(flow->cases[1].selectors[0]->expr->Is<ast::IntLiteralExpression>());
- EXPECT_EQ(1_i, flow->cases[1].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+ ASSERT_TRUE(flow->cases[1].selectors[0].val->Is<constant::Scalar<tint::i32>>());
+ EXPECT_EQ(1_i, flow->cases[1].selectors[0].val->As<constant::Scalar<tint::i32>>()->ValueOf());
ASSERT_EQ(1u, flow->cases[2].selectors.Length());
- EXPECT_TRUE(flow->cases[2].selectors[0]->IsDefault());
+ EXPECT_TRUE(flow->cases[2].selectors[0].IsDefault());
EXPECT_EQ(1u, flow->inbound_branches.Length());
EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
@@ -1205,7 +1206,7 @@
auto* func = m.functions[0];
ASSERT_EQ(1u, flow->cases[0].selectors.Length());
- EXPECT_TRUE(flow->cases[0].selectors[0]->IsDefault());
+ EXPECT_TRUE(flow->cases[0].selectors[0].IsDefault());
EXPECT_EQ(1u, flow->inbound_branches.Length());
EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
@@ -1257,11 +1258,11 @@
auto* func = m.functions[0];
ASSERT_EQ(1u, flow->cases[0].selectors.Length());
- ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
- EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+ ASSERT_TRUE(flow->cases[0].selectors[0].val->Is<constant::Scalar<tint::i32>>());
+ EXPECT_EQ(0_i, flow->cases[0].selectors[0].val->As<constant::Scalar<tint::i32>>()->ValueOf());
ASSERT_EQ(1u, flow->cases[1].selectors.Length());
- EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault());
+ EXPECT_TRUE(flow->cases[1].selectors[0].IsDefault());
EXPECT_EQ(1u, flow->inbound_branches.Length());
EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
@@ -1323,11 +1324,11 @@
auto* func = m.functions[0];
ASSERT_EQ(1u, flow->cases[0].selectors.Length());
- ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
- EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+ ASSERT_TRUE(flow->cases[0].selectors[0].val->Is<constant::Scalar<tint::i32>>());
+ EXPECT_EQ(0_i, flow->cases[0].selectors[0].val->As<constant::Scalar<tint::i32>>()->ValueOf());
ASSERT_EQ(1u, flow->cases[1].selectors.Length());
- EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault());
+ EXPECT_TRUE(flow->cases[1].selectors[0].IsDefault());
EXPECT_EQ(1u, flow->inbound_branches.Length());
EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h
index 7fff0a0..0a8d989 100644
--- a/src/tint/ir/switch.h
+++ b/src/tint/ir/switch.h
@@ -15,13 +15,13 @@
#ifndef SRC_TINT_IR_SWITCH_H_
#define SRC_TINT_IR_SWITCH_H_
+#include "src/tint/constant/value.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/flow_node.h"
#include "src/tint/ir/value.h"
// Forward declarations
namespace tint::ast {
-class CaseSelector;
class SwitchStatement;
} // namespace tint::ast
@@ -30,10 +30,19 @@
/// Flow node representing a switch statement
class Switch : public Castable<Switch, FlowNode> {
public:
+ /// A case selector
+ struct CaseSelector {
+ /// @returns true if this is a default selector
+ bool IsDefault() const { return val == nullptr; }
+
+ /// The selector value, or nullptr if this is the default selector
+ constant::Value* val = nullptr;
+ };
+
/// A case label in the struct
struct Case {
/// The case selector for this node
- utils::Vector<const ast::CaseSelector*, 4> selectors;
+ utils::Vector<CaseSelector, 4> selectors;
/// The start block for the case block.
Block* start_target;
};