Add case_selectors update

This CL adds the missing case_selectors option from the grammar updates.

Change-Id: Ia6c110e917dd574711d396fb34ad53a2a67cf1fe
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22306
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc
index 8aca58c..0989217 100644
--- a/src/ast/case_statement.cc
+++ b/src/ast/case_statement.cc
@@ -19,15 +19,14 @@
 
 CaseStatement::CaseStatement() : Statement() {}
 
-CaseStatement::CaseStatement(std::unique_ptr<Literal> condition,
-                             StatementList body)
-    : Statement(), condition_(std::move(condition)), body_(std::move(body)) {}
+CaseStatement::CaseStatement(CaseSelectorList conditions, StatementList body)
+    : Statement(), conditions_(std::move(conditions)), body_(std::move(body)) {}
 
 CaseStatement::CaseStatement(const Source& source,
-                             std::unique_ptr<Literal> condition,
+                             CaseSelectorList conditions,
                              StatementList body)
     : Statement(source),
-      condition_(std::move(condition)),
+      conditions_(std::move(conditions)),
       body_(std::move(body)) {}
 
 CaseStatement::CaseStatement(CaseStatement&&) = default;
@@ -52,7 +51,16 @@
   if (IsDefault()) {
     out << "Default{" << std::endl;
   } else {
-    out << "Case " << condition_->to_str() << "{" << std::endl;
+    out << "Case ";
+    bool first = true;
+    for (const auto& lit : conditions_) {
+      if (!first)
+        out << ", ";
+
+      first = false;
+      out << lit->to_str();
+    }
+    out << "{" << std::endl;
   }
 
   for (const auto& stmt : body_)
diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h
index 5d2c24e..11dcd33 100644
--- a/src/ast/case_statement.h
+++ b/src/ast/case_statement.h
@@ -27,35 +27,38 @@
 namespace tint {
 namespace ast {
 
+/// A list of case literals
+using CaseSelectorList = std::vector<std::unique_ptr<ast::Literal>>;
+
 /// A case statement
 class CaseStatement : public Statement {
  public:
   /// Constructor
   CaseStatement();
   /// Constructor
-  /// @param condition the case condition
+  /// @param conditions the case conditions
   /// @param body the case body
-  CaseStatement(std::unique_ptr<Literal> condition, StatementList body);
+  CaseStatement(CaseSelectorList conditions, StatementList body);
   /// Constructor
   /// @param source the source information
-  /// @param condition the case condition
+  /// @param conditions the case conditions
   /// @param body the case body
   CaseStatement(const Source& source,
-                std::unique_ptr<Literal> condition,
+                CaseSelectorList conditions,
                 StatementList body);
   /// Move constructor
   CaseStatement(CaseStatement&&);
   ~CaseStatement() override;
 
-  /// Sets the condition for the case statement
-  /// @param condition the condition to set
-  void set_condition(std::unique_ptr<Literal> condition) {
-    condition_ = std::move(condition);
+  /// Sets the conditions for the case statement
+  /// @param conditions the conditions to set
+  void set_conditions(CaseSelectorList conditions) {
+    conditions_ = std::move(conditions);
   }
-  /// @returns the case condition or nullptr if none set
-  Literal* condition() const { return condition_.get(); }
+  /// @returns the case condition, empty if none set
+  const CaseSelectorList& conditions() const { return conditions_; }
   /// @returns true if this is a default statement
-  bool IsDefault() const { return condition_ == nullptr; }
+  bool IsDefault() const { return conditions_.empty(); }
 
   /// Sets the case body
   /// @param body the case body
@@ -77,7 +80,7 @@
  private:
   CaseStatement(const CaseStatement&) = delete;
 
-  std::unique_ptr<Literal> condition_;
+  CaseSelectorList conditions_;
   StatementList body_;
 };
 
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index 4e841ac..8653cd7 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -17,8 +17,10 @@
 #include "gtest/gtest.h"
 #include "src/ast/bool_literal.h"
 #include "src/ast/if_statement.h"
+#include "src/ast/int_literal.h"
 #include "src/ast/kill_statement.h"
 #include "src/ast/type/bool_type.h"
+#include "src/ast/type/i32_type.h"
 
 namespace tint {
 namespace ast {
@@ -28,22 +30,28 @@
 
 TEST_F(CaseStatementTest, Creation) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
 
-  auto* bool_ptr = b.get();
+  auto* bool_ptr = b.back().get();
   auto* kill_ptr = stmts[0].get();
 
   CaseStatement c(std::move(b), std::move(stmts));
-  EXPECT_EQ(c.condition(), bool_ptr);
+  ASSERT_EQ(c.conditions().size(), 1);
+  EXPECT_EQ(c.conditions()[0].get(), bool_ptr);
   ASSERT_EQ(c.body().size(), 1u);
   EXPECT_EQ(c.body()[0].get(), kill_ptr);
 }
 
 TEST_F(CaseStatementTest, Creation_WithSource) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
 
@@ -64,9 +72,11 @@
 
 TEST_F(CaseStatementTest, IsDefault_WithCondition) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   CaseStatement c;
-  c.set_condition(std::move(b));
+  c.set_conditions(std::move(b));
   EXPECT_FALSE(c.IsDefault());
 }
 
@@ -82,7 +92,9 @@
 
 TEST_F(CaseStatementTest, IsValid_NullBodyStatement) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
   stmts.push_back(nullptr);
@@ -93,20 +105,24 @@
 
 TEST_F(CaseStatementTest, IsValid_InvalidBodyStatement) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   StatementList stmts;
   stmts.push_back(std::make_unique<IfStatement>());
 
-  CaseStatement c(std::move(b), std::move(stmts));
+  CaseStatement c({std::move(b)}, std::move(stmts));
   EXPECT_FALSE(c.IsValid());
 }
 
 TEST_F(CaseStatementTest, ToStr_WithCondition) {
   ast::type::BoolType bool_type;
-  auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList b;
+  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
-  CaseStatement c(std::move(b), std::move(stmts));
+  CaseStatement c({std::move(b)}, std::move(stmts));
 
   std::ostringstream out;
   c.to_str(out, 2);
@@ -116,10 +132,28 @@
 )");
 }
 
+TEST_F(CaseStatementTest, ToStr_WithMultipleConditions) {
+  ast::type::I32Type i32;
+
+  CaseSelectorList b;
+  b.push_back(std::make_unique<IntLiteral>(&i32, 1));
+  b.push_back(std::make_unique<IntLiteral>(&i32, 2));
+  StatementList stmts;
+  stmts.push_back(std::make_unique<KillStatement>());
+  CaseStatement c(std::move(b), std::move(stmts));
+
+  std::ostringstream out;
+  c.to_str(out, 2);
+  EXPECT_EQ(out.str(), R"(  Case 1, 2{
+    Kill{}
+  }
+)");
+}
+
 TEST_F(CaseStatementTest, ToStr_WithoutCondition) {
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
-  CaseStatement c(nullptr, std::move(stmts));
+  CaseStatement c(CaseSelectorList{}, std::move(stmts));
 
   std::ostringstream out;
   c.to_str(out, 2);
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index c4adb28..512df39 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -30,7 +30,8 @@
 
 TEST_F(SwitchStatementTest, Creation) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -61,7 +62,8 @@
 
 TEST_F(SwitchStatementTest, IsValid) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -73,7 +75,8 @@
 
 TEST_F(SwitchStatementTest, IsValid_Null_Condition) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   CaseStatementList body;
   body.push_back(
       std::make_unique<CaseStatement>(std::move(lit), StatementList()));
@@ -85,7 +88,8 @@
 
 TEST_F(SwitchStatementTest, IsValid_Invalid_Condition) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   auto ident = std::make_unique<IdentifierExpression>("");
   CaseStatementList body;
   body.push_back(
@@ -97,7 +101,8 @@
 
 TEST_F(SwitchStatementTest, IsValid_Null_BodyStatement) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -115,8 +120,8 @@
   case_body.push_back(nullptr);
 
   CaseStatementList body;
-  body.push_back(
-      std::make_unique<CaseStatement>(nullptr, std::move(case_body)));
+  body.push_back(std::make_unique<CaseStatement>(CaseSelectorList{},
+                                                 std::move(case_body)));
 
   SwitchStatement stmt(std::move(ident), std::move(body));
   EXPECT_FALSE(stmt.IsValid());
@@ -138,7 +143,8 @@
 
 TEST_F(SwitchStatementTest, ToStr) {
   ast::type::BoolType bool_type;
-  auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+  CaseSelectorList lit;
+  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index c02a375..e81ef0c 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1800,7 +1800,7 @@
 }
 
 // switch_body
-//   : CASE const_literal COLON BRACKET_LEFT case_body BRACKET_RIGHT
+//   : CASE case_selectors COLON BRACKET_LEFT case_body BRACKET_RIGHT
 //   | DEFAULT COLON BRACKET_LEFT case_body BRACKET_RIGHT
 std::unique_ptr<ast::CaseStatement> ParserImpl::switch_body() {
   auto t = peek();
@@ -1813,14 +1813,14 @@
   auto stmt = std::make_unique<ast::CaseStatement>();
   stmt->set_source(source);
   if (t.IsCase()) {
-    auto cond = const_literal();
+    auto cond = case_selectors();
     if (has_error())
       return nullptr;
-    if (cond == nullptr) {
+    if (cond.empty()) {
       set_error(peek(), "unable to parse case conditional");
       return nullptr;
     }
-    stmt->set_condition(std::move(cond));
+    stmt->set_conditions(std::move(cond));
   }
 
   t = next();
@@ -1850,6 +1850,24 @@
   return stmt;
 }
 
+// case_selectors
+//   : const_literal (COMMA const_literal)*
+ast::CaseSelectorList ParserImpl::case_selectors() {
+  ast::CaseSelectorList selectors;
+
+  for (;;) {
+    auto cond = const_literal();
+    if (has_error())
+      return {};
+    if (cond == nullptr)
+      break;
+
+    selectors.push_back(std::move(cond));
+  }
+
+  return selectors;
+}
+
 // case_body
 //   :
 //   | statement case_body
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 826ea8a..718f3c2 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -23,6 +23,7 @@
 
 #include "src/ast/assignment_statement.h"
 #include "src/ast/builtin.h"
+#include "src/ast/case_statement.h"
 #include "src/ast/constructor_expression.h"
 #include "src/ast/else_statement.h"
 #include "src/ast/entry_point.h"
@@ -216,6 +217,9 @@
   /// Parses a `switch_body` grammar element
   /// @returns the parsed statement or nullptr
   std::unique_ptr<ast::CaseStatement> switch_body();
+  /// Parses a `case_selectors` grammar element
+  /// @returns the list of literals
+  ast::CaseSelectorList case_selectors();
   /// Parses a `case_body` grammar element
   /// @returns the parsed statements
   ast::StatementList case_body();
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 5104709..bd4a04f 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -165,8 +165,9 @@
   body.push_back(std::make_unique<ast::AssignmentStatement>(std::move(lhs),
                                                             std::move(rhs)));
 
-  ast::CaseStatement cse(std::make_unique<ast::IntLiteral>(&i32, 3),
-                         std::move(body));
+  ast::CaseSelectorList lit;
+  lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 3));
+  ast::CaseStatement cse(std::move(lit), std::move(body));
 
   EXPECT_TRUE(td()->DetermineResultType(&cse));
   ASSERT_NE(lhs_ptr->result_type(), nullptr);
@@ -355,9 +356,12 @@
   body.push_back(std::make_unique<ast::AssignmentStatement>(std::move(lhs),
                                                             std::move(rhs)));
 
+  ast::CaseSelectorList lit;
+  lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 3));
+
   ast::CaseStatementList cases;
