[WGSL] Allow default as a case selector
This CL updates the WGSL parser to parse `default` as a case selector
value.
Bug: tint:1633
Change-Id: I57661d25924e36bec5c03f96399c557fb7bbf760
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106382
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 5ef8252..4af06d2 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -215,6 +215,8 @@
"ast/call_expression.h",
"ast/call_statement.cc",
"ast/call_statement.h",
+ "ast/case_selector.cc",
+ "ast/case_selector.h",
"ast/case_statement.cc",
"ast/case_statement.h",
"ast/compound_assignment_statement.cc",
@@ -1021,6 +1023,7 @@
"ast/builtin_value_test.cc",
"ast/call_expression_test.cc",
"ast/call_statement_test.cc",
+ "ast/case_selector_test.cc",
"ast/case_statement_test.cc",
"ast/compound_assignment_statement_test.cc",
"ast/continue_statement_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 0811598..512c906 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -83,6 +83,8 @@
ast/call_expression.h
ast/call_statement.cc
ast/call_statement.h
+ ast/case_selector.cc
+ ast/case_selector.h
ast/case_statement.cc
ast/case_statement.h
ast/compound_assignment_statement.cc
@@ -713,6 +715,7 @@
ast/builtin_value_test.cc
ast/call_expression_test.cc
ast/call_statement_test.cc
+ ast/case_selector_test.cc
ast/case_statement_test.cc
ast/compound_assignment_statement_test.cc
ast/continue_statement_test.cc
diff --git a/src/tint/ast/case_selector.cc b/src/tint/ast/case_selector.cc
new file mode 100644
index 0000000..8622d3a
--- /dev/null
+++ b/src/tint/ast/case_selector.cc
@@ -0,0 +1,39 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/case_selector.h"
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::CaseSelector);
+
+namespace tint::ast {
+
+CaseSelector::CaseSelector(ProgramID pid, NodeID nid, const Source& src, const ast::Expression* e)
+ : Base(pid, nid, src), expr(e) {}
+
+CaseSelector::CaseSelector(CaseSelector&&) = default;
+
+CaseSelector::~CaseSelector() = default;
+
+const CaseSelector* CaseSelector::Clone(CloneContext* ctx) const {
+ // Clone arguments outside of create() call to have deterministic ordering
+ auto src = ctx->Clone(source);
+ auto ex = ctx->Clone(expr);
+ return ctx->dst->create<CaseSelector>(src, ex);
+}
+
+} // namespace tint::ast
diff --git a/src/tint/ast/case_selector.h b/src/tint/ast/case_selector.h
new file mode 100644
index 0000000..b4c3ca7
--- /dev/null
+++ b/src/tint/ast/case_selector.h
@@ -0,0 +1,52 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_AST_CASE_SELECTOR_H_
+#define SRC_TINT_AST_CASE_SELECTOR_H_
+
+#include <vector>
+
+#include "src/tint/ast/block_statement.h"
+#include "src/tint/ast/expression.h"
+
+namespace tint::ast {
+
+/// A case selector
+class CaseSelector final : public Castable<CaseSelector, Node> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param src the source of this node
+ /// @param expr the selector expression, |nullptr| for a `default` selector
+ CaseSelector(ProgramID pid, NodeID nid, const Source& src, const Expression* expr = nullptr);
+ /// Move constructor
+ CaseSelector(CaseSelector&&);
+ ~CaseSelector() override;
+
+ /// @returns true if this is a default statement
+ bool IsDefault() const { return expr == nullptr; }
+
+ /// Clones this node and all transitive child nodes using the `CloneContext` `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned node
+ const CaseSelector* Clone(CloneContext* ctx) const override;
+
+ /// The selector, nullptr for a default selector
+ const Expression* const expr = nullptr;
+};
+
+} // namespace tint::ast
+
+#endif // SRC_TINT_AST_CASE_SELECTOR_H_
diff --git a/src/tint/ast/case_selector_test.cc b/src/tint/ast/case_selector_test.cc
new file mode 100644
index 0000000..16e74cc
--- /dev/null
+++ b/src/tint/ast/case_selector_test.cc
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/case_selector.h"
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/ast/test_helper.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast {
+namespace {
+
+using CaseSelectorTest = TestHelper;
+
+TEST_F(CaseSelectorTest, NonDefault) {
+ auto* e = Expr(2_i);
+ auto* c = CaseSelector(e);
+ EXPECT_FALSE(c->IsDefault());
+ EXPECT_EQ(e, c->expr);
+}
+
+TEST_F(CaseSelectorTest, Default) {
+ auto* c = DefaultCaseSelector();
+ EXPECT_TRUE(c->IsDefault());
+}
+
+} // namespace
+} // namespace tint::ast
diff --git a/src/tint/ast/case_statement.cc b/src/tint/ast/case_statement.cc
index 3125f05..7b2e798 100644
--- a/src/tint/ast/case_statement.cc
+++ b/src/tint/ast/case_statement.cc
@@ -25,10 +25,11 @@
CaseStatement::CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
- utils::VectorRef<const Expression*> s,
+ utils::VectorRef<const CaseSelector*> s,
const BlockStatement* b)
: Base(pid, nid, src), selectors(std::move(s)), body(b) {
TINT_ASSERT(AST, body);
+ TINT_ASSERT(AST, !selectors.IsEmpty());
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
for (auto* selector : selectors) {
TINT_ASSERT(AST, selector);
@@ -40,6 +41,15 @@
CaseStatement::~CaseStatement() = default;
+bool CaseStatement::ContainsDefault() const {
+ for (const auto* sel : selectors) {
+ if (sel->IsDefault()) {
+ return true;
+ }
+ }
+ return false;
+}
+
const CaseStatement* CaseStatement::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
diff --git a/src/tint/ast/case_statement.h b/src/tint/ast/case_statement.h
index eda9c01..acd502a 100644
--- a/src/tint/ast/case_statement.h
+++ b/src/tint/ast/case_statement.h
@@ -18,7 +18,7 @@
#include <vector>
#include "src/tint/ast/block_statement.h"
-#include "src/tint/ast/expression.h"
+#include "src/tint/ast/case_selector.h"
namespace tint::ast {
@@ -34,23 +34,23 @@
CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
- utils::VectorRef<const Expression*> selectors,
+ utils::VectorRef<const CaseSelector*> selectors,
const BlockStatement* body);
/// Move constructor
CaseStatement(CaseStatement&&);
~CaseStatement() override;
- /// @returns true if this is a default statement
- bool IsDefault() const { return selectors.IsEmpty(); }
-
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const CaseStatement* Clone(CloneContext* ctx) const override;
+ /// @returns true if this item contains a default selector
+ bool ContainsDefault() const;
+
/// The case selectors, empty if none set
- const utils::Vector<const Expression*, 4> selectors;
+ const utils::Vector<const CaseSelector*, 4> selectors;
/// The case body
const BlockStatement* const body;
diff --git a/src/tint/ast/case_statement_test.cc b/src/tint/ast/case_statement_test.cc
index dc3c88a..04b887b 100644
--- a/src/tint/ast/case_statement_test.cc
+++ b/src/tint/ast/case_statement_test.cc
@@ -27,7 +27,7 @@
using CaseStatementTest = TestHelper;
TEST_F(CaseStatementTest, Creation_i32) {
- auto* selector = Expr(2_i);
+ auto* selector = CaseSelector(2_i);
utils::Vector b{selector};
auto* discard = create<DiscardStatement>();
@@ -41,7 +41,7 @@
}
TEST_F(CaseStatementTest, Creation_u32) {
- auto* selector = Expr(2_u);
+ auto* selector = CaseSelector(2_u);
utils::Vector b{selector};
auto* discard = create<DiscardStatement>();
@@ -54,8 +54,20 @@
EXPECT_EQ(c->body->statements[0], discard);
}
+TEST_F(CaseStatementTest, ContainsDefault_WithDefault) {
+ utils::Vector b{CaseSelector(2_u), DefaultCaseSelector()};
+ auto* c = create<CaseStatement>(b, create<BlockStatement>(utils::Empty));
+ EXPECT_TRUE(c->ContainsDefault());
+}
+
+TEST_F(CaseStatementTest, ContainsDefault_WithOutDefault) {
+ utils::Vector b{CaseSelector(2_u), CaseSelector(3_u)};
+ auto* c = create<CaseStatement>(b, create<BlockStatement>(utils::Empty));
+ EXPECT_FALSE(c->ContainsDefault());
+}
+
TEST_F(CaseStatementTest, Creation_WithSource) {
- utils::Vector b{Expr(2_i)};
+ utils::Vector b{CaseSelector(2_i)};
auto* body = create<BlockStatement>(utils::Vector{
create<DiscardStatement>(),
@@ -66,22 +78,9 @@
EXPECT_EQ(src.range.begin.column, 2u);
}
-TEST_F(CaseStatementTest, IsDefault_WithoutSelectors) {
- auto* body = create<BlockStatement>(utils::Vector{
- create<DiscardStatement>(),
- });
- auto* c = create<CaseStatement>(utils::Empty, body);
- EXPECT_TRUE(c->IsDefault());
-}
-
-TEST_F(CaseStatementTest, IsDefault_WithSelectors) {
- utils::Vector b{Expr(2_i)};
- auto* c = create<CaseStatement>(b, create<BlockStatement>(utils::Empty));
- EXPECT_FALSE(c->IsDefault());
-}
-
TEST_F(CaseStatementTest, IsCase) {
- auto* c = create<CaseStatement>(utils::Empty, create<BlockStatement>(utils::Empty));
+ auto* c = create<CaseStatement>(utils::Vector{DefaultCaseSelector()},
+ create<BlockStatement>(utils::Empty));
EXPECT_TRUE(c->Is<CaseStatement>());
}
@@ -89,7 +88,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CaseStatement>(utils::Empty, nullptr);
+ b.create<CaseStatement>(utils::Vector{b.DefaultCaseSelector()}, nullptr);
},
"internal compiler error");
}
@@ -98,7 +97,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CaseStatement>(utils::Vector<const ast::IntLiteralExpression*, 1>{nullptr},
+ b.create<CaseStatement>(utils::Vector<const ast::CaseSelector*, 1>{nullptr},
b.create<BlockStatement>(utils::Empty));
},
"internal compiler error");
@@ -109,7 +108,8 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CaseStatement>(utils::Empty, b2.create<BlockStatement>(utils::Empty));
+ b1.create<CaseStatement>(utils::Vector{b1.DefaultCaseSelector()},
+ b2.create<BlockStatement>(utils::Empty));
},
"internal compiler error");
}
@@ -119,7 +119,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.create<CaseStatement>(utils::Vector{b2.Expr(2_i)},
+ b1.create<CaseStatement>(utils::Vector{b2.CaseSelector(b2.Expr(2_i))},
b1.create<BlockStatement>(utils::Empty));
},
"internal compiler error");
diff --git a/src/tint/ast/switch_statement_test.cc b/src/tint/ast/switch_statement_test.cc
index 0f66c61..00c515e 100644
--- a/src/tint/ast/switch_statement_test.cc
+++ b/src/tint/ast/switch_statement_test.cc
@@ -25,7 +25,7 @@
using SwitchStatementTest = TestHelper;
TEST_F(SwitchStatementTest, Creation) {
- auto* case_stmt = create<CaseStatement>(utils::Vector{Expr(1_u)}, Block());
+ auto* case_stmt = create<CaseStatement>(utils::Vector{CaseSelector(1_u)}, Block());
auto* ident = Expr("ident");
utils::Vector body{case_stmt};
@@ -44,7 +44,7 @@
}
TEST_F(SwitchStatementTest, IsSwitch) {
- utils::Vector lit{Expr(2_i)};
+ utils::Vector lit{CaseSelector(2_i)};
auto* ident = Expr("ident");
utils::Vector body{create<CaseStatement>(lit, Block())};
@@ -58,7 +58,8 @@
{
ProgramBuilder b;
CaseStatementList cases;
- cases.Push(b.create<CaseStatement>(utils::Vector{b.Expr(1_i)}, b.Block()));
+ cases.Push(
+ b.create<CaseStatement>(utils::Vector{b.CaseSelector(b.Expr(1_i))}, b.Block()));
b.create<SwitchStatement>(nullptr, cases);
},
"internal compiler error");
@@ -82,7 +83,7 @@
b1.create<SwitchStatement>(b2.Expr(true), utils::Vector{
b1.create<CaseStatement>(
utils::Vector{
- b1.Expr(1_i),
+ b1.CaseSelector(b1.Expr(1_i)),
},
b1.Block()),
});
@@ -98,7 +99,7 @@
b1.create<SwitchStatement>(b1.Expr(true), utils::Vector{
b2.create<CaseStatement>(
utils::Vector{
- b2.Expr(1_i),
+ b2.CaseSelector(b2.Expr(1_i)),
},
b2.Block()),
});
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 892cf27..f07a1f8 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2847,7 +2847,7 @@
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* Case(const Source& source,
- utils::VectorRef<const ast::Expression*> selectors,
+ utils::VectorRef<const ast::CaseSelector*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(source, std::move(selectors), body ? body : Block());
}
@@ -2856,7 +2856,7 @@
/// @param selectors list of selectors
/// @param body the case body
/// @returns the case statement pointer
- const ast::CaseStatement* Case(utils::VectorRef<const ast::Expression*> selectors,
+ const ast::CaseStatement* Case(utils::VectorRef<const ast::CaseSelector*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(std::move(selectors), body ? body : Block());
}
@@ -2865,7 +2865,7 @@
/// @param selector a single case selector
/// @param body the case body
/// @returns the case statement pointer
- const ast::CaseStatement* Case(const ast::Expression* selector,
+ const ast::CaseStatement* Case(const ast::CaseSelector* selector,
const ast::BlockStatement* body = nullptr) {
return Case(utils::Vector{selector}, body);
}
@@ -2876,16 +2876,44 @@
/// @returns the case statement pointer
const ast::CaseStatement* DefaultCase(const Source& source,
const ast::BlockStatement* body = nullptr) {
- return Case(source, utils::Empty, body);
+ return Case(source, utils::Vector{DefaultCaseSelector(source)}, body);
}
/// Convenience function that creates a 'default' ast::CaseStatement
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* DefaultCase(const ast::BlockStatement* body = nullptr) {
- return Case(utils::Empty, body);
+ return Case(utils::Vector{DefaultCaseSelector()}, body);
}
+ /// Convenience function that creates a case selector
+ /// @param source the source information
+ /// @param expr the selector expression
+ /// @returns the selector pointer
+ template <typename EXPR>
+ const ast::CaseSelector* CaseSelector(const Source& source, EXPR&& expr) {
+ return create<ast::CaseSelector>(source, Expr(std::forward<EXPR>(expr)));
+ }
+
+ /// Convenience function that creates a case selector
+ /// @param expr the selector expression
+ /// @returns the selector pointer
+ template <typename EXPR>
+ const ast::CaseSelector* CaseSelector(EXPR&& expr) {
+ return create<ast::CaseSelector>(source_, Expr(std::forward<EXPR>(expr)));
+ }
+
+ /// Convenience function that creates a default case selector
+ /// @param source the source information
+ /// @returns the selector pointer
+ const ast::CaseSelector* DefaultCaseSelector(const Source& source) {
+ return create<ast::CaseSelector>(source, nullptr);
+ }
+
+ /// Convenience function that creates a default case selector
+ /// @returns the selector pointer
+ const ast::CaseSelector* DefaultCaseSelector() { return create<ast::CaseSelector>(nullptr); }
+
/// Creates an ast::FallthroughStatement
/// @param source the source information
/// @returns the fallthrough statement pointer
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 2cc35d8..0856eff 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3024,7 +3024,7 @@
for (size_t i = last_clause_index;; --i) {
// Create a list of integer literals for the selector values leading to
// this case clause.
- utils::Vector<const ast::Expression*, 4> selectors;
+ utils::Vector<const ast::CaseSelector*, 4> selectors;
const bool has_selectors = clause_heads[i]->case_values.has_value();
if (has_selectors) {
auto values = clause_heads[i]->case_values.value();
@@ -3034,15 +3034,26 @@
// The Tint AST handles 32-bit values.
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
if (selector.type->IsUnsignedScalarOrVector()) {
- selectors.Push(create<ast::IntLiteralExpression>(
- Source{}, value32, ast::IntLiteralExpression::Suffix::kU));
+ selectors.Push(create<ast::CaseSelector>(
+ Source{}, create<ast::IntLiteralExpression>(
+ Source{}, value32, ast::IntLiteralExpression::Suffix::kU)));
} else {
- selectors.Push(
+ selectors.Push(create<ast::CaseSelector>(
+ Source{},
create<ast::IntLiteralExpression>(Source{}, static_cast<int32_t>(value32),
- ast::IntLiteralExpression::Suffix::kI));
+ ast::IntLiteralExpression::Suffix::kI)));
}
}
+
+ if ((default_info == clause_heads[i]) && construct->ContainsPos(default_info->pos)) {
+ // Generate a default selector
+ selectors.Push(create<ast::CaseSelector>(Source{}));
+ }
+ } else {
+ // Generate a default selector
+ selectors.Push(create<ast::CaseSelector>(Source{}));
}
+ TINT_ASSERT(Reader, !selectors.IsEmpty());
// Where does this clause end?
const auto end_id =
@@ -3057,17 +3068,6 @@
swch->cases[case_idx] = create<ast::CaseStatement>(Source{}, selectors, body);
});
- if ((default_info == clause_heads[i]) && has_selectors &&
- construct->ContainsPos(default_info->pos)) {
- // Generate a default clause with a just fallthrough.
- auto* stmts = create<ast::BlockStatement>(
- Source{}, StatementList{
- create<ast::FallthroughStatement>(Source{}),
- });
- auto* case_stmt = create<ast::CaseStatement>(Source{}, utils::Empty, stmts);
- swch->cases.Push(case_stmt);
- }
-
if (i == 0) {
break;
}
diff --git a/src/tint/reader/spirv/function_cfg_test.cc b/src/tint/reader/spirv/function_cfg_test.cc
index 8356e37..538657b 100644
--- a/src/tint/reader/spirv/function_cfg_test.cc
+++ b/src/tint/reader/spirv/function_cfg_test.cc
@@ -9349,10 +9349,7 @@
case 20u: {
var_1 = 20u;
}
- default: {
- fallthrough;
- }
- case 30u: {
+ case 30u, default: {
var_1 = 30u;
}
}
diff --git a/src/tint/reader/spirv/function_var_test.cc b/src/tint/reader/spirv/function_var_test.cc
index 5d156eb..e3d2acf 100644
--- a/src/tint/reader/spirv/function_var_test.cc
+++ b/src/tint/reader/spirv/function_var_test.cc
@@ -1393,10 +1393,7 @@
auto got = test::ToString(p->program(), ast_body);
auto* expect = R"(var x_41 : u32;
switch(1u) {
- default: {
- fallthrough;
- }
- case 0u: {
+ case 0u, default: {
fallthrough;
}
case 1u: {
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index cb8e217..f962482 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -2129,6 +2129,9 @@
}
selector_list = std::move(selectors.value);
+ } else {
+ // Push the default case selector
+ selector_list.Push(create<ast::CaseSelector>(t.source()));
}
// Consume the optional colon if present.
@@ -2148,12 +2151,12 @@
}
// case_selectors
-// : expression (COMMA expression)* COMMA?
+// : case_selector (COMMA case_selector)* COMMA?
Expect<ParserImpl::CaseSelectorList> ParserImpl::expect_case_selectors() {
CaseSelectorList selectors;
while (continue_parsing()) {
- auto expr = expression();
+ auto expr = case_selector();
if (expr.errored) {
return Failure::kErrored;
}
@@ -2168,12 +2171,32 @@
}
if (selectors.IsEmpty()) {
- return add_error(peek(), "unable to parse case selectors");
+ return add_error(peek(), "expected case selector expression or `default`");
}
return selectors;
}
+// case_selector
+// : DEFAULT
+// | expression
+Maybe<const ast::CaseSelector*> ParserImpl::case_selector() {
+ auto& p = peek();
+
+ if (match(Token::Type::kDefault)) {
+ return create<ast::CaseSelector>(p.source());
+ }
+
+ auto expr = expression();
+ if (expr.errored) {
+ return Failure::kErrored;
+ }
+ if (!expr.matched) {
+ return Failure::kNoMatch;
+ }
+ return create<ast::CaseSelector>(p.source(), expr.value);
+}
+
// case_body
// :
// | statement case_body
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 3dbd780..691cc5e 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -74,7 +74,7 @@
/// Pre-determined small vector sizes for AST pointers
//! @cond Doxygen_Suppress
using AttributeList = utils::Vector<const ast::Attribute*, 4>;
- using CaseSelectorList = utils::Vector<const ast::Expression*, 4>;
+ using CaseSelectorList = utils::Vector<const ast::CaseSelector*, 4>;
using CaseStatementList = utils::Vector<const ast::CaseStatement*, 4>;
using ExpressionList = utils::Vector<const ast::Expression*, 8>;
using ParameterList = utils::Vector<const ast::Parameter*, 8>;
@@ -573,6 +573,9 @@
/// Parses a `case_selectors` grammar element
/// @returns the list of literals
Expect<CaseSelectorList> expect_case_selectors();
+ /// Parses a `case_selector` grammar element
+ /// @returns the selector
+ Maybe<const ast::CaseSelector*> case_selector();
/// Parses a `case_body` grammar element
/// @returns the parsed statements
Maybe<const ast::BlockStatement*> case_body();
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index 657b522..1a88856 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -1333,7 +1333,7 @@
TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase) {
EXPECT("fn f() { switch(1) { case ^: } }",
- R"(test.wgsl:1:27 error: unable to parse case selectors
+ R"(test.wgsl:1:27 error: expected case selector expression or `default`
fn f() { switch(1) { case ^: } }
^
)");
diff --git a/src/tint/reader/wgsl/parser_impl_statement_test.cc b/src/tint/reader/wgsl/parser_impl_statement_test.cc
index 7bb51d3..f78e944 100644
--- a/src/tint/reader/wgsl/parser_impl_statement_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_statement_test.cc
@@ -143,7 +143,7 @@
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:18: unable to parse case selectors");
+ EXPECT_EQ(p->error(), "1:18: expected case selector expression or `default`");
}
TEST_F(ParserImplTest, Statement_Loop) {
diff --git a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
index 2b8b0bd..61ec524 100644
--- a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
@@ -25,13 +25,16 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
- ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
- auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ auto* sel = stmt->selectors[0];
+ EXPECT_FALSE(sel->IsDefault());
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->body->statements.Length(), 1u);
@@ -46,12 +49,16 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
- ASSERT_TRUE(stmt->selectors[0]->Is<ast::BinaryExpression>());
- auto* expr = stmt->selectors[0]->As<ast::BinaryExpression>();
+
+ auto* sel = stmt->selectors[0];
+ EXPECT_FALSE(sel->IsDefault());
+
+ ASSERT_TRUE(sel->expr->Is<ast::BinaryExpression>());
+ auto* expr = sel->expr->As<ast::BinaryExpression>();
EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
auto* v = expr->lhs->As<ast::IntLiteralExpression>();
@@ -74,13 +81,16 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
- ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+ auto* sel = stmt->selectors[0];
+ EXPECT_FALSE(sel->IsDefault());
- auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
+
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->body->statements.Length(), 1u);
@@ -95,17 +105,20 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
+
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
- ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+ auto* sel = stmt->selectors[0];
- auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
- expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
+ sel = stmt->selectors[1];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ expr = sel->expr->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
@@ -118,18 +131,20 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
- ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+ auto* sel = stmt->selectors[0];
- auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
- expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
+ sel = stmt->selectors[1];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ expr = sel->expr->As<ast::IntLiteralExpression>();
EXPECT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
@@ -141,7 +156,7 @@
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
+ EXPECT_EQ(p->error(), "1:6: expected case selector expression or `default`");
}
TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) {
@@ -151,7 +166,7 @@
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:5: unable to parse case selectors");
+ EXPECT_EQ(p->error(), "1:5: expected case selector expression or `default`");
}
TEST_F(ParserImplTest, SwitchBody_Case_MissingBracketLeft) {
@@ -202,17 +217,46 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
- ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
- auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
+ auto* sel = e->selectors[0];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
- expr = e->selectors[1]->As<ast::IntLiteralExpression>();
+ sel = e->selectors[1];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ expr = sel->expr->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+}
+
+TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_with_default) {
+ auto p = parser("case 1, default, 2 { }");
+ auto e = p->switch_body();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::CaseStatement>());
+ EXPECT_TRUE(e->ContainsDefault());
+ ASSERT_EQ(e->body->statements.Length(), 0u);
+ ASSERT_EQ(e->selectors.Length(), 3u);
+
+ auto* sel = e->selectors[0];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_TRUE(e->selectors[1]->IsDefault());
+
+ sel = e->selectors[2];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ expr = sel->expr->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
@@ -225,17 +269,19 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_FALSE(e->IsDefault());
+ EXPECT_FALSE(e->ContainsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
- ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
- auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
+ auto* sel = e->selectors[0];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = sel->expr->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 1);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
- expr = e->selectors[1]->As<ast::IntLiteralExpression>();
+ sel = e->selectors[1];
+ ASSERT_TRUE(sel->expr->Is<ast::IntLiteralExpression>());
+ expr = sel->expr->As<ast::IntLiteralExpression>();
ASSERT_EQ(expr->value, 2);
EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
@@ -257,7 +303,7 @@
EXPECT_TRUE(e.errored);
EXPECT_FALSE(e.matched);
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
+ EXPECT_EQ(p->error(), "1:6: expected case selector expression or `default`");
}
TEST_F(ParserImplTest, SwitchBody_Default) {
@@ -268,7 +314,7 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_TRUE(e->IsDefault());
+ EXPECT_TRUE(e->ContainsDefault());
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
@@ -281,7 +327,7 @@
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
- EXPECT_TRUE(e->IsDefault());
+ EXPECT_TRUE(e->ContainsDefault());
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
diff --git a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
index 014d850..a5e3dd3 100644
--- a/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_switch_stmt_test.cc
@@ -29,8 +29,8 @@
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body.Length(), 2u);
- EXPECT_FALSE(e->body[0]->IsDefault());
- EXPECT_FALSE(e->body[1]->IsDefault());
+ EXPECT_FALSE(e->body[0]->ContainsDefault());
+ EXPECT_FALSE(e->body[1]->ContainsDefault());
}
TEST_F(ParserImplTest, SwitchStmt_Empty) {
@@ -58,9 +58,24 @@
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body.Length(), 3u);
- ASSERT_FALSE(e->body[0]->IsDefault());
- ASSERT_TRUE(e->body[1]->IsDefault());
- ASSERT_FALSE(e->body[2]->IsDefault());
+ ASSERT_FALSE(e->body[0]->ContainsDefault());
+ ASSERT_TRUE(e->body[1]->ContainsDefault());
+ ASSERT_FALSE(e->body[2]->ContainsDefault());
+}
+
+TEST_F(ParserImplTest, SwitchStmt_Default_Mixed) {
+ auto p = parser(R"(switch a {
+ case 1, default, 2: {}
+})");
+ auto e = p->switch_statement();
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::SwitchStatement>());
+
+ ASSERT_EQ(e->body.Length(), 1u);
+ ASSERT_TRUE(e->body[0]->ContainsDefault());
}
TEST_F(ParserImplTest, SwitchStmt_WithParens) {
@@ -123,7 +138,7 @@
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "2:7: unable to parse case selectors");
+ EXPECT_EQ(p->error(), "2:7: expected case selector expression or `default`");
}
} // namespace
diff --git a/src/tint/resolver/compound_statement_test.cc b/src/tint/resolver/compound_statement_test.cc
index 0a96ced..b5a5ba9 100644
--- a/src/tint/resolver/compound_statement_test.cc
+++ b/src/tint/resolver/compound_statement_test.cc
@@ -390,8 +390,8 @@
auto* stmt_a = Ignore(1_i);
auto* stmt_b = Ignore(1_i);
auto* stmt_c = Ignore(1_i);
- auto* swi = Switch(expr, Case(Expr(1_i), Block(stmt_a)), Case(Expr(2_i), Block(stmt_b)),
- DefaultCase(Block(stmt_c)));
+ auto* swi = Switch(expr, Case(CaseSelector(1_i), Block(stmt_a)),
+ Case(CaseSelector(2_i), Block(stmt_b)), DefaultCase(Block(stmt_c)));
WrapInFunction(swi);
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/control_block_validation_test.cc b/src/tint/resolver/control_block_validation_test.cc
index 403d0bc..8cbf506 100644
--- a/src/tint/resolver/control_block_validation_test.cc
+++ b/src/tint/resolver/control_block_validation_test.cc
@@ -70,7 +70,7 @@
auto* block = Block(Decl(var), //
Switch(Source{{12, 34}}, "a", //
- Case(Expr(1_i))));
+ Case(CaseSelector(1_i))));
WrapInFunction(block);
@@ -87,16 +87,79 @@
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
- auto* block = Block(Decl(var), //
- Switch("a", //
- DefaultCase(), //
- Case(Expr(1_i)), //
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{9, 2}}), //
+ Case(CaseSelector(1_i)), //
DefaultCase(Source{{12, 34}})));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: switch statement must have exactly one default clause");
+ EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause
+9:2 note: previous default case)");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_OneInCase_Fail) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // case 1, default: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(
+ Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{CaseSelector(1_i), DefaultCaseSelector(Source{{9, 2}})}), //
+ DefaultCase(Source{{12, 34}})));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause
+9:2 note: previous default case)");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_SameCase) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // case default, 1, default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block =
+ Block(Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{DefaultCaseSelector(Source{{9, 2}}), CaseSelector(1_i),
+ DefaultCaseSelector(Source{{12, 34}})})));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause
+9:2 note: previous default case)");
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_DifferentMultiCase) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // case 1, default: {}
+ // case default, 2: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(
+ Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{CaseSelector(1_i), DefaultCaseSelector(Source{{9, 2}})}),
+ Case(utils::Vector{DefaultCaseSelector(Source{{12, 34}}), CaseSelector(2_i)})));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: switch statement must have exactly one default clause
+9:2 note: previous default case)");
}
TEST_F(ResolverControlBlockValidationTest, UnreachableCode_Loop_continue) {
@@ -187,9 +250,9 @@
auto* decl_z = Decl(Var("z", ty.i32()));
auto* brk = Break();
auto* assign_z = Assign(Source{{12, 34}}, "z", 1_i);
- WrapInFunction( //
- Block(Switch(1_i, //
- Case(Expr(1_i), Block(decl_z, brk, assign_z)), //
+ WrapInFunction( //
+ Block(Switch(1_i, //
+ Case(CaseSelector(1_i), Block(decl_z, brk, assign_z)), //
DefaultCase())));
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -210,11 +273,11 @@
auto* decl_z = Decl(Var("z", ty.i32()));
auto* brk = Break();
auto* assign_z = Assign(Source{{12, 34}}, "z", 1_i);
- WrapInFunction(
- Loop(Block(Switch(1_i, //
- Case(Expr(1_i), Block(decl_z, Block(Block(Block(brk))), assign_z)),
- DefaultCase()), //
- Break())));
+ WrapInFunction(Loop(
+ Block(Switch(1_i, //
+ Case(CaseSelector(1_i), Block(decl_z, Block(Block(Block(brk))), assign_z)),
+ DefaultCase()), //
+ Break())));
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable");
@@ -231,8 +294,8 @@
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
- auto* block = Block(Decl(var), Switch("a", //
- Case(Expr(Source{{12, 34}}, 1_u)), //
+ auto* block = Block(Decl(var), Switch("a", //
+ Case(CaseSelector(Source{{12, 34}}, 1_u)), //
DefaultCase()));
WrapInFunction(block);
@@ -250,9 +313,9 @@
// }
auto* var = Var("a", ty.u32(), Expr(2_u));
- auto* block = Block(Decl(var), //
- Switch("a", //
- Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), //
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(CaseSelector(Source{{12, 34}}, -1_i)), //
DefaultCase()));
WrapInFunction(block);
@@ -273,11 +336,11 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(Expr(0_u)),
+ Case(CaseSelector(0_u)),
Case(utils::Vector{
- Expr(Source{{12, 34}}, 2_u),
- Expr(3_u),
- Expr(Source{{56, 78}}, 2_u),
+ CaseSelector(Source{{12, 34}}, 2_u),
+ CaseSelector(3_u),
+ CaseSelector(Source{{56, 78}}, 2_u),
}),
DefaultCase()));
WrapInFunction(block);
@@ -299,12 +362,12 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(Expr(Source{{12, 34}}, -10_i)),
+ Case(CaseSelector(Source{{12, 34}}, -10_i)),
Case(utils::Vector{
- Expr(0_i),
- Expr(1_i),
- Expr(2_i),
- Expr(Source{{56, 78}}, -10_i),
+ CaseSelector(0_i),
+ CaseSelector(1_i),
+ CaseSelector(2_i),
+ CaseSelector(Source{{56, 78}}, -10_i),
}),
DefaultCase()));
WrapInFunction(block);
@@ -344,7 +407,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
- Case(Expr(5_i))));
+ Case(CaseSelector(5_i))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -361,7 +424,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
- Case(Add(5_i, 6_i))));
+ Case(CaseSelector(Add(5_i, 6_i)))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -378,7 +441,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
- Case(Add(5_i, 6_i))));
+ Case(CaseSelector(Add(5_i, 6_i)))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -395,7 +458,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
DefaultCase(Source{{12, 34}}), //
- Case(Add(5_a, 6_a))));
+ Case(CaseSelector(Add(5_a, 6_a)))));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -409,10 +472,12 @@
// }
auto* var = Var("a", Expr(2_u));
- auto* block = Block(Decl(var), //
- Switch("a", //
- DefaultCase(Source{{12, 34}}), //
- Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)})));
+ auto* block =
+ Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(utils::Vector{CaseSelector(Add(5_u, 6_u)), CaseSelector(Add(7_u, 9_u)),
+ CaseSelector(Mul(2_u, 4_u))})));
WrapInFunction(block);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -446,8 +511,8 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(Expr(Source{{12, 34}}, 10_i)),
- Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase()));
+ Case(CaseSelector(Source{{12, 34}}, 10_i)),
+ Case(CaseSelector(Source{{56, 78}}, Add(5_i, 5_i))), DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
@@ -466,8 +531,8 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i),
- Add(Source{{12, 34}}, 6_i, 4_i)}),
+ Case(utils::Vector{CaseSelector(Source{{56, 78}}, Add(5_i, 5_i)),
+ CaseSelector(Source{{12, 34}}, Add(6_i, 4_i))}),
DefaultCase()));
WrapInFunction(block);
@@ -485,11 +550,11 @@
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
- auto* block = Block(
- Decl(var), //
- Switch("a", //
- Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}),
- DefaultCase()));
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{CaseSelector(Source{{56, 78}}, Add(5_i, 5_i)),
+ CaseSelector(Source{{12, 34}}, 10_i)}),
+ DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
@@ -508,7 +573,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(Expr(Source{{12, 34}}, "b")), DefaultCase()));
+ Case(CaseSelector(Source{{12, 34}}, "b")), DefaultCase()));
WrapInFunction(block);
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index e84eec5..edc111c 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -299,7 +299,7 @@
TraverseExpression(s->condition);
for (auto* c : s->body) {
for (auto* sel : c->selectors) {
- TraverseExpression(sel);
+ TraverseExpression(sel->expr);
}
TraverseStatement(c->body);
}
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 724a527..a984c20 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1262,9 +1262,9 @@
Loop(Block(Assign(V, V)), //
Block(Assign(V, V))), //
Switch(V, //
- Case(Expr(1_i), //
+ Case(CaseSelector(1_i), //
Block(Assign(V, V))), //
- Case(Expr(2_i), //
+ Case(CaseSelector(2_i), //
Block(Fallthrough())), //
DefaultCase(Block(Assign(V, V)))), //
Return(V), //
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 251fe82..1b96a27 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -356,26 +356,30 @@
WrapInFunction(Add(target_expr(), abstract_expr));
break;
case Method::kSwitchCond:
- WrapInFunction(Switch(abstract_expr, //
- Case(target_expr()->As<ast::IntLiteralExpression>()), //
- DefaultCase()));
+ WrapInFunction(
+ Switch(abstract_expr, //
+ Case(CaseSelector(target_expr()->As<ast::IntLiteralExpression>())), //
+ DefaultCase()));
break;
case Method::kSwitchCase:
- WrapInFunction(Switch(target_expr(), //
- Case(abstract_expr->As<ast::IntLiteralExpression>()), //
- DefaultCase()));
+ WrapInFunction(
+ Switch(target_expr(), //
+ Case(CaseSelector(abstract_expr->As<ast::IntLiteralExpression>())), //
+ DefaultCase()));
break;
case Method::kSwitchCondWithAbstractCase:
- WrapInFunction(Switch(abstract_expr, //
- Case(Expr(123_a)), //
- Case(target_expr()->As<ast::IntLiteralExpression>()), //
- DefaultCase()));
+ WrapInFunction(
+ Switch(abstract_expr, //
+ Case(CaseSelector(123_a)), //
+ Case(CaseSelector(target_expr()->As<ast::IntLiteralExpression>())), //
+ DefaultCase()));
break;
case Method::kSwitchCaseWithAbstractCase:
- WrapInFunction(Switch(target_expr(), //
- Case(Expr(123_a)), //
- Case(abstract_expr->As<ast::IntLiteralExpression>()), //
- DefaultCase()));
+ WrapInFunction(
+ Switch(target_expr(), //
+ Case(CaseSelector(123_a)), //
+ Case(CaseSelector(abstract_expr->As<ast::IntLiteralExpression>())), //
+ DefaultCase()));
break;
case Method::kWorkgroupSize:
Func("f", utils::Empty, ty.void_(), utils::Empty,
@@ -903,9 +907,10 @@
break;
}
case Method::kSwitch: {
- WrapInFunction(Switch(abstract_expr(),
- Case(abstract_expr()->As<ast::IntLiteralExpression>()),
- DefaultCase()));
+ WrapInFunction(
+ Switch(abstract_expr(),
+ Case(CaseSelector(abstract_expr()->As<ast::IntLiteralExpression>())),
+ DefaultCase()));
break;
}
case Method::kWorkgroupSize: {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index b2961bd..21501c5 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1240,26 +1240,31 @@
return StatementScope(stmt, sem, [&] {
sem->Selectors().reserve(stmt->selectors.Length());
for (auto* sel : stmt->selectors) {
+ Mark(sel);
+
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- // The sem statement is created in the switch when attempting to determine the common
- // type.
- auto* materialized = Materialize(sem_.Get(sel), ty);
- if (!materialized) {
- return false;
- }
- if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
- AddError("case selector must be an i32 or u32 value", sel->source);
- return false;
- }
- auto const_value = materialized->ConstantValue();
- if (!const_value) {
- AddError("case selector must be a constant expression", sel->source);
- return false;
+ const sem::Constant* const_value = nullptr;
+ if (!sel->IsDefault()) {
+ // The sem statement was created in the switch when attempting to determine the
+ // common type.
+ auto* materialized = Materialize(sem_.Get(sel->expr), ty);
+ if (!materialized) {
+ return false;
+ }
+ if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
+ AddError("case selector must be an i32 or u32 value", sel->source);
+ return false;
+ }
+ const_value = materialized->ConstantValue();
+ if (!const_value) {
+ AddError("case selector must be a constant expression", sel->source);
+ return false;
+ }
}
- sem->Selectors().emplace_back(const_value);
+ sem->Selectors().emplace_back(builder_->create<sem::CaseSelector>(sel, const_value));
}
Mark(stmt->body);
@@ -3103,8 +3108,11 @@
utils::Vector<const sem::Type*, 8> types;
types.Push(cond_ty);
for (auto* case_stmt : stmt->body) {
- for (auto* expr : case_stmt->selectors) {
- auto* sem_expr = Expression(expr);
+ for (auto* sel : case_stmt->selectors) {
+ if (sel->IsDefault()) {
+ continue;
+ }
+ auto* sem_expr = Expression(sel->expr);
types.Push(sem_expr->Type()->UnwrapRef());
}
}
@@ -3127,9 +3135,6 @@
if (!c) {
return false;
}
- for (auto* expr : c->Selectors()) {
- types.Push(expr->Type()->UnwrapRef());
- }
cases.Push(c);
behaviors.Add(c->Behaviors());
sem->Cases().emplace_back(c);
diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc
index 5002857..cf91501 100644
--- a/src/tint/resolver/resolver_behavior_test.cc
+++ b/src/tint/resolver/resolver_behavior_test.cc
@@ -633,7 +633,7 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultEmpty) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block()));
+ auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -643,7 +643,7 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultDiscard) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Discard())));
+ auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block(Discard())));
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
@@ -655,7 +655,7 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Empty_DefaultReturn) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block()), DefaultCase(Block(Return())));
+ auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block()), DefaultCase(Block(Return())));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -665,7 +665,7 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultEmpty) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block()));
+ auto* stmt = Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block()));
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
@@ -677,7 +677,8 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultDiscard) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Discard())));
+ auto* stmt =
+ Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block(Discard())));
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
@@ -689,7 +690,8 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_DefaultReturn) {
- auto* stmt = Switch(1_i, Case(Expr(0_i), Block(Discard())), DefaultCase(Block(Return())));
+ auto* stmt =
+ Switch(1_i, Case(CaseSelector(0_i), Block(Discard())), DefaultCase(Block(Return())));
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
utils::Vector{Stage(ast::PipelineStage::kFragment)});
@@ -701,9 +703,9 @@
}
TEST_F(ResolverBehaviorTest, StmtSwitch_CondLiteral_Case0Discard_Case1Return_DefaultEmpty) {
- auto* stmt = Switch(1_i, //
- Case(Expr(0_i), Block(Discard())), //
- Case(Expr(1_i), Block(Return())), //
+ auto* stmt = Switch(1_i, //
+ Case(CaseSelector(0_i), Block(Discard())), //
+ Case(CaseSelector(1_i), Block(Return())), //
DefaultCase(Block()));
Func("F", utils::Empty, ty.void_(), utils::Vector{stmt},
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 4538bf4..0af3894 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -112,7 +112,7 @@
auto* assign = Assign(lhs, rhs);
auto* block = Block(assign);
- auto* sel = Expr(3_i);
+ auto* sel = CaseSelector(3_i);
auto* cse = Case(sel, block);
auto* def = DefaultCase();
auto* cond_var = Var("c", ty.i32());
@@ -132,7 +132,7 @@
ASSERT_EQ(sem->Cases().size(), 2u);
EXPECT_EQ(sem->Cases()[0]->Declaration(), cse);
ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u);
- EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u);
+ EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 1u);
}
TEST_F(ResolverTest, Stmt_Block) {
@@ -251,7 +251,7 @@
auto* lhs = Expr("v");
auto* rhs = Expr(2.3_f);
auto* case_block = Block(Assign(lhs, rhs));
- auto* stmt = Switch(Expr(2_i), Case(Expr(3_i), case_block), DefaultCase());
+ auto* stmt = Switch(Expr(2_i), Case(CaseSelector(3_i), case_block), DefaultCase());
WrapInFunction(v, stmt);
EXPECT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index fb210c0..afd9e91 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -1039,11 +1039,11 @@
}
TEST_F(ResolverValidationTest, Stmt_BreakInSwitch) {
- WrapInFunction(Loop(Block(Switch(Expr(1_i), //
- Case(Expr(1_i), //
- Block(Break())), //
- DefaultCase()), //
- Break()))); //
+ WrapInFunction(Loop(Block(Switch(Expr(1_i), //
+ Case(CaseSelector(1_i), //
+ Block(Break())), //
+ DefaultCase()), //
+ Break()))); //
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index f738274..4489543 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -2324,54 +2324,50 @@
return false;
}
- bool has_default = false;
+ const sem::CaseSelector* default_selector = nullptr;
std::unordered_map<int64_t, Source> selectors;
for (auto* case_stmt : s->body) {
- if (case_stmt->IsDefault()) {
- if (has_default) {
- // More than one default clause
- AddError("switch statement must have exactly one default clause",
- case_stmt->source);
- return false;
- }
- has_default = true;
- }
-
auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
+ for (auto* selector : case_sem->Selectors()) {
+ if (selector->IsDefault()) {
+ if (default_selector != nullptr) {
+ // More than one default clause
+ AddError("switch statement must have exactly one default clause",
+ selector->Declaration()->source);
- auto& case_selectors = case_stmt->selectors;
- auto& selector_values = case_sem->Selectors();
- TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size());
- for (size_t i = 0; i < case_sem->Selectors().size(); ++i) {
- auto* selector = selector_values[i];
- if (cond_ty != selector->Type()) {
+ AddNote("previous default case", default_selector->Declaration()->source);
+ return false;
+ }
+ default_selector = selector;
+ continue;
+ }
+
+ auto* decl_ty = selector->Value()->Type();
+ if (cond_ty != decl_ty) {
AddError(
"the case selector values must have the same type as the selector expression.",
- case_selectors[i]->source);
+ selector->Declaration()->source);
return false;
}
- auto value = selector->As<uint32_t>();
+ auto value = selector->Value()->As<uint32_t>();
auto it = selectors.find(value);
if (it != selectors.end()) {
- std::string err = "duplicate switch case '";
- if (selector->Type()->Is<sem::I32>()) {
- err += std::to_string(selector->As<int32_t>());
- } else {
- err += std::to_string(value);
- }
- err += "'";
-
- AddError(err, case_selectors[i]->source);
+ AddError("duplicate switch case '" +
+ (decl_ty->IsAnyOf<sem::I32, sem::AbstractNumeric>()
+ ? std::to_string(i32(value))
+ : std::to_string(value)) +
+ "'",
+ selector->Declaration()->source);
AddNote("previous case declared here", it->second);
return false;
}
- selectors.emplace(value, case_selectors[i]->source);
+ selectors.emplace(value, selector->Declaration()->source);
}
}
- if (!has_default) {
+ if (default_selector == nullptr) {
// No default clause
AddError("switch statement must have a default clause", s->source);
return false;
diff --git a/src/tint/sem/switch_statement.cc b/src/tint/sem/switch_statement.cc
index ed3942d..5eb2f09 100644
--- a/src/tint/sem/switch_statement.cc
+++ b/src/tint/sem/switch_statement.cc
@@ -17,6 +17,7 @@
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseSelector);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement);
namespace tint::sem {
@@ -48,4 +49,13 @@
return static_cast<const ast::CaseStatement*>(Base::Declaration());
}
+CaseSelector::CaseSelector(const ast::CaseSelector* decl, const Constant* val)
+ : Base(), decl_(decl), val_(val) {}
+
+CaseSelector::~CaseSelector() = default;
+
+const ast::CaseSelector* CaseSelector::Declaration() const {
+ return decl_;
+}
+
} // namespace tint::sem
diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h
index 7028c05..929f8cf 100644
--- a/src/tint/sem/switch_statement.h
+++ b/src/tint/sem/switch_statement.h
@@ -22,10 +22,12 @@
// Forward declarations
namespace tint::ast {
class CaseStatement;
+class CaseSelector;
class SwitchStatement;
} // namespace tint::ast
namespace tint::sem {
class CaseStatement;
+class CaseSelector;
class Constant;
class Expression;
} // namespace tint::sem
@@ -83,14 +85,39 @@
const BlockStatement* Body() const { return body_; }
/// @returns the selectors for the case
- std::vector<const Constant*>& Selectors() { return selectors_; }
+ std::vector<const CaseSelector*>& Selectors() { return selectors_; }
/// @returns the selectors for the case
- const std::vector<const Constant*>& Selectors() const { return selectors_; }
+ const std::vector<const CaseSelector*>& Selectors() const { return selectors_; }
private:
const BlockStatement* body_ = nullptr;
- std::vector<const Constant*> selectors_;
+ std::vector<const CaseSelector*> selectors_;
+};
+
+/// Holds semantic information about a switch case selector
+class CaseSelector final : public Castable<CaseSelector, Node> {
+ public:
+ /// Constructor
+ /// @param decl the selector declaration
+ /// @param val the case selector value, nullptr for a default selector
+ explicit CaseSelector(const ast::CaseSelector* decl, const Constant* val = nullptr);
+
+ /// Destructor
+ ~CaseSelector() override;
+
+ /// @returns true if this is a default selector
+ bool IsDefault() const { return val_ == nullptr; }
+
+ /// @returns the case selector declaration
+ const ast::CaseSelector* Declaration() const;
+
+ /// @returns the selector constant value, or nullptr if this is the default selector
+ const Constant* Value() const { return val_; }
+
+ private:
+ const ast::CaseSelector* const decl_;
+ const Constant* const val_;
};
} // namespace tint::sem
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 17f5a5b..8edf832 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -944,7 +944,7 @@
ret_ty = ty;
}
- auto* case_sel = b.Expr(u32(column_idx));
+ auto* case_sel = b.CaseSelector(b.Expr(u32(column_idx)));
auto* case_body = b.Block(utils::Vector{b.Return(expr)});
cases.Push(b.Case(case_sel, case_body));
}
diff --git a/src/tint/transform/test_helper.h b/src/tint/transform/test_helper.h
index 42218a7..bc82fe5 100644
--- a/src/tint/transform/test_helper.h
+++ b/src/tint/transform/test_helper.h
@@ -115,7 +115,12 @@
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(Program&& program, const DataMap& data = {}) {
- EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
+ if (!program.IsValid()) {
+ ADD_FAILURE() << "ShouldRun() called with invalid program: "
+ << program.Diagnostics().str();
+ return false;
+ }
+
const Transform& t = TRANSFORM();
return t.ShouldRun(&program, data);
}
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index cf75678..e2ac5fe 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -1687,20 +1687,21 @@
}
bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
- if (stmt->IsDefault()) {
- line() << "default: {";
- } else {
- auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
- for (auto* selector : sem->Selectors()) {
- auto out = line();
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
+ auto out = line();
+
+ if (selector->IsDefault()) {
+ out << "default";
+ } else {
out << "case ";
- if (!EmitConstant(out, selector)) {
+ if (!EmitConstant(out, selector->Value())) {
return false;
}
- out << ":";
- if (selector == sem->Selectors().back()) {
- out << " {";
- }
+ }
+ out << ":";
+ if (selector == sem->Selectors().back()) {
+ out << " {";
}
}
diff --git a/src/tint/writer/glsl/generator_impl_case_test.cc b/src/tint/writer/glsl/generator_impl_case_test.cc
index 1f60781..4f1c6f3 100644
--- a/src/tint/writer/glsl/generator_impl_case_test.cc
+++ b/src/tint/writer/glsl/generator_impl_case_test.cc
@@ -23,7 +23,8 @@
using GlslGeneratorImplTest_Case = TestHelper;
TEST_F(GlslGeneratorImplTest_Case, Emit_Case) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
+ auto* s =
+ Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -38,7 +39,7 @@
}
TEST_F(GlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase());
+ auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -53,8 +54,8 @@
}
TEST_F(GlslGeneratorImplTest_Case, Emit_Case_WithFallthrough) {
- auto* s =
- Switch(1_i, Case(Expr(5_i), Block(create<ast::FallthroughStatement>())), DefaultCase());
+ auto* s = Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::FallthroughStatement>())),
+ DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -72,8 +73,8 @@
auto* s = Switch(1_i,
Case(
utils::Vector{
- Expr(5_i),
- Expr(6_i),
+ CaseSelector(5_i),
+ CaseSelector(6_i),
},
Block(create<ast::BreakStatement>())),
DefaultCase());
diff --git a/src/tint/writer/glsl/generator_impl_switch_test.cc b/src/tint/writer/glsl/generator_impl_switch_test.cc
index 7a2c750..ada3c0b 100644
--- a/src/tint/writer/glsl/generator_impl_switch_test.cc
+++ b/src/tint/writer/glsl/generator_impl_switch_test.cc
@@ -25,21 +25,13 @@
GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
auto* def_body = Block(create<ast::BreakStatement>());
- auto* def = create<ast::CaseStatement>(utils::Empty, def_body);
-
- utils::Vector case_val{Expr(5_i)};
+ auto* def = create<ast::CaseStatement>(utils::Vector{DefaultCaseSelector()}, def_body);
auto* case_body = Block(create<ast::BreakStatement>());
-
- auto* case_stmt = create<ast::CaseStatement>(case_val, case_body);
-
- utils::Vector body{
- case_stmt,
- def,
- };
+ auto* case_stmt = create<ast::CaseStatement>(utils::Vector{CaseSelector(5_i)}, case_body);
auto* cond = Expr("cond");
- auto* s = create<ast::SwitchStatement>(cond, body);
+ auto* s = create<ast::SwitchStatement>(cond, utils::Vector{case_stmt, def});
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -58,5 +50,30 @@
)");
}
+TEST_F(GlslGeneratorImplTest_Switch, Emit_Switch_MixedDefault) {
+ GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
+
+ auto* def_body = Block(create<ast::BreakStatement>());
+ auto* def = create<ast::CaseStatement>(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()},
+ def_body);
+
+ auto* cond = Expr("cond");
+ auto* s = create<ast::SwitchStatement>(cond, utils::Vector{def});
+ WrapInFunction(s);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( switch(cond) {
+ case 5:
+ default: {
+ break;
+ }
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::glsl
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 7bdb8e9..6a3c5c4 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -2562,20 +2562,20 @@
bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
auto* stmt = s->body[case_idx];
- if (stmt->IsDefault()) {
- line() << "default: {";
- } else {
- auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
- for (auto* selector : sem->Selectors()) {
- auto out = line();
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
+ auto out = line();
+ if (selector->IsDefault()) {
+ out << "default";
+ } else {
out << "case ";
- if (!EmitConstant(out, selector)) {
+ if (!EmitConstant(out, selector->Value())) {
return false;
}
- out << ":";
- if (selector == sem->Selectors().back()) {
- out << " {";
- }
+ }
+ out << ":";
+ if (selector == sem->Selectors().back()) {
+ out << " {";
}
}
@@ -3652,7 +3652,7 @@
}
bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
- TINT_ASSERT(Writer, stmt->body.Length() == 1 && stmt->body[0]->IsDefault());
+ TINT_ASSERT(Writer, stmt->body.Length() == 1 && stmt->body[0]->ContainsDefault());
// FXC fails to compile a switch with just a default case, ignoring the
// default case body. We work around this here by emitting the default case
@@ -3685,7 +3685,8 @@
bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
// BUG(crbug.com/tint/1188): work around default-only switches
- if (stmt->body.Length() == 1 && stmt->body[0]->IsDefault()) {
+ if (stmt->body.Length() == 1 && stmt->body[0]->selectors.Length() == 1 &&
+ stmt->body[0]->ContainsDefault()) {
return EmitDefaultOnlySwitch(stmt);
}
diff --git a/src/tint/writer/hlsl/generator_impl_case_test.cc b/src/tint/writer/hlsl/generator_impl_case_test.cc
index c55f14b..bfcbcaf 100644
--- a/src/tint/writer/hlsl/generator_impl_case_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_case_test.cc
@@ -23,7 +23,8 @@
using HlslGeneratorImplTest_Case = TestHelper;
TEST_F(HlslGeneratorImplTest_Case, Emit_Case) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
+ auto* s =
+ Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -38,7 +39,7 @@
}
TEST_F(HlslGeneratorImplTest_Case, Emit_Case_BreaksByDefault) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase());
+ auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -53,9 +54,9 @@
}
TEST_F(HlslGeneratorImplTest_Case, Emit_Case_WithFallthrough) {
- auto* s = Switch(1_i, //
- Case(Expr(4_i), Block(create<ast::FallthroughStatement>())), //
- Case(Expr(5_i), Block(create<ast::ReturnStatement>())), //
+ auto* s = Switch(1_i, //
+ Case(CaseSelector(4_i), Block(create<ast::FallthroughStatement>())), //
+ Case(CaseSelector(5_i), Block(create<ast::ReturnStatement>())), //
DefaultCase());
WrapInFunction(s);
@@ -75,9 +76,10 @@
}
TEST_F(HlslGeneratorImplTest_Case, Emit_Case_MultipleSelectors) {
- auto* s =
- Switch(1_i, Case(utils::Vector{Expr(5_i), Expr(6_i)}, Block(create<ast::BreakStatement>())),
- DefaultCase());
+ auto* s = Switch(1_i,
+ Case(utils::Vector{CaseSelector(5_i), CaseSelector(6_i)},
+ Block(create<ast::BreakStatement>())),
+ DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/hlsl/generator_impl_switch_test.cc b/src/tint/writer/hlsl/generator_impl_switch_test.cc
index 24c17a6..a84e632 100644
--- a/src/tint/writer/hlsl/generator_impl_switch_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_switch_test.cc
@@ -23,9 +23,9 @@
TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch) {
GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
- auto* s = Switch( //
- Expr("cond"), //
- Case(Expr(5_i), Block(Break())), //
+ auto* s = Switch( //
+ Expr("cond"), //
+ Case(CaseSelector(5_i), Block(Break())), //
DefaultCase());
WrapInFunction(s);
@@ -45,6 +45,27 @@
)");
}
+TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_MixedDefault) {
+ GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
+ auto* s = Switch( //
+ Expr("cond"), //
+ Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, Block(Break())));
+ WrapInFunction(s);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( switch(cond) {
+ case 5:
+ default: {
+ break;
+ }
+ }
+)");
+}
+
TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase) {
GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
GlobalVar("a", ty.i32(), ast::AddressSpace::kPrivate);
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index f765d93..976734f 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1589,20 +1589,21 @@
}
bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
- if (stmt->IsDefault()) {
- line() << "default: {";
- } else {
- auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
- for (auto* selector : sem->Selectors()) {
- auto out = line();
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
+ auto out = line();
+
+ if (selector->IsDefault()) {
+ out << "default";
+ } else {
out << "case ";
- if (!EmitConstant(out, selector)) {
+ if (!EmitConstant(out, selector->Value())) {
return false;
}
- out << ":";
- if (selector == sem->Selectors().back()) {
- out << " {";
- }
+ }
+ out << ":";
+ if (selector == sem->Selectors().back()) {
+ out << " {";
}
}
diff --git a/src/tint/writer/msl/generator_impl_case_test.cc b/src/tint/writer/msl/generator_impl_case_test.cc
index 250d67d..8aae4fe 100644
--- a/src/tint/writer/msl/generator_impl_case_test.cc
+++ b/src/tint/writer/msl/generator_impl_case_test.cc
@@ -23,7 +23,8 @@
using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, Emit_Case) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
+ auto* s =
+ Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -38,7 +39,7 @@
}
TEST_F(MslGeneratorImplTest, Emit_Case_BreaksByDefault) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block()), DefaultCase());
+ auto* s = Switch(1_i, Case(CaseSelector(5_i), Block()), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -53,8 +54,8 @@
}
TEST_F(MslGeneratorImplTest, Emit_Case_WithFallthrough) {
- auto* s =
- Switch(1_i, Case(Expr(5_i), Block(create<ast::FallthroughStatement>())), DefaultCase());
+ auto* s = Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::FallthroughStatement>())),
+ DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -72,8 +73,8 @@
auto* s = Switch(1_i,
Case(
utils::Vector{
- Expr(5_i),
- Expr(6_i),
+ CaseSelector(5_i),
+ CaseSelector(6_i),
},
Block(create<ast::BreakStatement>())),
DefaultCase());
diff --git a/src/tint/writer/msl/generator_impl_switch_test.cc b/src/tint/writer/msl/generator_impl_switch_test.cc
index b327f34..6b47d09 100644
--- a/src/tint/writer/msl/generator_impl_switch_test.cc
+++ b/src/tint/writer/msl/generator_impl_switch_test.cc
@@ -25,16 +25,13 @@
auto* cond = Var("cond", ty.i32());
auto* def_body = Block(create<ast::BreakStatement>());
- auto* def = create<ast::CaseStatement>(utils::Empty, def_body);
-
- utils::Vector case_val{Expr(5_i)};
+ auto* def = Case(DefaultCaseSelector(), def_body);
auto* case_body = Block(create<ast::BreakStatement>());
-
- auto* case_stmt = create<ast::CaseStatement>(case_val, case_body);
+ auto* case_stmt = Case(CaseSelector(5_i), case_body);
utils::Vector body{case_stmt, def};
- auto* s = create<ast::SwitchStatement>(Expr(cond), body);
+ auto* s = Switch(cond, body);
WrapInFunction(cond, s);
GeneratorImpl& gen = Build();
@@ -52,5 +49,27 @@
)");
}
+TEST_F(MslGeneratorImplTest, Emit_Switch_MixedDefault) {
+ auto* cond = Var("cond", ty.i32());
+
+ auto* def_body = Block(create<ast::BreakStatement>());
+ auto* def = Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, def_body);
+
+ auto* s = Switch(cond, def);
+ WrapInFunction(cond, s);
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( switch(cond) {
+ case 5:
+ default: {
+ break;
+ }
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::msl
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index b72de67..790ebe6 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -3456,19 +3456,26 @@
std::vector<uint32_t> case_ids;
for (const auto* item : stmt->body) {
- if (item->IsDefault()) {
- case_ids.push_back(default_block_id);
+ auto block_id = default_block_id;
+ if (!item->ContainsDefault()) {
+ auto block = result_op();
+ block_id = std::get<uint32_t>(block);
+ }
+ case_ids.push_back(block_id);
+
+ // If this case statement is only a default selector skip adding the block
+ // as it will be done below.
+ if (item->selectors.Length() == 1 && item->ContainsDefault()) {
continue;
}
- auto block = result_op();
- auto block_id = std::get<uint32_t>(block);
-
- case_ids.push_back(block_id);
-
auto* sem = builder_.Sem().Get<sem::CaseStatement>(item);
for (auto* selector : sem->Selectors()) {
- params.push_back(Operand(selector->As<uint32_t>()));
+ if (selector->IsDefault()) {
+ continue;
+ }
+
+ params.push_back(Operand(selector->Value()->As<uint32_t>()));
params.push_back(Operand(block_id));
}
}
@@ -3490,7 +3497,7 @@
for (uint32_t i = 0; i < body.Length(); i++) {
auto* item = body[i];
- if (item->IsDefault()) {
+ if (item->ContainsDefault()) {
generated_default = true;
}
diff --git a/src/tint/writer/spirv/builder_switch_test.cc b/src/tint/writer/spirv/builder_switch_test.cc
index 4ee8f39..c59c640 100644
--- a/src/tint/writer/spirv/builder_switch_test.cc
+++ b/src/tint/writer/spirv/builder_switch_test.cc
@@ -62,9 +62,9 @@
auto* func = Func("a_func", utils::Empty, ty.void_(),
utils::Vector{
- Switch("a", //
- Case(Expr(1_i), Block(Assign("v", 1_i))), //
- Case(Expr(2_i), Block(Assign("v", 2_i))), //
+ Switch("a", //
+ Case(CaseSelector(1_i), Block(Assign("v", 1_i))), //
+ Case(CaseSelector(2_i), Block(Assign("v", 2_i))), //
DefaultCase()),
});
@@ -119,9 +119,9 @@
auto* func = Func("a_func", utils::Empty, ty.void_(),
utils::Vector{
- Switch("a", //
- Case(Expr(1_u), Block(Assign("v", 1_i))), //
- Case(Expr(2_u), Block(Assign("v", 2_i))), //
+ Switch("a", //
+ Case(CaseSelector(1_u), Block(Assign("v", 1_i))), //
+ Case(CaseSelector(2_u), Block(Assign("v", 2_i))), //
DefaultCase()),
});
@@ -226,11 +226,11 @@
auto* func = Func("a_func", utils::Empty, ty.void_(),
utils::Vector{
- Switch(Expr("a"), //
- Case(Expr(1_i), //
- Block(Assign("v", 1_i))), //
- Case(utils::Vector{Expr(2_i), Expr(3_i)}, //
- Block(Assign("v", 2_i))), //
+ Switch(Expr("a"), //
+ Case(CaseSelector(1_i), //
+ Block(Assign("v", 1_i))), //
+ Case(utils::Vector{CaseSelector(2_i), CaseSelector(3_i)}, //
+ Block(Assign("v", 2_i))), //
DefaultCase(Block(Assign("v", 3_i)))),
});
@@ -273,6 +273,61 @@
)");
}
+TEST_F(BuilderTest, Switch_WithCaseAndMixedDefault) {
+ // switch(a) {
+ // case 1i:
+ // v = 1i;
+ // case 2i, 3i, default:
+ // v = 2i;
+ // }
+
+ auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
+ auto* a = GlobalVar("a", ty.i32(), ast::AddressSpace::kPrivate);
+
+ auto* func = Func("a_func", utils::Empty, ty.void_(),
+ utils::Vector{Switch(Expr("a"), //
+ Case(CaseSelector(1_i), //
+ Block(Assign("v", 1_i))), //
+ Case(utils::Vector{CaseSelector(2_i), CaseSelector(3_i),
+ DefaultCaseSelector()}, //
+ Block(Assign("v", 2_i))) //
+ )});
+
+ spirv::Builder& b = Build();
+
+ ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
+ ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
+OpName %5 "a"
+OpName %8 "a_func"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%5 = OpVariable %2 Private %4
+%7 = OpTypeVoid
+%6 = OpTypeFunction %7
+%14 = OpConstant %3 1
+%15 = OpConstant %3 2
+%8 = OpFunction %7 None %6
+%9 = OpLabel
+%11 = OpLoad %3 %5
+OpSelectionMerge %10 None
+OpSwitch %11 %12 1 %13 2 %12 3 %12
+%13 = OpLabel
+OpStore %1 %14
+OpBranch %10
+%12 = OpLabel
+OpStore %1 %15
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
// switch(a) {
// case 1i:
@@ -290,9 +345,9 @@
auto* func = Func("a_func", utils::Empty, ty.void_(),
utils::Vector{
Switch(Expr("a"), //
- Case(Expr(1_i), //
+ Case(CaseSelector(1_i), //
Block(Assign("v", 1_i), Fallthrough())), //
- Case(Expr(2_i), //
+ Case(CaseSelector(2_i), //
Block(Assign("v", 2_i))), //
DefaultCase(Block(Assign("v", 3_i)))),
});
@@ -351,9 +406,9 @@
auto* func = Func("a_func", utils::Empty, ty.void_(),
utils::Vector{
- Switch("a", //
- Case(Expr(1_i), //
- Block( //
+ Switch("a", //
+ Case(CaseSelector(1_i), //
+ Block( //
If(Expr(true), Block(create<ast::BreakStatement>())),
Assign("v", 1_i))),
DefaultCase()),
@@ -414,9 +469,9 @@
auto* fn = Func("f", utils::Empty, ty.i32(),
utils::Vector{
- Switch(1_i, //
- Case(Expr(1_i), Block(Return(1_i))), //
- Case(Expr(2_i), Block(Fallthrough())), //
+ Switch(1_i, //
+ Case(CaseSelector(1_i), Block(Return(1_i))), //
+ Case(CaseSelector(2_i), Block(Fallthrough())), //
DefaultCase(Block(Return(3_i)))),
});
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 7183c74..118bd4c 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -1024,26 +1024,28 @@
}
bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
- if (stmt->IsDefault()) {
+ if (stmt->selectors.Length() == 1 && stmt->ContainsDefault()) {
line() << "default: {";
} else {
auto out = line();
out << "case ";
bool first = true;
- for (auto* expr : stmt->selectors) {
+ for (auto* sel : stmt->selectors) {
if (!first) {
out << ", ";
}
first = false;
- if (!EmitExpression(out, expr)) {
+
+ if (sel->IsDefault()) {
+ out << "default";
+ } else if (!EmitExpression(out, sel->expr)) {
return false;
}
}
out << ": {";
}
-
if (!EmitStatementsWithIndent(stmt->body->statements)) {
return false;
}
diff --git a/src/tint/writer/wgsl/generator_impl_case_test.cc b/src/tint/writer/wgsl/generator_impl_case_test.cc
index 6d59970..39c28cb 100644
--- a/src/tint/writer/wgsl/generator_impl_case_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_case_test.cc
@@ -22,7 +22,8 @@
using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Case) {
- auto* s = Switch(1_i, Case(Expr(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
+ auto* s =
+ Switch(1_i, Case(CaseSelector(5_i), Block(create<ast::BreakStatement>())), DefaultCase());
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -40,8 +41,8 @@
auto* s = Switch(1_i,
Case(
utils::Vector{
- Expr(5_i),
- Expr(6_i),
+ CaseSelector(5_i),
+ CaseSelector(6_i),
},
Block(create<ast::BreakStatement>())),
DefaultCase());
diff --git a/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc b/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc
index 2b12051..093fba1 100644
--- a/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_fallthrough_test.cc
@@ -23,8 +23,8 @@
TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) {
auto* f = create<ast::FallthroughStatement>();
- WrapInFunction(Switch(1_i, //
- Case(Expr(1_i), Block(f)), //
+ WrapInFunction(Switch(1_i, //
+ Case(CaseSelector(1_i), Block(f)), //
DefaultCase()));
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/wgsl/generator_impl_switch_test.cc b/src/tint/writer/wgsl/generator_impl_switch_test.cc
index 5cf0ed2..7a39079 100644
--- a/src/tint/writer/wgsl/generator_impl_switch_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_switch_test.cc
@@ -25,13 +25,10 @@
GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
auto* def_body = Block(create<ast::BreakStatement>());
- auto* def = create<ast::CaseStatement>(utils::Empty, def_body);
-
- utils::Vector case_val{Expr(5_i)};
+ auto* def = Case(DefaultCaseSelector(), def_body);
auto* case_body = Block(create<ast::BreakStatement>());
-
- auto* case_stmt = create<ast::CaseStatement>(case_val, case_body);
+ auto* case_stmt = Case(utils::Vector{CaseSelector(5_i)}, case_body);
utils::Vector body{
case_stmt,
@@ -39,7 +36,7 @@
};
auto* cond = Expr("cond");
- auto* s = create<ast::SwitchStatement>(cond, body);
+ auto* s = Switch(cond, body);
WrapInFunction(s);
GeneratorImpl& gen = Build();
@@ -58,5 +55,28 @@
)");
}
+TEST_F(WgslGeneratorImplTest, Emit_Switch_MixedDefault) {
+ GlobalVar("cond", ty.i32(), ast::AddressSpace::kPrivate);
+
+ auto* def_body = Block(create<ast::BreakStatement>());
+ auto* def = Case(utils::Vector{CaseSelector(5_i), DefaultCaseSelector()}, def_body);
+
+ auto* cond = Expr("cond");
+ auto* s = Switch(cond, utils::Vector{def});
+ WrapInFunction(s);
+
+ GeneratorImpl& gen = Build();
+
+ gen.increment_indent();
+
+ ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
+ EXPECT_EQ(gen.result(), R"( switch(cond) {
+ case 5i, default: {
+ break;
+ }
+ }
+)");
+}
+
} // namespace
} // namespace tint::writer::wgsl
diff --git a/test/tint/statements/switch/case_default.wgsl b/test/tint/statements/switch/case_default.wgsl
new file mode 100644
index 0000000..c7ffc49
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl
@@ -0,0 +1,16 @@
+@compute @workgroup_size(1)
+fn f() {
+ var i : i32;
+ var result : i32;
+ switch(i) {
+ case default: {
+ result = 10;
+ }
+ case 1: {
+ result = 22;
+ }
+ case 2: {
+ result = 33;
+ }
+ }
+}
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl b/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..059ec98
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.dxc.hlsl
@@ -0,0 +1,20 @@
+[numthreads(1, 1, 1)]
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ default: {
+ result = 10;
+ break;
+ }
+ case 1: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl b/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..059ec98
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.fxc.hlsl
@@ -0,0 +1,20 @@
+[numthreads(1, 1, 1)]
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ default: {
+ result = 10;
+ break;
+ }
+ case 1: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.glsl b/test/tint/statements/switch/case_default.wgsl.expected.glsl
new file mode 100644
index 0000000..3da9112
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.glsl
@@ -0,0 +1,26 @@
+#version 310 es
+
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ default: {
+ result = 10;
+ break;
+ }
+ case 1: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+ f();
+ return;
+}
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.msl b/test/tint/statements/switch/case_default.wgsl.expected.msl
new file mode 100644
index 0000000..a76392f
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.msl
@@ -0,0 +1,23 @@
+#include <metal_stdlib>
+
+using namespace metal;
+kernel void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ default: {
+ result = 10;
+ break;
+ }
+ case 1: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
+
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.spvasm b/test/tint/statements/switch/case_default.wgsl.expected.spvasm
new file mode 100644
index 0000000..6402c56
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.spvasm
@@ -0,0 +1,39 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 18
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %f "f"
+ OpExecutionMode %f LocalSize 1 1 1
+ OpName %f "f"
+ OpName %i "i"
+ OpName %result "result"
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+ %8 = OpConstantNull %int
+ %int_10 = OpConstant %int 10
+ %int_22 = OpConstant %int 22
+ %int_33 = OpConstant %int 33
+ %f = OpFunction %void None %1
+ %4 = OpLabel
+ %i = OpVariable %_ptr_Function_int Function %8
+ %result = OpVariable %_ptr_Function_int Function %8
+ %11 = OpLoad %int %i
+ OpSelectionMerge %10 None
+ OpSwitch %11 %12 1 %13 2 %14
+ %12 = OpLabel
+ OpStore %result %int_10
+ OpBranch %10
+ %13 = OpLabel
+ OpStore %result %int_22
+ OpBranch %10
+ %14 = OpLabel
+ OpStore %result %int_33
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/statements/switch/case_default.wgsl.expected.wgsl b/test/tint/statements/switch/case_default.wgsl.expected.wgsl
new file mode 100644
index 0000000..792708a
--- /dev/null
+++ b/test/tint/statements/switch/case_default.wgsl.expected.wgsl
@@ -0,0 +1,16 @@
+@compute @workgroup_size(1)
+fn f() {
+ var i : i32;
+ var result : i32;
+ switch(i) {
+ default: {
+ result = 10;
+ }
+ case 1: {
+ result = 22;
+ }
+ case 2: {
+ result = 33;
+ }
+ }
+}
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl b/test/tint/statements/switch/case_default_mixed.wgsl
new file mode 100644
index 0000000..e28b24a
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl
@@ -0,0 +1,16 @@
+@compute @workgroup_size(1)
+fn f() {
+ var i : i32;
+ var result : i32;
+ switch(i) {
+ case 0: {
+ result = 10;
+ }
+ case 1, default: {
+ result = 22;
+ }
+ case 2: {
+ result = 33;
+ }
+ }
+}
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..b15ec5b
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.dxc.hlsl
@@ -0,0 +1,21 @@
+[numthreads(1, 1, 1)]
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ case 0: {
+ result = 10;
+ break;
+ }
+ case 1:
+ default: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..b15ec5b
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.fxc.hlsl
@@ -0,0 +1,21 @@
+[numthreads(1, 1, 1)]
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ case 0: {
+ result = 10;
+ break;
+ }
+ case 1:
+ default: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl
new file mode 100644
index 0000000..cbd24c0
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.glsl
@@ -0,0 +1,27 @@
+#version 310 es
+
+void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ case 0: {
+ result = 10;
+ break;
+ }
+ case 1:
+ default: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+ f();
+ return;
+}
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl
new file mode 100644
index 0000000..f36caad
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.msl
@@ -0,0 +1,24 @@
+#include <metal_stdlib>
+
+using namespace metal;
+kernel void f() {
+ int i = 0;
+ int result = 0;
+ switch(i) {
+ case 0: {
+ result = 10;
+ break;
+ }
+ case 1:
+ default: {
+ result = 22;
+ break;
+ }
+ case 2: {
+ result = 33;
+ break;
+ }
+ }
+ return;
+}
+
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm b/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm
new file mode 100644
index 0000000..a212876
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.spvasm
@@ -0,0 +1,39 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 18
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %f "f"
+ OpExecutionMode %f LocalSize 1 1 1
+ OpName %f "f"
+ OpName %i "i"
+ OpName %result "result"
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+ %8 = OpConstantNull %int
+ %int_10 = OpConstant %int 10
+ %int_22 = OpConstant %int 22
+ %int_33 = OpConstant %int 33
+ %f = OpFunction %void None %1
+ %4 = OpLabel
+ %i = OpVariable %_ptr_Function_int Function %8
+ %result = OpVariable %_ptr_Function_int Function %8
+ %11 = OpLoad %int %i
+ OpSelectionMerge %10 None
+ OpSwitch %11 %12 0 %13 1 %12 2 %14
+ %13 = OpLabel
+ OpStore %result %int_10
+ OpBranch %10
+ %12 = OpLabel
+ OpStore %result %int_22
+ OpBranch %10
+ %14 = OpLabel
+ OpStore %result %int_33
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl b/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl
new file mode 100644
index 0000000..920b432
--- /dev/null
+++ b/test/tint/statements/switch/case_default_mixed.wgsl.expected.wgsl
@@ -0,0 +1,16 @@
+@compute @workgroup_size(1)
+fn f() {
+ var i : i32;
+ var result : i32;
+ switch(i) {
+ case 0: {
+ result = 10;
+ }
+ case 1, default: {
+ result = 22;
+ }
+ case 2: {
+ result = 33;
+ }
+ }
+}