[validation] validate switch statement

This CL adds validation for the following rules:
v-switch01: switch statement selector expression must be of a scalar integer type
v-0008: switch statement must have exactly one default clause
v-switch03: the case selector values must have the same type as the selector expression the case selectors for a switch statement
v-switch04: a literal value must not appear more than once in the case selectors for a switch statement
v-switch05: a fallthrough statement must not appear as the last statement in last clause of a switch

Bug: tint: 6
Change-Id: I264d5079cc6cb31075965c8721651dc76f3e2a24
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28062
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/validator_control_block_test.cc b/src/validator_control_block_test.cc
index 60f40e2..7163db2 100644
--- a/src/validator_control_block_test.cc
+++ b/src/validator_control_block_test.cc
@@ -22,6 +22,7 @@
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
 #include "src/ast/switch_statement.h"
+#include "src/ast/type/alias_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/u32_type.h"
@@ -37,8 +38,7 @@
 class ValidateControlBlockTest : public ValidatorTestHelper,
                                  public testing::Test {};
 
-TEST_F(ValidateControlBlockTest,
-       DISABLED_SwitchSelectorExpressionNoneIntegerType_Fail) {
+TEST_F(ValidateControlBlockTest, SwitchSelectorExpressionNoneIntegerType_Fail) {
   // var a : f32 = 3.14;
   // switch (a) {
   //   default: {}
@@ -49,7 +49,7 @@
   var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
       std::make_unique<ast::SintLiteral>(&f32, 3.14f)));
 
-  auto cond = std::make_unique<ast::IdentifierExpression>("a");
+  auto cond = std::make_unique<ast::IdentifierExpression>(Source{12, 34}, "a");
   ast::CaseSelectorList default_csl;
   auto block_default = std::make_unique<ast::BlockStatement>();
   ast::CaseStatementList body;
@@ -58,8 +58,8 @@
 
   auto block = std::make_unique<ast::BlockStatement>();
   block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
-  block->append(std::make_unique<ast::SwitchStatement>(
-      Source{12, 34}, std::move(cond), std::move(body)));
+  block->append(
+      std::make_unique<ast::SwitchStatement>(std::move(cond), std::move(body)));
 
   EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
   EXPECT_FALSE(v()->ValidateStatements(block.get()));
@@ -68,7 +68,7 @@
             "of a scalar integer type");
 }
 
-TEST_F(ValidateControlBlockTest, DISABLED_SwitchWithoutDefault_Fail) {
+TEST_F(ValidateControlBlockTest, SwitchWithoutDefault_Fail) {
   // var a : i32 = 2;
   // switch (a) {
   //   case 1: {}
@@ -94,11 +94,11 @@
   EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
   EXPECT_FALSE(v()->ValidateStatements(block.get()));
   EXPECT_EQ(v()->error(),
-            "12:34: v-switch02: switch statement must have exactly one default "
+            "12:34: v-0008: switch statement must have exactly one default "
             "clause");
 }
 
-TEST_F(ValidateControlBlockTest, DISABLED_SwitchWithTwoDefault_Fail) {
+TEST_F(ValidateControlBlockTest, SwitchWithTwoDefault_Fail) {
   // var a : i32 = 2;
   // switch (a) {
   //   default: {}
@@ -128,7 +128,46 @@
   ast::CaseSelectorList default_csl_2;
   auto block_default_2 = std::make_unique<ast::BlockStatement>();
   switch_body.push_back(std::make_unique<ast::CaseStatement>(
-      Source{12, 34}, std::move(default_csl_2), std::move(block_default_2)));
+      std::move(default_csl_2), std::move(block_default_2)));
+
+  auto block = std::make_unique<ast::BlockStatement>();
+  block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  block->append(std::make_unique<ast::SwitchStatement>(
+      Source{12, 34}, std::move(cond), std::move(switch_body)));
+
+  EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
+  EXPECT_FALSE(v()->ValidateStatements(block.get()));
+  EXPECT_EQ(v()->error(),
+            "12:34: v-0008: switch statement must have exactly one default "
+            "clause");
+}
+
+TEST_F(ValidateControlBlockTest,
+       SwitchConditionTypeMustMatchSelectorType2_Fail) {
+  // var a : i32 = 2;
+  // switch (a) {
+  //   case 1: {}
+  //   default: {}
+  // }
+  ast::type::U32Type u32;
+  ast::type::I32Type i32;
+  auto var =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &i32);
+  var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::SintLiteral>(&i32, 2)));
+
+  ast::CaseStatementList switch_body;
+  auto cond = std::make_unique<ast::IdentifierExpression>("a");
+
+  ast::CaseSelectorList csl;
+  csl.push_back(std::make_unique<ast::UintLiteral>(&u32, 1));
+  switch_body.push_back(std::make_unique<ast::CaseStatement>(
+      Source{12, 34}, std::move(csl), std::make_unique<ast::BlockStatement>()));
+
+  ast::CaseSelectorList default_csl;
+  auto block_default = std::make_unique<ast::BlockStatement>();
+  switch_body.push_back(std::make_unique<ast::CaseStatement>(
+      std::move(default_csl), std::move(block_default)));
 
   auto block = std::make_unique<ast::BlockStatement>();
   block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