-  cases.push_back(std::make_unique<ast::CaseStatement>(
-      std::make_unique<ast::IntLiteral>(&i32, 3), std::move(body)));
+  cases.push_back(
+      std::make_unique<ast::CaseStatement>(std::move(lit), std::move(body)));
 
   ast::SwitchStatement stmt(std::make_unique<ast::ScalarConstructorExpression>(
                                 std::make_unique<ast::IntLiteral>(&i32, 2)),
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index c83759e..50dad68 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -724,9 +724,18 @@
   } else {
     out_ << "case ";
 
-    if (!EmitLiteral(stmt->condition())) {
-      return false;
+    bool first = true;
+    for (const auto& lit : stmt->conditions()) {
+      if (!first) {
+        out_ << ", ";
+      }
+
+      first = false;
+      if (!EmitLiteral(lit.get())) {
+        return false;
+      }
     }
+
     out_ << ":";
   }
 
diff --git a/src/writer/wgsl/generator_impl_case_test.cc b/src/writer/wgsl/generator_impl_case_test.cc
index d1c3fd1..5a1a906 100644
--- a/src/writer/wgsl/generator_impl_case_test.cc
+++ b/src/writer/wgsl/generator_impl_case_test.cc
@@ -31,12 +31,13 @@
 
 TEST_F(GeneratorImplTest, Emit_Case) {
   ast::type::I32Type i32;
-  auto cond = std::make_unique<ast::IntLiteral>(&i32, 5);
 
   ast::StatementList body;
   body.push_back(std::make_unique<ast::BreakStatement>());
 
-  ast::CaseStatement c(std::move(cond), std::move(body));
+  ast::CaseSelectorList lit;
+  lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+  ast::CaseStatement c(std::move(lit), std::move(body));
 
   GeneratorImpl g;
   g.increment_indent();
@@ -48,6 +49,27 @@
 )");
 }
 
+TEST_F(GeneratorImplTest, Emit_Case_MultipleSelectors) {
+  ast::type::I32Type i32;
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::BreakStatement>());
+
+  ast::CaseSelectorList lit;
+  lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+  lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 6));
+  ast::CaseStatement c(std::move(lit), std::move(body));
+
+  GeneratorImpl g;
+  g.increment_indent();
+
+  ASSERT_TRUE(g.EmitCase(&c)) << g.error();
+  EXPECT_EQ(g.result(), R"(  case 5, 6: {
+    break;
+  }
+)");
+}
+
 TEST_F(GeneratorImplTest, Emit_Case_Default) {
   ast::CaseStatement c;
 
diff --git a/src/writer/wgsl/generator_impl_switch_test.cc b/src/writer/wgsl/generator_impl_switch_test.cc
index b784f21..b149541 100644
--- a/src/writer/wgsl/generator_impl_switch_test.cc
+++ b/src/writer/wgsl/generator_impl_switch_test.cc
@@ -37,7 +37,9 @@
   def->set_body(std::move(def_body));
 
   ast::type::I32Type i32;
-  auto case_val = std::make_unique<ast::IntLiteral>(&i32, 5);
+  ast::CaseSelectorList case_val;
+  case_val.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+
   ast::StatementList case_body;
   case_body.push_back(std::make_unique<ast::BreakStatement>());