Add case_selectors update
This CL adds the missing case_selectors option from the grammar updates.
Change-Id: Ia6c110e917dd574711d396fb34ad53a2a67cf1fe
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22306
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc
index 8aca58c..0989217 100644
--- a/src/ast/case_statement.cc
+++ b/src/ast/case_statement.cc
@@ -19,15 +19,14 @@
CaseStatement::CaseStatement() : Statement() {}
-CaseStatement::CaseStatement(std::unique_ptr<Literal> condition,
- StatementList body)
- : Statement(), condition_(std::move(condition)), body_(std::move(body)) {}
+CaseStatement::CaseStatement(CaseSelectorList conditions, StatementList body)
+ : Statement(), conditions_(std::move(conditions)), body_(std::move(body)) {}
CaseStatement::CaseStatement(const Source& source,
- std::unique_ptr<Literal> condition,
+ CaseSelectorList conditions,
StatementList body)
: Statement(source),
- condition_(std::move(condition)),
+ conditions_(std::move(conditions)),
body_(std::move(body)) {}
CaseStatement::CaseStatement(CaseStatement&&) = default;
@@ -52,7 +51,16 @@
if (IsDefault()) {
out << "Default{" << std::endl;
} else {
- out << "Case " << condition_->to_str() << "{" << std::endl;
+ out << "Case ";
+ bool first = true;
+ for (const auto& lit : conditions_) {
+ if (!first)
+ out << ", ";
+
+ first = false;
+ out << lit->to_str();
+ }
+ out << "{" << std::endl;
}
for (const auto& stmt : body_)
diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h
index 5d2c24e..11dcd33 100644
--- a/src/ast/case_statement.h
+++ b/src/ast/case_statement.h
@@ -27,35 +27,38 @@
namespace tint {
namespace ast {
+/// A list of case literals
+using CaseSelectorList = std::vector<std::unique_ptr<ast::Literal>>;
+
/// A case statement
class CaseStatement : public Statement {
public:
/// Constructor
CaseStatement();
/// Constructor
- /// @param condition the case condition
+ /// @param conditions the case conditions
/// @param body the case body
- CaseStatement(std::unique_ptr<Literal> condition, StatementList body);
+ CaseStatement(CaseSelectorList conditions, StatementList body);
/// Constructor
/// @param source the source information
- /// @param condition the case condition
+ /// @param conditions the case conditions
/// @param body the case body
CaseStatement(const Source& source,
- std::unique_ptr<Literal> condition,
+ CaseSelectorList conditions,
StatementList body);
/// Move constructor
CaseStatement(CaseStatement&&);
~CaseStatement() override;
- /// Sets the condition for the case statement
- /// @param condition the condition to set
- void set_condition(std::unique_ptr<Literal> condition) {
- condition_ = std::move(condition);
+ /// Sets the conditions for the case statement
+ /// @param conditions the conditions to set
+ void set_conditions(CaseSelectorList conditions) {
+ conditions_ = std::move(conditions);
}
- /// @returns the case condition or nullptr if none set
- Literal* condition() const { return condition_.get(); }
+ /// @returns the case condition, empty if none set
+ const CaseSelectorList& conditions() const { return conditions_; }
/// @returns true if this is a default statement
- bool IsDefault() const { return condition_ == nullptr; }
+ bool IsDefault() const { return conditions_.empty(); }
/// Sets the case body
/// @param body the case body
@@ -77,7 +80,7 @@
private:
CaseStatement(const CaseStatement&) = delete;
- std::unique_ptr<Literal> condition_;
+ CaseSelectorList conditions_;
StatementList body_;
};
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index 4e841ac..8653cd7 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -17,8 +17,10 @@
#include "gtest/gtest.h"
#include "src/ast/bool_literal.h"
#include "src/ast/if_statement.h"
+#include "src/ast/int_literal.h"
#include "src/ast/kill_statement.h"
#include "src/ast/type/bool_type.h"
+#include "src/ast/type/i32_type.h"
namespace tint {
namespace ast {
@@ -28,22 +30,28 @@
TEST_F(CaseStatementTest, Creation) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
StatementList stmts;
stmts.push_back(std::make_unique<KillStatement>());
- auto* bool_ptr = b.get();
+ auto* bool_ptr = b.back().get();
auto* kill_ptr = stmts[0].get();
CaseStatement c(std::move(b), std::move(stmts));
- EXPECT_EQ(c.condition(), bool_ptr);
+ ASSERT_EQ(c.conditions().size(), 1);
+ EXPECT_EQ(c.conditions()[0].get(), bool_ptr);
ASSERT_EQ(c.body().size(), 1u);
EXPECT_EQ(c.body()[0].get(), kill_ptr);
}
TEST_F(CaseStatementTest, Creation_WithSource) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
StatementList stmts;
stmts.push_back(std::make_unique<KillStatement>());
@@ -64,9 +72,11 @@
TEST_F(CaseStatementTest, IsDefault_WithCondition) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
CaseStatement c;
- c.set_condition(std::move(b));
+ c.set_conditions(std::move(b));
EXPECT_FALSE(c.IsDefault());
}
@@ -82,7 +92,9 @@
TEST_F(CaseStatementTest, IsValid_NullBodyStatement) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
StatementList stmts;
stmts.push_back(std::make_unique<KillStatement>());
stmts.push_back(nullptr);
@@ -93,20 +105,24 @@
TEST_F(CaseStatementTest, IsValid_InvalidBodyStatement) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
StatementList stmts;
stmts.push_back(std::make_unique<IfStatement>());
- CaseStatement c(std::move(b), std::move(stmts));
+ CaseStatement c({std::move(b)}, std::move(stmts));
EXPECT_FALSE(c.IsValid());
}
TEST_F(CaseStatementTest, ToStr_WithCondition) {
ast::type::BoolType bool_type;
- auto b = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList b;
+ b.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
+
StatementList stmts;
stmts.push_back(std::make_unique<KillStatement>());
- CaseStatement c(std::move(b), std::move(stmts));
+ CaseStatement c({std::move(b)}, std::move(stmts));
std::ostringstream out;
c.to_str(out, 2);
@@ -116,10 +132,28 @@
)");
}
+TEST_F(CaseStatementTest, ToStr_WithMultipleConditions) {
+ ast::type::I32Type i32;
+
+ CaseSelectorList b;
+ b.push_back(std::make_unique<IntLiteral>(&i32, 1));
+ b.push_back(std::make_unique<IntLiteral>(&i32, 2));
+ StatementList stmts;
+ stmts.push_back(std::make_unique<KillStatement>());
+ CaseStatement c(std::move(b), std::move(stmts));
+
+ std::ostringstream out;
+ c.to_str(out, 2);
+ EXPECT_EQ(out.str(), R"( Case 1, 2{
+ Kill{}
+ }
+)");
+}
+
TEST_F(CaseStatementTest, ToStr_WithoutCondition) {
StatementList stmts;
stmts.push_back(std::make_unique<KillStatement>());
- CaseStatement c(nullptr, std::move(stmts));
+ CaseStatement c(CaseSelectorList{}, std::move(stmts));
std::ostringstream out;
c.to_str(out, 2);
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index c4adb28..512df39 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -30,7 +30,8 @@
TEST_F(SwitchStatementTest, Creation) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
auto ident = std::make_unique<IdentifierExpression>("ident");
CaseStatementList body;
body.push_back(
@@ -61,7 +62,8 @@
TEST_F(SwitchStatementTest, IsValid) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
auto ident = std::make_unique<IdentifierExpression>("ident");
CaseStatementList body;
body.push_back(
@@ -73,7 +75,8 @@
TEST_F(SwitchStatementTest, IsValid_Null_Condition) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
CaseStatementList body;
body.push_back(
std::make_unique<CaseStatement>(std::move(lit), StatementList()));
@@ -85,7 +88,8 @@
TEST_F(SwitchStatementTest, IsValid_Invalid_Condition) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
auto ident = std::make_unique<IdentifierExpression>("");
CaseStatementList body;
body.push_back(
@@ -97,7 +101,8 @@
TEST_F(SwitchStatementTest, IsValid_Null_BodyStatement) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
auto ident = std::make_unique<IdentifierExpression>("ident");
CaseStatementList body;
body.push_back(
@@ -115,8 +120,8 @@
case_body.push_back(nullptr);
CaseStatementList body;
- body.push_back(
- std::make_unique<CaseStatement>(nullptr, std::move(case_body)));
+ body.push_back(std::make_unique<CaseStatement>(CaseSelectorList{},
+ std::move(case_body)));
SwitchStatement stmt(std::move(ident), std::move(body));
EXPECT_FALSE(stmt.IsValid());
@@ -138,7 +143,8 @@
TEST_F(SwitchStatementTest, ToStr) {
ast::type::BoolType bool_type;
- auto lit = std::make_unique<BoolLiteral>(&bool_type, true);
+ CaseSelectorList lit;
+ lit.push_back(std::make_unique<BoolLiteral>(&bool_type, true));
auto ident = std::make_unique<IdentifierExpression>("ident");
CaseStatementList body;
body.push_back(
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index c02a375..e81ef0c 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1800,7 +1800,7 @@
}
// switch_body
-// : CASE const_literal COLON BRACKET_LEFT case_body BRACKET_RIGHT
+// : CASE case_selectors COLON BRACKET_LEFT case_body BRACKET_RIGHT
// | DEFAULT COLON BRACKET_LEFT case_body BRACKET_RIGHT
std::unique_ptr<ast::CaseStatement> ParserImpl::switch_body() {
auto t = peek();
@@ -1813,14 +1813,14 @@
auto stmt = std::make_unique<ast::CaseStatement>();
stmt->set_source(source);
if (t.IsCase()) {
- auto cond = const_literal();
+ auto cond = case_selectors();
if (has_error())
return nullptr;
- if (cond == nullptr) {
+ if (cond.empty()) {
set_error(peek(), "unable to parse case conditional");
return nullptr;
}
- stmt->set_condition(std::move(cond));
+ stmt->set_conditions(std::move(cond));
}
t = next();
@@ -1850,6 +1850,24 @@
return stmt;
}
+// case_selectors
+// : const_literal (COMMA const_literal)*
+ast::CaseSelectorList ParserImpl::case_selectors() {
+ ast::CaseSelectorList selectors;
+
+ for (;;) {
+ auto cond = const_literal();
+ if (has_error())
+ return {};
+ if (cond == nullptr)
+ break;
+
+ selectors.push_back(std::move(cond));
+ }
+
+ return selectors;
+}
+
// case_body
// :
// | statement case_body
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 826ea8a..718f3c2 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -23,6 +23,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/builtin.h"
+#include "src/ast/case_statement.h"
#include "src/ast/constructor_expression.h"
#include "src/ast/else_statement.h"
#include "src/ast/entry_point.h"
@@ -216,6 +217,9 @@
/// Parses a `switch_body` grammar element
/// @returns the parsed statement or nullptr
std::unique_ptr<ast::CaseStatement> switch_body();
+ /// Parses a `case_selectors` grammar element
+ /// @returns the list of literals
+ ast::CaseSelectorList case_selectors();
/// Parses a `case_body` grammar element
/// @returns the parsed statements
ast::StatementList case_body();
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 5104709..bd4a04f 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -165,8 +165,9 @@
body.push_back(std::make_unique<ast::AssignmentStatement>(std::move(lhs),
std::move(rhs)));
- ast::CaseStatement cse(std::make_unique<ast::IntLiteral>(&i32, 3),
- std::move(body));
+ ast::CaseSelectorList lit;
+ lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 3));
+ ast::CaseStatement cse(std::move(lit), std::move(body));
EXPECT_TRUE(td()->DetermineResultType(&cse));
ASSERT_NE(lhs_ptr->result_type(), nullptr);
@@ -355,9 +356,12 @@
body.push_back(std::make_unique<ast::AssignmentStatement>(std::move(lhs),
std::move(rhs)));
+ ast::CaseSelectorList lit;
+ lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 3));
+
ast::CaseStatementList cases;
- cases.push_back(std::make_unique<ast::CaseStatement>(
- std::make_unique<ast::IntLiteral>(&i32, 3), std::move(body)));
+ cases.push_back(
+ std::make_unique<ast::CaseStatement>(std::move(lit), std::move(body)));
ast::SwitchStatement stmt(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2)),
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index c83759e..50dad68 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -724,9 +724,18 @@
} else {
out_ << "case ";
- if (!EmitLiteral(stmt->condition())) {
- return false;
+ bool first = true;
+ for (const auto& lit : stmt->conditions()) {
+ if (!first) {
+ out_ << ", ";
+ }
+
+ first = false;
+ if (!EmitLiteral(lit.get())) {
+ return false;
+ }
}
+
out_ << ":";
}
diff --git a/src/writer/wgsl/generator_impl_case_test.cc b/src/writer/wgsl/generator_impl_case_test.cc
index d1c3fd1..5a1a906 100644
--- a/src/writer/wgsl/generator_impl_case_test.cc
+++ b/src/writer/wgsl/generator_impl_case_test.cc
@@ -31,12 +31,13 @@
TEST_F(GeneratorImplTest, Emit_Case) {
ast::type::I32Type i32;
- auto cond = std::make_unique<ast::IntLiteral>(&i32, 5);
ast::StatementList body;
body.push_back(std::make_unique<ast::BreakStatement>());
- ast::CaseStatement c(std::move(cond), std::move(body));
+ ast::CaseSelectorList lit;
+ lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+ ast::CaseStatement c(std::move(lit), std::move(body));
GeneratorImpl g;
g.increment_indent();
@@ -48,6 +49,27 @@
)");
}
+TEST_F(GeneratorImplTest, Emit_Case_MultipleSelectors) {
+ ast::type::I32Type i32;
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::BreakStatement>());
+
+ ast::CaseSelectorList lit;
+ lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+ lit.push_back(std::make_unique<ast::IntLiteral>(&i32, 6));
+ ast::CaseStatement c(std::move(lit), std::move(body));
+
+ GeneratorImpl g;
+ g.increment_indent();
+
+ ASSERT_TRUE(g.EmitCase(&c)) << g.error();
+ EXPECT_EQ(g.result(), R"( case 5, 6: {
+ break;
+ }
+)");
+}
+
TEST_F(GeneratorImplTest, Emit_Case_Default) {
ast::CaseStatement c;
diff --git a/src/writer/wgsl/generator_impl_switch_test.cc b/src/writer/wgsl/generator_impl_switch_test.cc
index b784f21..b149541 100644
--- a/src/writer/wgsl/generator_impl_switch_test.cc
+++ b/src/writer/wgsl/generator_impl_switch_test.cc
@@ -37,7 +37,9 @@
def->set_body(std::move(def_body));
ast::type::I32Type i32;
- auto case_val = std::make_unique<ast::IntLiteral>(&i32, 5);
+ ast::CaseSelectorList case_val;
+ case_val.push_back(std::make_unique<ast::IntLiteral>(&i32, 5));
+
ast::StatementList case_body;
case_body.push_back(std::make_unique<ast::BreakStatement>());