@@ -138,12 +177,12 @@
   EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
   EXPECT_FALSE(v()->ValidateStatements(block.get()));
   EXPECT_EQ(v()->error(),
-            "12:34: v-switch02: switch statement must have exactly one default "
-            "clause");
+            "12:34: v-switch03: the case selector values must have the same "
+            "type as the selector expression.");
 }
 
 TEST_F(ValidateControlBlockTest,
-       DISABLED_SwitchConditionTypeMustMatchSelectorType_Fail) {
+       SwitchConditionTypeMustMatchSelectorType_Fail) {
   // var a : u32 = 2;
   // switch (a) {
   //   case -1: {}
@@ -176,14 +215,57 @@
 
   EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
   EXPECT_FALSE(v()->ValidateStatements(block.get()));
-  EXPECT_EQ(
-      v()->error(),
-      "12:34: v-switch03: the case selector values must have the same type "
-      "as the selector expression.");
+  EXPECT_EQ(v()->error(),
+            "12:34: v-switch03: the case selector values must have the same "
+            "type as the selector expression.");
 }
 
-TEST_F(ValidateControlBlockTest,
-       DISABLED_NonUniqueCaseSelectorLiteralValue_Fail) {
+TEST_F(ValidateControlBlockTest, NonUniqueCaseSelectorValueUint_Fail) {
+  // var a : u32 = 3;
+  // switch (a) {
+  //   case 0: {}
+  //   case 2, 2: {}
+  //   default: {}
+  // }
+  ast::type::U32Type u32;
+  auto var =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &u32);
+  var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::UintLiteral>(&u32, 3)));
+
+  ast::CaseStatementList switch_body;
+  auto cond = std::make_unique<ast::IdentifierExpression>("a");
+
+  ast::CaseSelectorList csl_1;
+  csl_1.push_back(std::make_unique<ast::UintLiteral>(&u32, 0));
+  switch_body.push_back(std::make_unique<ast::CaseStatement>(
+      std::move(csl_1), std::make_unique<ast::BlockStatement>()));
+
+  ast::CaseSelectorList csl_2;
+  csl_2.push_back(std::make_unique<ast::UintLiteral>(&u32, 2));
+  csl_2.push_back(std::make_unique<ast::UintLiteral>(&u32, 2));
+  switch_body.push_back(std::make_unique<ast::CaseStatement>(
+      Source{12, 34}, std::move(csl_2),
+      std::make_unique<ast::BlockStatement>()));
+
+  ast::CaseSelectorList default_csl;
+  auto block_default = std::make_unique<ast::BlockStatement>();
+  switch_body.push_back(std::make_unique<ast::CaseStatement>(
+      std::move(default_csl), std::move(block_default)));
+
+  auto block = std::make_unique<ast::BlockStatement>();
+  block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  block->append(std::make_unique<ast::SwitchStatement>(std::move(cond),
+                                                       std::move(switch_body)));
+
+  EXPECT_TRUE(td()->DetermineStatements(block.get())) << td()->error();
+  EXPECT_FALSE(v()->ValidateStatements(block.get()));
+  EXPECT_EQ(v()->error(),
+            "12:34: v-switch04: a literal value must not appear more than once "
+            "in the case selectors for a switch statement: '2'");
+}
+
+TEST_F(ValidateControlBlockTest, NonUniqueCaseSelectorValueSint_Fail) {
   // var a : i32 = 2;
   // switch (a) {
   //   case 10: {}
@@ -210,7 +292,7 @@
   csl_2.push_back(std::make_unique<ast::SintLiteral>(&i32, 2));
   csl_2.push_back(std::make_unique<ast::SintLiteral>(&i32, 10));
   switch_body.push_back(std::make_unique<ast::CaseStatement>(
-      Source{12, 34}, std::move(csl_1),
+      Source{12, 34}, std::move(csl_2),
       std::make_unique<ast::BlockStatement>()));
 
   ast::CaseSelectorList default_csl;
@@ -228,11 +310,10 @@
   EXPECT_EQ(
       v()->error(),
       "12:34: v-switch04: a literal value must not appear more than once in "
-      "the case selectors for a switch statement");
+      "the case selectors for a switch statement: '10'");
 }
 
