Make case selectors an integer value

Change-Id: I819983701ed6cca4eba1a05b4edc5fdff10fa88d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22542
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h
index 9b4150c..9693bb5 100644
--- a/src/ast/case_statement.h
+++ b/src/ast/case_statement.h
@@ -20,14 +20,14 @@
 #include <vector>
 
 #include "src/ast/expression.h"
-#include "src/ast/literal.h"
+#include "src/ast/int_literal.h"
 #include "src/ast/statement.h"
 
 namespace tint {
 namespace ast {
 
 /// A list of case literals
-using CaseSelectorList = std::vector<std::unique_ptr<ast::Literal>>;
+using CaseSelectorList = std::vector<std::unique_ptr<ast::IntLiteral>>;
 
 /// A case statement
 class CaseStatement : public Statement {
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index cdc5982..72753f1 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -15,12 +15,12 @@
 #include "src/ast/case_statement.h"
 
 #include "gtest/gtest.h"
-#include "src/ast/bool_literal.h"
 #include "src/ast/if_statement.h"
 #include "src/ast/kill_statement.h"
 #include "src/ast/sint_literal.h"
-#include "src/ast/type/bool_type.h"
 #include "src/ast/type/i32_type.h"
+#include "src/ast/type/u32_type.h"
+#include "src/ast/uint_literal.h"
 
 namespace tint {
 namespace ast {
@@ -28,29 +28,48 @@
 
 using CaseStatementTest = testing::Test;
 
-TEST_F(CaseStatementTest, Creation) {
-  ast::type::BoolType bool_type;
+TEST_F(CaseStatementTest, Creation_i32) {
+  ast::type::I32Type i32;
 
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, 2));
 
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
 
-  auto* bool_ptr = b.back().get();
+  auto* int_ptr = b.back().get();
   auto* kill_ptr = stmts[0].get();
 
   CaseStatement c(std::move(b), std::move(stmts));
   ASSERT_EQ(c.selectors().size(), 1);
-  EXPECT_EQ(c.selectors()[0].get(), bool_ptr);
+  EXPECT_EQ(c.selectors()[0].get(), int_ptr);
+  ASSERT_EQ(c.body().size(), 1u);
+  EXPECT_EQ(c.body()[0].get(), kill_ptr);
+}
+
+TEST_F(CaseStatementTest, Creation_u32) {
+  ast::type::U32Type u32;
+
+  CaseSelectorList b;
+  b.push_back(std::make_unique<UintLiteral>(&u32, 2));
+
+  StatementList stmts;
+  stmts.push_back(std::make_unique<KillStatement>());
+
+  auto* int_ptr = b.back().get();
+  auto* kill_ptr = stmts[0].get();
+
+  CaseStatement c(std::move(b), std::move(stmts));
+  ASSERT_EQ(c.selectors().size(), 1);
+  EXPECT_EQ(c.selectors()[0].get(), int_ptr);
   ASSERT_EQ(c.body().size(), 1u);
   EXPECT_EQ(c.body()[0].get(), kill_ptr);
 }
 
 TEST_F(CaseStatementTest, Creation_WithSource) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, 2));
 
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
@@ -71,9 +90,9 @@
 }
 
 TEST_F(CaseStatementTest, IsDefault_WithSelectors) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, 2));
 
   CaseStatement c;
   c.set_selectors(std::move(b));
@@ -91,9 +110,9 @@
 }
 
 TEST_F(CaseStatementTest, IsValid_NullBodyStatement) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, 2));
 
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
@@ -104,9 +123,9 @@
 }
 
 TEST_F(CaseStatementTest, IsValid_InvalidBodyStatement) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, 2));
 
   StatementList stmts;
   stmts.push_back(std::make_unique<IfStatement>());
@@ -115,10 +134,10 @@
   EXPECT_FALSE(c.IsValid());
 }
 
