[spirv-writer] Add fallthrough support

This CL adds support for the fallthrough statement in a `case` block.

Bug: tint:5
Change-Id: I282643a304846a19212d41bd8bd20a60398bd793
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22220
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index d633f9c..7169411 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -89,6 +89,10 @@
   return model;
 }
 
+bool LastIsFallthrough(const ast::StatementList& stmts) {
+  return !stmts.empty() && stmts.back()->IsFallthrough();
+}
+
 // A terminator is anything which will case a SPIR-V terminator to be emitted.
 // This means things like breaks, fallthroughs and continues which all emit an
 // OpBranch or return for the OpReturn emission.
@@ -1395,7 +1399,13 @@
       return false;
     }
 
-    if (!LastIsTerminator(item->body())) {
+    if (LastIsFallthrough(item->body())) {
+      if (i == (body.size() - 1)) {
+        error_ = "fallthrough of last case statement is disallowed";
+        return false;
+      }
+      push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])});
+    } else if (!LastIsTerminator(item->body())) {
       push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
     }
   }
@@ -1491,6 +1501,10 @@
   if (stmt->IsContinue()) {
     return GenerateContinueStatement(stmt->AsContinue());
   }
+  if (stmt->IsFallthrough()) {
+    // Do nothing here, the fallthrough gets handled by the switch code.
+    return true;
+  }
   if (stmt->IsIf()) {
     return GenerateIfStatement(stmt->AsIf());
   }
diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc
index 1284345..d717c08 100644
--- a/src/writer/spirv/builder_switch_test.cc
+++ b/src/writer/spirv/builder_switch_test.cc
@@ -19,6 +19,7 @@
 #include "src/ast/bool_literal.h"
 #include "src/ast/break_statement.h"
 #include "src/ast/case_statement.h"
+#include "src/ast/fallthrough_statement.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/int_literal.h"
@@ -321,15 +322,145 @@
 )");
 }
 
-TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) {
-  // switch (a) {
+TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
+  ast::type::I32Type i32;
+
+  // switch(a) {
   //   case 1:
-  //     v = 1;
-  //     fallthrough;
+  //      v = 1;
+  //      fallthrough;
   //   case 2:
-  //     v = 2;
+  //      v = 2;
+  //   default:
+  //      v = 3;
   //  }
-  FAIL();
+
+  auto v =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+  auto a =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
+
+  ast::StatementList case_1_body;
+  case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 1))));
+  case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
+
+  ast::StatementList case_2_body;
+  case_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 2))));
+
+  ast::StatementList default_body;
+  default_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 3))));
+
+  ast::CaseStatementList cases;
+  cases.push_back(std::make_unique<ast::CaseStatement>(
+      std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
+  cases.push_back(std::make_unique<ast::CaseStatement>(
+      std::make_unique<ast::IntLiteral>(&i32, 2), std::move(case_2_body)));
+  cases.push_back(
+      std::make_unique<ast::CaseStatement>(std::move(default_body)));
+
+  ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
+                            std::move(cases));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(v.get());
+  td.RegisterVariableForTesting(a.get());
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  ast::Function func("a_func", {}, &i32);
+
+  Builder b(&mod);
+  ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
+  ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+
+  EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
+OpName %5 "a"
+OpName %7 "a_func"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%5 = OpVariable %2 Private %4
+%6 = OpTypeFunction %3
+%14 = OpConstant %3 1
+%15 = OpConstant %3 2
+%16 = OpConstant %3 3
+%7 = OpFunction %3 None %6
+%8 = OpLabel
+%10 = OpLoad %3 %5
+OpSelectionMerge %9 None
+OpSwitch %10 %11 1 %12 2 %13
+%12 = OpLabel
+OpStore %1 %14
+OpBranch %13
+%13 = OpLabel
+OpStore %1 %15
+OpBranch %9
+%11 = OpLabel
+OpStore %1 %16
+OpBranch %9
+%9 = OpLabel
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
+  ast::type::I32Type i32;
+
+  // switch(a) {
+  //   case 1:
+  //      v = 1;
+  //      fallthrough;
+  //  }
+
+  auto v =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+  auto a =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
+
+  ast::StatementList case_1_body;
+  case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 1))));
+  case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
+
+  ast::CaseStatementList cases;
+  cases.push_back(std::make_unique<ast::CaseStatement>(
+      std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
+
+  ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
+                            std::move(cases));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(v.get());
+  td.RegisterVariableForTesting(a.get());
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  ast::Function func("a_func", {}, &i32);
+
+  Builder b(&mod);
+  ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
+  ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+
+  EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error();
+  EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed");
 }
 
 // TODO(dsinclair): Implement when parsing is handled for multi-value