-TEST_F(ValidateControlBlockTest,
-       DISABLED_LastClauseLastStatementIsFallthrough_Fail) {
+TEST_F(ValidateControlBlockTest, LastClauseLastStatementIsFallthrough_Fail) {
   // var a : i32 = 2;
   // switch (a) {
   //   default: { fallthrough; }
@@ -246,10 +327,11 @@
   auto cond = std::make_unique<ast::IdentifierExpression>("a");
   ast::CaseSelectorList default_csl;
   auto block_default = std::make_unique<ast::BlockStatement>();
-  block_default->append(std::make_unique<ast::FallthroughStatement>());
+  block_default->append(
+      std::make_unique<ast::FallthroughStatement>(Source{12, 34}));
   ast::CaseStatementList body;
   body.push_back(std::make_unique<ast::CaseStatement>(
-      Source{12, 34}, std::move(default_csl), std::move(block_default)));
+      std::move(default_csl), std::move(block_default)));
 
   auto block = std::make_unique<ast::BlockStatement>();
   block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
@@ -296,5 +378,50 @@
   EXPECT_TRUE(v()->ValidateStatements(block.get())) << v()->error();
 }
 
+TEST_F(ValidateControlBlockTest, SwitchCaseAlias_Pass) {
+  // entry_point vertex = main
+  // type MyInt = u32;
+  // fn main()->void {
+  //   var v: MyInt;
+  //   switch(v){
+  //     default: {}
+  //   }
+  // }
+  ast::type::U32Type u32;
+  ast::type::AliasType my_int{"MyInt", &u32};
+
+  auto var =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &my_int);
+  var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::SintLiteral>(&u32, 2)));
+
+  auto cond = std::make_unique<ast::IdentifierExpression>("a");
+  ast::CaseSelectorList default_csl;
+  auto block_default = std::make_unique<ast::BlockStatement>();
+  ast::CaseStatementList body;
+  body.push_back(std::make_unique<ast::CaseStatement>(
+      Source{12, 34}, std::move(default_csl), std::move(block_default)));
+
+  auto block = std::make_unique<ast::BlockStatement>();
+  block->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  block->append(
+      std::make_unique<ast::SwitchStatement>(std::move(cond), std::move(body)));
+  block->append(std::make_unique<ast::ReturnStatement>());
+
+  ast::type::VoidType void_type;
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("main", std::move(params), &void_type);
+  func->set_body(std::move(block));
+  auto entry_point = std::make_unique<ast::EntryPoint>(
+      ast::PipelineStage::kVertex, "", "main");
+  mod()->AddFunction(std::move(func));
+  mod()->AddAliasType(&my_int);
+  mod()->AddEntryPoint(std::move(entry_point));
+
+  EXPECT_TRUE(td()->Determine()) << td()->error();
+  EXPECT_TRUE(v()->Validate(mod())) << v()->error();
+}
+
 }  // namespace
 }  // namespace tint
diff --git a/src/validator_impl.cc b/src/validator_impl.cc
index 88a1950..a9500eb 100644
--- a/src/validator_impl.cc
+++ b/src/validator_impl.cc
@@ -14,10 +14,17 @@
 
 #include <cassert>
 #include "src/validator_impl.h"
+#include <unordered_set>
 #include "src/ast/call_statement.h"
 #include "src/ast/function.h"
+#include "src/ast/int_literal.h"
 #include "src/ast/intrinsic.h"