-TEST_F(CaseStatementTest, ToStr_WithSelectors) {
-  ast::type::BoolType bool_type;
+TEST_F(CaseStatementTest, ToStr_WithSelectors_i32) {
+  ast::type::I32Type i32;
   CaseSelectorList b;
-  b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  b.push_back(std::make_unique<SintLiteral>(&i32, -2));
 
   StatementList stmts;
   stmts.push_back(std::make_unique<KillStatement>());
@@ -126,7 +145,24 @@
 
   std::ostringstream out;
   c.to_str(out, 2);
-  EXPECT_EQ(out.str(), R"(  Case true{
+  EXPECT_EQ(out.str(), R"(  Case -2{
+    Kill{}
+  }
+)");
+}
+
+TEST_F(CaseStatementTest, ToStr_WithSelectors_u32) {
+  ast::type::U32Type u32;
+  CaseSelectorList b;
+  b.push_back(std::make_unique<UintLiteral>(&u32, 2));
+
+  StatementList stmts;
+  stmts.push_back(std::make_unique<KillStatement>());
+  CaseStatement c({std::move(b)}, std::move(stmts));
+
+  std::ostringstream out;
+  c.to_str(out, 2);
+  EXPECT_EQ(out.str(), R"(  Case 2{
     Kill{}
   }
 )");
diff --git a/src/ast/int_literal_test.cc b/src/ast/int_literal_test.cc
index defdda4..d4cd32b 100644
--- a/src/ast/int_literal_test.cc
+++ b/src/ast/int_literal_test.cc
@@ -15,9 +15,9 @@
 #include "src/ast/int_literal.h"
 
 #include "gtest/gtest.h"
+#include "src/ast/sint_literal.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/u32_type.h"
-#include "src/ast/sint_literal.h"
 #include "src/ast/uint_literal.h"
 
 namespace tint {
diff --git a/src/ast/literal.cc b/src/ast/literal.cc
index d8923f2..3ac8543 100644
--- a/src/ast/literal.cc
+++ b/src/ast/literal.cc
@@ -18,6 +18,7 @@
 
 #include "src/ast/bool_literal.h"
 #include "src/ast/float_literal.h"
+#include "src/ast/int_literal.h"
 #include "src/ast/null_literal.h"
 #include "src/ast/sint_literal.h"
 #include "src/ast/uint_literal.h"
@@ -63,6 +64,11 @@
   return static_cast<FloatLiteral*>(this);
 }
 
+IntLiteral* Literal::AsInt() {
+  assert(IsInt());
+  return static_cast<IntLiteral*>(this);
+}
+
 SintLiteral* Literal::AsSint() {
   assert(IsSint());
   return static_cast<SintLiteral*>(this);
diff --git a/src/ast/literal.h b/src/ast/literal.h
index 89ba4d2..c60f577 100644
--- a/src/ast/literal.h
+++ b/src/ast/literal.h
@@ -26,6 +26,7 @@
 class FloatLiteral;
 class NullLiteral;
 class SintLiteral;
+class IntLiteral;
 class UintLiteral;
 
 /// Base class for a literal value
@@ -50,6 +51,8 @@
   BoolLiteral* AsBool();
   /// @returns the literal as a float literal
   FloatLiteral* AsFloat();
+  /// @returns the literal as an int literal
+  IntLiteral* AsInt();
   /// @returns the literal as a signed int literal
   SintLiteral* AsSint();
   /// @returns the literal as a null literal
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index 512df39..8141157 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -17,10 +17,10 @@
 #include <sstream>
 
 #include "gtest/gtest.h"
-#include "src/ast/bool_literal.h"
 #include "src/ast/case_statement.h"
 #include "src/ast/identifier_expression.h"
-#include "src/ast/type/bool_type.h"
+#include "src/ast/sint_literal.h"
+#include "src/ast/type/i32_type.h"
 
 namespace tint {
 namespace ast {
@@ -29,9 +29,11 @@
 using SwitchStatementTest = testing::Test;
 
 TEST_F(SwitchStatementTest, Creation) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 1));
+
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -61,9 +63,11 @@
 }
 
 TEST_F(SwitchStatementTest, IsValid) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 2));
+
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -74,9 +78,11 @@
 }
 
 TEST_F(SwitchStatementTest, IsValid_Null_Condition) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 2));
+
   CaseStatementList body;
   body.push_back(
       std::make_unique<CaseStatement>(std::move(lit), StatementList()));
@@ -87,9 +93,11 @@
 }
 
 TEST_F(SwitchStatementTest, IsValid_Invalid_Condition) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 2));
+
   auto ident = std::make_unique<IdentifierExpression>("");
   CaseStatementList body;
   body.push_back(
@@ -100,9 +108,11 @@
 }
 
 TEST_F(SwitchStatementTest, IsValid_Null_BodyStatement) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 2));
+
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -142,9 +152,11 @@
 }
 
 TEST_F(SwitchStatementTest, ToStr) {
-  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
   CaseSelectorList lit;
-  lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+  lit.push_back(std::make_unique<SintLiteral>(&i32, 2));
+
   auto ident = std::make_unique<IdentifierExpression>("ident");
   CaseStatementList body;
   body.push_back(
@@ -156,7 +168,7 @@
   EXPECT_EQ(out.str(), R"(  Switch{
     Identifier{ident}
     {
-      Case true{
+      Case 2{
       }
     }
   }
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 038e397..89edc77 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1856,13 +1856,19 @@
   ast::CaseSelectorList selectors;
 
   for (;;) {
+    auto t = peek();
     auto cond = const_literal();
     if (has_error())
       return {};
     if (cond == nullptr)
       break;
+    if (!cond->IsInt()) {
+      set_error(t, "invalid case selector must be an integer value");
+      return {};
+    }
 
-    selectors.push_back(std::move(cond));
+    std::unique_ptr<ast::IntLiteral> selector(cond.release()->AsInt());
+    selectors.push_back(std::move(selector));
   }
 
   return selectors;
diff --git a/src/reader/wgsl/parser_impl_switch_body_test.cc b/src/reader/wgsl/parser_impl_switch_body_test.cc
index c9458fd..3a68dd5 100644
--- a/src/reader/wgsl/parser_impl_switch_body_test.cc
+++ b/src/reader/wgsl/parser_impl_switch_body_test.cc
@@ -41,6 +41,14 @@
   EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
 }
 
+TEST_F(ParserImplTest, SwitchBody_Case_InvalidSelector_bool) {
+  auto* p = parser("case true: { a = 4; }");
+  auto e = p->switch_body();
+  ASSERT_TRUE(p->has_error());
+  ASSERT_EQ(e, nullptr);
+  EXPECT_EQ(p->error(), "1:6: invalid case selector must be an integer value");
+}
+
 TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) {
   auto* p = parser("case: { a = 4; }");
   auto e = p->switch_body();