[IR] Add switch control flow node.

This CL updates the IR builder to create control flow nodes for
a switch statement and the contained case statements.

Bug: tint:1718
Change-Id: I05b73db11ab14676cc123f436ae5912b1dbee0d5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107801
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 594b279..b023a0a 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -71,6 +71,21 @@
     return ir_loop;
 }
 
+Switch* Builder::CreateSwitch(const ast::SwitchStatement* stmt) {
+    auto* ir_switch = ir.flow_nodes.Create<Switch>(stmt);
+    ir_switch->merge_target = CreateBlock();
+    return ir_switch;
+}
+
+Block* Builder::CreateCase(Switch* s, const utils::VectorRef<const ast::CaseSelector*> selectors) {
+    s->cases.Push(Switch::Case{selectors, CreateBlock()});
+
+    Block* b = s->cases.Back().start_target;
+    // Switch branches into the case block
+    b->inbound_branches.Push(s);
+    return b;
+}
+
 void Builder::Branch(Block* from, FlowNode* to) {
     TINT_ASSERT(IR, from);
     TINT_ASSERT(IR, to);
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index a60e7b0..23d4da4 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -26,6 +26,9 @@
 namespace tint {
 class Program;
 }  // namespace tint
+namespace tint::ast {
+class CaseSelector;
+}  // namespace tint::ast
 
 namespace tint::ir {
 
@@ -62,6 +65,17 @@
     /// @returns the flow node
     Loop* CreateLoop(const ast::LoopStatement* stmt);
 
+    /// Creates a switch flow node for the given ast::SwitchStatement
+    /// @param stmt the ast::SwitchStatment
+    /// @returns the flow node
+    Switch* CreateSwitch(const ast::SwitchStatement* stmt);
+
+    /// Creates a case flow node for the given case branch.
+    /// @param s the switch to create the case into
+    /// @param selectors the case selectors for the case statement
+    /// @returns the start block for the case flow node
+    Block* CreateCase(Switch* s, const utils::VectorRef<const ast::CaseSelector*> selectors);
+
     /// Branches the given block to the given flow node.
     /// @param from the block to branch from
     /// @param to the node to branch too
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index 9c5caff..895e316 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -24,6 +24,7 @@
 #include "src/tint/ast/return_statement.h"
 #include "src/tint/ast/statement.h"
 #include "src/tint/ast/static_assert.h"
+#include "src/tint/ast/switch_statement.h"
 #include "src/tint/ir/function.h"
 #include "src/tint/ir/if.h"
 #include "src/tint/ir/loop.h"
@@ -209,7 +210,7 @@
         //        [&](const ast::ForLoopStatement* l) { },
         //        [&](const ast::WhileStatement* l) { },
         [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
-        //        [&](const ast::SwitchStatement* s) { },
+        [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
         //        [&](const ast::VariableDeclStatement* v) { },
         [&](const ast::StaticAssert*) {
             return true;  // Not emitted
@@ -254,15 +255,6 @@
     }
     current_flow_block_ = nullptr;
 
-    // If both branches went somewhere, then they both returned, continued or broke. So,
-    // there is no need for the if merge-block and there is nothing to branch to the merge
-    // block anyway.
-    if (IsBranched(if_node->true_target) && IsBranched(if_node->false_target)) {
-        return true;
-    }
-
-    current_flow_block_ = if_node->merge_target;
-
     // If the true branch did not execute control flow, then go to the merge target
     if (!IsBranched(if_node->true_target)) {
         builder_.Branch(if_node->true_target, if_node->merge_target);
@@ -272,6 +264,13 @@
         builder_.Branch(if_node->false_target, if_node->merge_target);
     }
 
+    // If both branches went somewhere, then they both returned, continued or broke. So,
+    // there is no need for the if merge-block and there is nothing to branch to the merge
+    // block anyway.
+    if (IsConnected(if_node->merge_target)) {
+        current_flow_block_ = if_node->merge_target;
+    }
+
     return true;
 }
 
@@ -313,6 +312,35 @@
     return true;
 }
 
+bool BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
+    auto* switch_node = builder_.CreateSwitch(stmt);
+
+    // TODO(dsinclair): Emit the condition expression into the current block
+
+    BranchTo(switch_node);
+
+    ast_to_flow_[stmt] = switch_node;
+
+    {
+        FlowStackScope scope(this, switch_node);
+
+        for (const auto* c : stmt->body) {
+            current_flow_block_ = builder_.CreateCase(switch_node, c->selectors);
+            if (!EmitStatement(c->body)) {
+                return false;
+            }
+            BranchToIfNeeded(switch_node->merge_target);
+        }
+    }
+    current_flow_block_ = nullptr;
+
+    if (IsConnected(switch_node->merge_target)) {
+        current_flow_block_ = switch_node->merge_target;
+    }
+
+    return true;
+}
+
 bool BuilderImpl::EmitReturn(const ast::ReturnStatement*) {
     // TODO(dsinclair): Emit the return value ....
 
diff --git a/src/tint/ir/builder_impl.h b/src/tint/ir/builder_impl.h
index 1ca2482..ac76bd7 100644
--- a/src/tint/ir/builder_impl.h
+++ b/src/tint/ir/builder_impl.h
@@ -102,6 +102,11 @@
     /// @returns true if successful, false otherwise.
     bool EmitLoop(const ast::LoopStatement* stmt);
 
+    /// Emits a switch statement
+    /// @param stmt the switch statement
+    /// @returns true if successfull, false otherwise.
+    bool EmitSwitch(const ast::SwitchStatement* stmt);
+
     /// Emits a break statement
     /// @param stmt the break statement
     /// @returns true if successfull, false otherwise.
diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc
index eb932da..826cd77e 100644
--- a/src/tint/ir/builder_impl_test.cc
+++ b/src/tint/ir/builder_impl_test.cc
@@ -14,9 +14,14 @@
 
 #include "src/tint/ir/test_helper.h"
 
+#include "src/tint/ast/case_selector.h"
+#include "src/tint/ast/int_literal_expression.h"
+
 namespace tint::ir {
 namespace {
 
+using namespace tint::number_suffixes;  // NOLINT
+
 using IRBuilderImplTest = TestHelper;
 
 TEST_F(IRBuilderImplTest, Func) {
@@ -817,5 +822,222 @@
     EXPECT_EQ(loop_flow_a->merge_target->branch_target, func->end_target);
 }
 
+TEST_F(IRBuilderImplTest, Switch) {
+    // func -> switch -> case 1
+    //                -> case 2
+    //                -> default
+    //
+    //   [case 1] -> switch merge
+    //   [case 2] -> switch merge
+    //   [default] -> switch merge
+    //   [switch merge] -> func end
+    //
+    auto* ast_switch = Switch(
+        1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)}, Block()),
+                           Case(utils::Vector{CaseSelector(1_i)}, Block()), DefaultCase(Block())});
+
+    WrapInFunction(ast_switch);
+    auto& b = Build();
+
+    auto r = b.Build();
+    ASSERT_TRUE(r) << b.error();
+    auto m = r.Move();
+
+    auto* ir_switch = b.FlowNodeForAstNode(ast_switch);
+    ASSERT_NE(ir_switch, nullptr);
+    ASSERT_TRUE(ir_switch->Is<ir::Switch>());
+
+    auto* flow = ir_switch->As<ir::Switch>();
+    ASSERT_NE(flow->merge_target, nullptr);
+    ASSERT_EQ(3u, flow->cases.Length());
+
+    ASSERT_EQ(1u, m.functions.Length());
+    auto* func = m.functions[0];
+
+    ASSERT_EQ(1u, flow->cases[0].selectors.Length());
+    ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+
+    ASSERT_EQ(1u, flow->cases[1].selectors.Length());
+    ASSERT_TRUE(flow->cases[1].selectors[0]->expr->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(1_i, flow->cases[1].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+
+    ASSERT_EQ(1u, flow->cases[2].selectors.Length());
+    EXPECT_TRUE(flow->cases[2].selectors[0]->IsDefault());
+
+    EXPECT_EQ(1u, flow->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[2].start_target->inbound_branches.Length());
+    EXPECT_EQ(3u, flow->merge_target->inbound_branches.Length());
+    EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
+
+    EXPECT_EQ(func->start_target->branch_target, ir_switch);
+    EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->cases[1].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->cases[2].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->merge_target->branch_target, func->end_target);
+}
+
+TEST_F(IRBuilderImplTest, Switch_OnlyDefault) {
+    // func -> switch -> default -> switch merge -> func end
+    //
+    auto* ast_switch = Switch(1_i, utils::Vector{DefaultCase(Block())});
+
+    WrapInFunction(ast_switch);
+    auto& b = Build();
+
+    auto r = b.Build();
+    ASSERT_TRUE(r) << b.error();
+    auto m = r.Move();
+
+    auto* ir_switch = b.FlowNodeForAstNode(ast_switch);
+    ASSERT_NE(ir_switch, nullptr);
+    ASSERT_TRUE(ir_switch->Is<ir::Switch>());
+
+    auto* flow = ir_switch->As<ir::Switch>();
+    ASSERT_NE(flow->merge_target, nullptr);
+    ASSERT_EQ(1u, flow->cases.Length());
+
+    ASSERT_EQ(1u, m.functions.Length());
+    auto* func = m.functions[0];
+
+    ASSERT_EQ(1u, flow->cases[0].selectors.Length());
+    EXPECT_TRUE(flow->cases[0].selectors[0]->IsDefault());
+
+    EXPECT_EQ(1u, flow->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->merge_target->inbound_branches.Length());
+    EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
+
+    EXPECT_EQ(func->start_target->branch_target, ir_switch);
+    EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->merge_target->branch_target, func->end_target);
+}
+
+TEST_F(IRBuilderImplTest, Switch_WithBreak) {
+    // {
+    //   switch(1) {
+    //     case 0: {
+    //       break;
+    //       if true { return;}   // Dead code
+    //     }
+    //     default: {}
+    //   }
+    // }
+    //
+    // func -> switch -> case 1
+    //                -> default
+    //
+    //   [case 1] -> switch merge
+    //   [default] -> switch merge
+    //   [switch merge] -> func end
+    auto* ast_switch = Switch(1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)},
+                                                      Block(Break(), If(true, Block(Return())))),
+                                                 DefaultCase(Block())});
+
+    WrapInFunction(ast_switch);
+    auto& b = Build();
+
+    auto r = b.Build();
+    ASSERT_TRUE(r) << b.error();
+    auto m = r.Move();
+
+    auto* ir_switch = b.FlowNodeForAstNode(ast_switch);
+    ASSERT_NE(ir_switch, nullptr);
+    ASSERT_TRUE(ir_switch->Is<ir::Switch>());
+
+    auto* flow = ir_switch->As<ir::Switch>();
+    ASSERT_NE(flow->merge_target, nullptr);
+    ASSERT_EQ(2u, flow->cases.Length());
+
+    ASSERT_EQ(1u, m.functions.Length());
+    auto* func = m.functions[0];
+
+    ASSERT_EQ(1u, flow->cases[0].selectors.Length());
+    ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+
+    ASSERT_EQ(1u, flow->cases[1].selectors.Length());
+    EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault());
+
+    EXPECT_EQ(1u, flow->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length());
+    EXPECT_EQ(2u, flow->merge_target->inbound_branches.Length());
+    // This is 1 because the if is dead-code eliminated and the return doesn't happen.
+    EXPECT_EQ(1u, func->end_target->inbound_branches.Length());
+
+    EXPECT_EQ(func->start_target->branch_target, ir_switch);
+    EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->cases[1].start_target->branch_target, flow->merge_target);
+    EXPECT_EQ(flow->merge_target->branch_target, func->end_target);
+}
+
+TEST_F(IRBuilderImplTest, Switch_AllReturn) {
+    // {
+    //   switch(1) {
+    //     case 0: {
+    //       return;
+    //     }
+    //     default: {
+    //       return;
+    //     }
+    //   }
+    //   if true { return; }  // Dead code
+    // }
+    //
+    // func -> switch -> case 1
+    //                -> default
+    //
+    //   [case 1] -> func end
+    //   [default] -> func end
+    //   [switch merge] -> nullptr
+    //
+    auto* ast_switch =
+        Switch(1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)}, Block(Return())),
+                                  DefaultCase(Block(Return()))});
+
+    auto* ast_if = If(true, Block(Return()));
+
+    WrapInFunction(ast_switch, ast_if);
+    auto& b = Build();
+
+    auto r = b.Build();
+    ASSERT_TRUE(r) << b.error();
+    auto m = r.Move();
+
+    ASSERT_EQ(b.FlowNodeForAstNode(ast_if), nullptr);
+
+    auto* ir_switch = b.FlowNodeForAstNode(ast_switch);
+    ASSERT_NE(ir_switch, nullptr);
+    ASSERT_TRUE(ir_switch->Is<ir::Switch>());
+
+    auto* flow = ir_switch->As<ir::Switch>();
+    ASSERT_NE(flow->merge_target, nullptr);
+    ASSERT_EQ(2u, flow->cases.Length());
+
+    ASSERT_EQ(1u, m.functions.Length());
+    auto* func = m.functions[0];
+
+    ASSERT_EQ(1u, flow->cases[0].selectors.Length());
+    ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As<ast::IntLiteralExpression>()->value);
+
+    ASSERT_EQ(1u, flow->cases[1].selectors.Length());
+    EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault());
+
+    EXPECT_EQ(1u, flow->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length());
+    EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length());
+    EXPECT_EQ(0u, flow->merge_target->inbound_branches.Length());
+    EXPECT_EQ(2u, func->end_target->inbound_branches.Length());
+
+    EXPECT_EQ(func->start_target->branch_target, ir_switch);
+    EXPECT_EQ(flow->cases[0].start_target->branch_target, func->end_target);
+    EXPECT_EQ(flow->cases[1].start_target->branch_target, func->end_target);
+    EXPECT_EQ(flow->merge_target->branch_target, nullptr);
+}
+
 }  // namespace
 }  // namespace tint::ir