+#include "src/ast/sint_literal.h"
+#include "src/ast/switch_statement.h"
+#include "src/ast/type/i32_type.h"
+#include "src/ast/type/u32_type.h"
 #include "src/ast/type/void_type.h"
+#include "src/ast/uint_literal.h"
 #include "src/ast/variable_decl_statement.h"
 
 namespace tint {
@@ -239,13 +246,94 @@
   if (stmt->IsCall()) {
     return ValidateCallExpr(stmt->AsCall()->expr());
   }
+  if (stmt->IsSwitch()) {
+    return ValidateSwitch(stmt->AsSwitch());
+  }
+  if (stmt->IsCase()) {
+    return ValidateCase(stmt->AsCase());
+  }
+  return true;
+}
+
+bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
+  if (!ValidateExpression(s->condition())) {
+    return false;
+  }
+
+  auto* cond_type = s->condition()->result_type()->UnwrapAliasPtrAlias();
+  if (!(cond_type->IsI32() || cond_type->IsU32())) {
+    set_error(s->condition()->source(),
+              "v-switch01: switch statement selector expression must be of a "
+              "scalar integer type");
+    return false;
+  }
+
+  int default_counter = 0;
+  std::unordered_set<int32_t> selector_set;
+  for (const auto& case_stmt : s->body()) {
+    if (!ValidateStatement(case_stmt.get())) {
+      return false;
+    }
+
+    if (case_stmt.get()->IsDefault()) {
+      default_counter++;
+    }
+
+    for (const auto& selector : case_stmt.get()->selectors()) {
+      auto* selector_ptr = selector.get();
+      if (cond_type != selector_ptr->type()) {
+        set_error(case_stmt.get()->source(),
+                  "v-switch03: the case selector values must have the same "
+                  "type as the selector expression.");
+        return false;
+      }
+
+      auto v = static_cast<int32_t>(selector_ptr->type()->IsU32()
+                                        ? selector_ptr->AsUint()->value()
+                                        : selector_ptr->AsSint()->value());
+      if (selector_set.count(v)) {
+        auto v_str = selector_ptr->type()->IsU32()
+                         ? selector_ptr->AsUint()->to_str()
+                         : selector_ptr->AsSint()->to_str();
+        set_error(
+            case_stmt.get()->source(),
+            "v-switch04: a literal value must not appear more than once in "
+            "the case selectors for a switch statement: '" +
+                v_str + "'");
+        return false;
+      }
+      selector_set.emplace(v);
+    }
+  }
+
+  if (default_counter != 1) {
+    set_error(s->source(),
+              "v-0008: switch statement must have exactly one default clause");
+    return false;
+  }
+
+  auto* last_clause = s->body().back().get();
+  auto* last_stmt_of_last_clause = last_clause->AsCase()->body()->last();
+  if (last_stmt_of_last_clause && last_stmt_of_last_clause->IsFallthrough()) {
+    set_error(last_stmt_of_last_clause->source(),
+              "v-switch05: a fallthrough statement must not appear as "
+              "the last statement in last clause of a switch");
+    return false;
+  }
+  return true;
+}
+
+bool ValidatorImpl::ValidateCase(const ast::CaseStatement* c) {
+  if (!ValidateStatement(c->body())) {
+    return false;
+  }
   return true;
 }
 
 bool ValidatorImpl::ValidateCallExpr(const ast::CallExpression* expr) {
   if (!expr) {
-    // TODO(sarahM0): Here and other Validate.*: figure out whether return false
-    // or true
+    // TODO(sarahM0): Here and other Validate.*: figure out whether return
+    // false or true
     return false;
   }
 
diff --git a/src/validator_impl.h b/src/validator_impl.h
index 5f085fd..de29088 100644
--- a/src/validator_impl.h
+++ b/src/validator_impl.h
@@ -115,6 +115,14 @@
   /// @param eps the vector of entry points to check
   /// @return true if the validation was successful
   bool ValidateEntryPoints(const ast::EntryPointList& eps);
+  /// Validates switch statements
+  /// @param s the switch statement to check
+  /// @returns true if the valdiation was successful
+  bool ValidateSwitch(const ast::SwitchStatement* s);
+  /// Validates case statements
+  /// @param c the case statement to check
+  /// @returns true if the valdiation was successful
+  bool ValidateCase(const ast::CaseStatement* c);
 
  private:
   std::string error_;