diff --git a/src/tint/ir/switch.cc b/src/tint/ir/switch.cc
index 9ad6d30..23b7fbb 100644
--- a/src/tint/ir/switch.cc
+++ b/src/tint/ir/switch.cc
@@ -18,7 +18,7 @@
 
 namespace tint::ir {
 
-Switch::Switch() : Base() {}
+Switch::Switch(const ast::SwitchStatement* stmt) : Base(), source(stmt) {}
 
 Switch::~Switch() = default;
 
diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h
index e9de3ae..39d3d06 100644
--- a/src/tint/ir/switch.h
+++ b/src/tint/ir/switch.h
@@ -18,17 +18,38 @@
 #include "src/tint/ir/block.h"
 #include "src/tint/ir/flow_node.h"
 
+// Forward declarations
+namespace tint::ast {
+class CaseSelector;
+class SwitchStatement;
+}  // namespace tint::ast
+
 namespace tint::ir {
 
 /// Flow node representing a switch statement
 class Switch : public Castable<Switch, FlowNode> {
   public:
+    /// A case label in the struct
+    struct Case {
+        /// The case selector for this node
+        const utils::VectorRef<const ast::CaseSelector*> selectors;
+        /// The start block for the case block.
+        Block* start_target;
+    };
+
     /// Constructor
-    Switch();
+    /// @param stmt the originating ast switch statement
+    explicit Switch(const ast::SwitchStatement* stmt);
     ~Switch() override;
 
+    /// The originating switch statment in the AST
+    const ast::SwitchStatement* source;
+
     /// The switch merge target
     Block* merge_target;
+
+    /// The switch case statements
+    utils::Vector<Case, 4> cases;
 };
 
 }  // namespace tint::ir