Store expressions in switch case statements.
This CL moves switch case statements to store Expression instead
of an IntLiteralExpression. The SEM is updated to store the
materialized constant instead of accessing the expression value
directly.
Bug: tint:1633
Change-Id: Id79dabb806be1049f775299732bc1c7b1bf0c05f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106300
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/ast/case_statement.cc b/src/tint/ast/case_statement.cc
index 9f1c20e..3125f05 100644
--- a/src/tint/ast/case_statement.cc
+++ b/src/tint/ast/case_statement.cc
@@ -25,7 +25,7 @@
CaseStatement::CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
- utils::VectorRef<const IntLiteralExpression*> s,
+ utils::VectorRef<const Expression*> s,
const BlockStatement* b)
: Base(pid, nid, src), selectors(std::move(s)), body(b) {
TINT_ASSERT(AST, body);
diff --git a/src/tint/ast/case_statement.h b/src/tint/ast/case_statement.h
index 47d2097..eda9c01 100644
--- a/src/tint/ast/case_statement.h
+++ b/src/tint/ast/case_statement.h
@@ -18,7 +18,7 @@
#include <vector>
#include "src/tint/ast/block_statement.h"
-#include "src/tint/ast/int_literal_expression.h"
+#include "src/tint/ast/expression.h"
namespace tint::ast {
@@ -34,7 +34,7 @@
CaseStatement(ProgramID pid,
NodeID nid,
const Source& src,
- utils::VectorRef<const IntLiteralExpression*> selectors,
+ utils::VectorRef<const Expression*> selectors,
const BlockStatement* body);
/// Move constructor
CaseStatement(CaseStatement&&);
@@ -50,7 +50,7 @@
const CaseStatement* Clone(CloneContext* ctx) const override;
/// The case selectors, empty if none set
- const utils::Vector<const IntLiteralExpression*, 4> selectors;
+ const utils::Vector<const Expression*, 4> selectors;
/// The case body
const BlockStatement* const body;
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index f300710..892cf27 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -54,6 +54,7 @@
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/ast/index_accessor_expression.h"
+#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
@@ -2846,7 +2847,7 @@
/// @param body the case body
/// @returns the case statement pointer
const ast::CaseStatement* Case(const Source& source,
- utils::VectorRef<const ast::IntLiteralExpression*> selectors,
+ utils::VectorRef<const ast::Expression*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(source, std::move(selectors), body ? body : Block());
}
@@ -2855,7 +2856,7 @@
/// @param selectors list of selectors
/// @param body the case body
/// @returns the case statement pointer
- const ast::CaseStatement* Case(utils::VectorRef<const ast::IntLiteralExpression*> selectors,
+ const ast::CaseStatement* Case(utils::VectorRef<const ast::Expression*> selectors,
const ast::BlockStatement* body = nullptr) {
return create<ast::CaseStatement>(std::move(selectors), body ? body : Block());
}
@@ -2864,7 +2865,7 @@
/// @param selector a single case selector
/// @param body the case body
/// @returns the case statement pointer
- const ast::CaseStatement* Case(const ast::IntLiteralExpression* selector,
+ const ast::CaseStatement* Case(const ast::Expression* selector,
const ast::BlockStatement* body = nullptr) {
return Case(utils::Vector{selector}, body);
}
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 9c38818..2cc35d8 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -3024,7 +3024,7 @@
for (size_t i = last_clause_index;; --i) {
// Create a list of integer literals for the selector values leading to
// this case clause.
- utils::Vector<const ast::IntLiteralExpression*, 4> selectors;
+ utils::Vector<const ast::Expression*, 4> selectors;
const bool has_selectors = clause_heads[i]->case_values.has_value();
if (has_selectors) {
auto values = clause_heads[i]->case_values.value();
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index d8d7994..cb8e217 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -2148,21 +2148,19 @@
}
// case_selectors
-// : const_literal (COMMA const_literal)* COMMA?
+// : expression (COMMA expression)* COMMA?
Expect<ParserImpl::CaseSelectorList> ParserImpl::expect_case_selectors() {
CaseSelectorList selectors;
while (continue_parsing()) {
- auto cond = const_literal();
- if (cond.errored) {
+ auto expr = expression();
+ if (expr.errored) {
return Failure::kErrored;
- } else if (!cond.matched) {
- break;
- } else if (!cond->Is<ast::IntLiteralExpression>()) {
- return add_error(cond.value->source, "invalid case selector must be an integer value");
}
-
- selectors.Push(cond.value->As<ast::IntLiteralExpression>());
+ if (!expr.matched) {
+ break;
+ }
+ selectors.Push(expr.value);
if (!match(Token::Type::kComma)) {
break;
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 60b94de..3dbd780 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -74,7 +74,7 @@
/// Pre-determined small vector sizes for AST pointers
//! @cond Doxygen_Suppress
using AttributeList = utils::Vector<const ast::Attribute*, 4>;
- using CaseSelectorList = utils::Vector<const ast::IntLiteralExpression*, 4>;
+ using CaseSelectorList = utils::Vector<const ast::Expression*, 4>;
using CaseStatementList = utils::Vector<const ast::CaseStatement*, 4>;
using ExpressionList = utils::Vector<const ast::Expression*, 8>;
using ParameterList = utils::Vector<const ast::Parameter*, 8>;
diff --git a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
index 9416a4d..657b522 100644
--- a/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_error_msg_test.cc
@@ -1339,14 +1339,6 @@
)");
}
-TEST_F(ParserImplErrorTest, SwitchStmtInvalidCase2) {
- EXPECT("fn f() { switch(1) { case false: } }",
- R"(test.wgsl:1:27 error: invalid case selector must be an integer value
-fn f() { switch(1) { case false: } }
- ^^^^^
-)");
-}
-
TEST_F(ParserImplErrorTest, SwitchStmtCaseMissingLBrace) {
EXPECT("fn f() { switch(1) { case 1: } }",
R"(test.wgsl:1:30 error: expected '{' for case statement
diff --git a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
index 076b1c3..2b8b0bd 100644
--- a/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_switch_body_test.cc
@@ -26,10 +26,42 @@
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
+
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
- EXPECT_EQ(stmt->selectors[0]->value, 1);
- EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+ ASSERT_EQ(e->body->statements.Length(), 1u);
+ EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
+}
+
+TEST_F(ParserImplTest, SwitchBody_Case_Expression) {
+ auto p = parser("case 1 + 2 { a = 4; }");
+ auto e = p->switch_body();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::CaseStatement>());
+ EXPECT_FALSE(e->IsDefault());
+
+ auto* stmt = e->As<ast::CaseStatement>();
+ ASSERT_EQ(stmt->selectors.Length(), 1u);
+ ASSERT_TRUE(stmt->selectors[0]->Is<ast::BinaryExpression>());
+ auto* expr = stmt->selectors[0]->As<ast::BinaryExpression>();
+
+ EXPECT_EQ(ast::BinaryOp::kAdd, expr->op);
+ auto* v = expr->lhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 1u);
+
+ v = expr->rhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 2u);
+
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
@@ -43,10 +75,14 @@
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
+
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 1u);
- EXPECT_EQ(stmt->selectors[0]->value, 1);
- EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
ASSERT_EQ(e->body->statements.Length(), 1u);
EXPECT_TRUE(e->body->statements[0]->Is<ast::AssignmentStatement>());
}
@@ -62,9 +98,16 @@
EXPECT_FALSE(e->IsDefault());
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
- EXPECT_EQ(stmt->selectors[0]->value, 1);
- EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
- EXPECT_EQ(stmt->selectors[1]->value, 2);
+ ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
+ expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_TrailingComma_WithColon) {
@@ -76,15 +119,23 @@
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
+
auto* stmt = e->As<ast::CaseStatement>();
ASSERT_EQ(stmt->selectors.Length(), 2u);
- EXPECT_EQ(stmt->selectors[0]->value, 1);
- EXPECT_EQ(stmt->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
- EXPECT_EQ(stmt->selectors[1]->value, 2);
+ ASSERT_TRUE(stmt->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = stmt->selectors[0]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(stmt->selectors[1]->Is<ast::IntLiteralExpression>());
+ expr = stmt->selectors[1]->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
-TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
- auto p = parser("case a == 4: { a = 4; }");
+TEST_F(ParserImplTest, SwitchBody_Case_Invalid) {
+ auto p = parser("case if: { a = 4; }");
auto e = p->switch_body();
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(e.errored);
@@ -93,16 +144,6 @@
EXPECT_EQ(p->error(), "1:6: unable to parse case selectors");
}
-TEST_F(ParserImplTest, SwitchBody_Case_InvalidSelector_bool) {
- auto p = parser("case true: { a = 4; }");
- auto e = p->switch_body();
- EXPECT_TRUE(p->has_error());
- EXPECT_TRUE(e.errored);
- EXPECT_FALSE(e.matched);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:6: invalid case selector must be an integer value");
-}
-
TEST_F(ParserImplTest, SwitchBody_Case_MissingConstLiteral) {
auto p = parser("case: { a = 4; }");
auto e = p->switch_body();
@@ -164,10 +205,16 @@
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
- ASSERT_EQ(e->selectors[0]->value, 1);
- EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_EQ(e->selectors[1]->value, 2);
- EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
+ expr = e->selectors[1]->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectors_WithColon) {
@@ -181,10 +228,16 @@
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body->statements.Length(), 0u);
ASSERT_EQ(e->selectors.Length(), 2u);
- ASSERT_EQ(e->selectors[0]->value, 1);
- EXPECT_EQ(e->selectors[0]->suffix, ast::IntLiteralExpression::Suffix::kNone);
- ASSERT_EQ(e->selectors[1]->value, 2);
- EXPECT_EQ(e->selectors[1]->suffix, ast::IntLiteralExpression::Suffix::kNone);
+ ASSERT_TRUE(e->selectors[0]->Is<ast::IntLiteralExpression>());
+
+ auto* expr = e->selectors[0]->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 1);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(e->selectors[1]->Is<ast::IntLiteralExpression>());
+ expr = e->selectors[1]->As<ast::IntLiteralExpression>();
+ ASSERT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, SwitchBody_Case_MultipleSelectorsMissingComma) {
diff --git a/src/tint/resolver/control_block_validation_test.cc b/src/tint/resolver/control_block_validation_test.cc
index 9b0d289..403d0bc 100644
--- a/src/tint/resolver/control_block_validation_test.cc
+++ b/src/tint/resolver/control_block_validation_test.cc
@@ -25,7 +25,7 @@
class ResolverControlBlockValidationTest : public TestHelper, public testing::Test {};
-TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpressionNoneIntegerType_Fail) {
+TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_F32) {
// var a : f32 = 3.14;
// switch (a) {
// default: {}
@@ -43,6 +43,24 @@
"scalar integer type");
}
+TEST_F(ResolverControlBlockValidationTest, SwitchSelectorExpression_bool) {
+ // var a : bool = true;
+ // switch (a) {
+ // default: {}
+ // }
+ auto* var = Var("a", ty.bool_(), Expr(false));
+
+ auto* block = Block(Decl(var), Switch(Expr(Source{{12, 34}}, "a"), //
+ DefaultCase()));
+
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: switch statement selector expression must be of a "
+ "scalar integer type");
+}
+
TEST_F(ResolverControlBlockValidationTest, SwitchWithoutDefault_Fail) {
// var a : i32 = 2;
// switch (a) {
@@ -213,8 +231,8 @@
// }
auto* var = Var("a", ty.i32(), Expr(2_i));
- auto* block = Block(Decl(var), Switch("a", //
- Case(Source{{12, 34}}, utils::Vector{Expr(1_u)}), //
+ auto* block = Block(Decl(var), Switch("a", //
+ Case(Expr(Source{{12, 34}}, 1_u)), //
DefaultCase()));
WrapInFunction(block);
@@ -234,7 +252,7 @@
auto* block = Block(Decl(var), //
Switch("a", //
- Case(Source{{12, 34}}, utils::Vector{Expr(-1_i)}), //
+ Case(utils::Vector{Expr(Source{{12, 34}}, -1_i)}), //
DefaultCase()));
WrapInFunction(block);
@@ -332,6 +350,74 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Pass) {
+ // var a : i32 = 2;
+ // switch (a) {
+ // default: {}
+ // case 5 + 6: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(Add(5_i, 6_i))));
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixI32_Abstract) {
+ // var a = 2;
+ // switch (a) {
+ // default: {}
+ // case 5i + 6i: {}
+ // }
+ auto* var = Var("a", Expr(2_a));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(Add(5_i, 6_i))));
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_MixU32_Abstract) {
+ // var a = 2u;
+ // switch (a) {
+ // default: {}
+ // case 5 + 6: {}
+ // }
+ auto* var = Var("a", Expr(2_u));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(Add(5_a, 6_a))));
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverControlBlockValidationTest, SwitchCase_Expression_Multiple) {
+ // var a = 2u;
+ // switch (a) {
+ // default: {}
+ // case 5 + 6, 7+9, 2*4: {}
+ // }
+ auto* var = Var("a", Expr(2_u));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ DefaultCase(Source{{12, 34}}), //
+ Case(utils::Vector{Add(5_u, 6_u), Add(7_u, 9_u), Mul(2_u, 4_u)})));
+ WrapInFunction(block);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverControlBlockValidationTest, SwitchCaseAlias_Pass) {
// type MyInt = u32;
// var v: MyInt;
@@ -349,5 +435,85 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelector_Expression_Fail) {
+ // var a : i32 = 2i;
+ // switch (a) {
+ // case 10i: {}
+ // case 5i+5i: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(Expr(Source{{12, 34}}, 10_i)),
+ Case(Add(Source{{56, 78}}, 5_i, 5_i)), DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "56:78 error: duplicate switch case '10'\n"
+ "12:34 note: previous case declared here");
+}
+
+TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSameCase_BothExpression_Fail) {
+ // var a : i32 = 2i;
+ // switch (a) {
+ // case 5i+5i, 6i+4i: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i),
+ Add(Source{{12, 34}}, 6_i, 4_i)}),
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: duplicate switch case '10'\n"
+ "56:78 note: previous case declared here");
+}
+
+TEST_F(ResolverControlBlockValidationTest, NonUniqueCaseSelectorSame_Case_Expression_Fail) {
+ // var a : i32 = 2i;
+ // switch (a) {
+ // case 5u+5u, 10i: {}
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+
+ auto* block = Block(
+ Decl(var), //
+ Switch("a", //
+ Case(utils::Vector{Add(Source{{56, 78}}, 5_i, 5_i), Expr(Source{{12, 34}}, 10_i)}),
+ DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: duplicate switch case '10'\n"
+ "56:78 note: previous case declared here");
+}
+
+TEST_F(ResolverControlBlockValidationTest, Switch_OverrideCondition_Fail) {
+ // override a : i32 = 2;
+ // switch (a) {
+ // default: {}
+ // }
+ auto* var = Var("a", ty.i32(), Expr(2_i));
+ Override("b", ty.i32(), Expr(2_i));
+
+ auto* block = Block(Decl(var), //
+ Switch("a", //
+ Case(Expr(Source{{12, 34}}, "b")), DefaultCase()));
+ WrapInFunction(block);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: case selector must be a constant expression");
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index c252473..b2961bd 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1234,18 +1234,34 @@
});
}
-sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
+sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, const sem::Type* ty) {
auto* sem =
builder_->create<sem::CaseStatement>(stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
sem->Selectors().reserve(stmt->selectors.Length());
for (auto* sel : stmt->selectors) {
- auto* expr = Expression(sel);
- if (!expr) {
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
+ // The sem statement is created in the switch when attempting to determine the common
+ // type.
+ auto* materialized = Materialize(sem_.Get(sel), ty);
+ if (!materialized) {
return false;
}
- sem->Selectors().emplace_back(expr);
+ if (!materialized->Type()->IsAnyOf<sem::I32, sem::U32>()) {
+ AddError("case selector must be an i32 or u32 value", sel->source);
+ return false;
+ }
+ auto const_value = materialized->ConstantValue();
+ if (!const_value) {
+ AddError("case selector must be a constant expression", sel->source);
+ return false;
+ }
+
+ sem->Selectors().emplace_back(const_value);
}
+
Mark(stmt->body);
auto* body = BlockStatement(stmt->body);
if (!body) {
@@ -3082,27 +3098,16 @@
auto* cond_ty = cond->Type()->UnwrapRef();
- utils::Vector<const sem::Type*, 8> types;
- types.Push(cond_ty);
-
- utils::Vector<sem::CaseStatement*, 4> cases;
- cases.Reserve(stmt->body.Length());
- for (auto* case_stmt : stmt->body) {
- Mark(case_stmt);
- auto* c = CaseStatement(case_stmt);
- if (!c) {
- return false;
- }
- for (auto* expr : c->Selectors()) {
- types.Push(expr->Type()->UnwrapRef());
- }
- cases.Push(c);
- behaviors.Add(c->Behaviors());
- sem->Cases().emplace_back(c);
- }
-
// Determine the common type across all selectors and the switch expression
// This must materialize to an integer scalar (non-abstract).
+ utils::Vector<const sem::Type*, 8> types;
+ types.Push(cond_ty);
+ for (auto* case_stmt : stmt->body) {
+ for (auto* expr : case_stmt->selectors) {
+ auto* sem_expr = Expression(expr);
+ types.Push(sem_expr->Type()->UnwrapRef());
+ }
+ }
auto* common_ty = sem::Type::Common(types);
if (!common_ty || !common_ty->is_integer_scalar()) {
// No common type found or the common type was abstract.
@@ -3113,13 +3118,21 @@
if (!cond) {
return false;
}
- for (auto* c : cases) {
- for (auto*& sel : c->Selectors()) { // Note: pointer reference
- sel = Materialize(sel, common_ty);
- if (!sel) {
- return false;
- }
+
+ utils::Vector<sem::CaseStatement*, 4> cases;
+ cases.Reserve(stmt->body.Length());
+ for (auto* case_stmt : stmt->body) {
+ Mark(case_stmt);
+ auto* c = CaseStatement(case_stmt, common_ty);
+ if (!c) {
+ return false;
}
+ for (auto* expr : c->Selectors()) {
+ types.Push(expr->Type()->UnwrapRef());
+ }
+ cases.Push(c);
+ behaviors.Add(c->Behaviors());
+ sem->Cases().emplace_back(c);
}
if (behaviors.Contains(sem::Behavior::kBreak)) {
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index c25b48e..e38d62d 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -209,7 +209,7 @@
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
- sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
+ sem::CaseStatement* CaseStatement(const ast::CaseStatement*, const sem::Type*);
sem::Statement* CompoundAssignmentStatement(const ast::CompoundAssignmentStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 5e4494c..4538bf4 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -132,8 +132,6 @@
ASSERT_EQ(sem->Cases().size(), 2u);
EXPECT_EQ(sem->Cases()[0]->Declaration(), cse);
ASSERT_EQ(sem->Cases()[0]->Selectors().size(), 1u);
- EXPECT_EQ(sem->Cases()[0]->Selectors()[0]->Declaration(), sel);
- EXPECT_EQ(sem->Cases()[1]->Declaration(), def);
EXPECT_EQ(sem->Cases()[1]->Selectors().size(), 0u);
}
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index e14cdd6..f738274 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -2338,22 +2338,36 @@
has_default = true;
}
- for (auto* selector : case_stmt->selectors) {
- if (cond_ty != sem_.TypeOf(selector)) {
+ auto* case_sem = sem_.Get<sem::CaseStatement>(case_stmt);
+
+ auto& case_selectors = case_stmt->selectors;
+ auto& selector_values = case_sem->Selectors();
+ TINT_ASSERT(Resolver, case_selectors.Length() == selector_values.size());
+ for (size_t i = 0; i < case_sem->Selectors().size(); ++i) {
+ auto* selector = selector_values[i];
+ if (cond_ty != selector->Type()) {
AddError(
"the case selector values must have the same type as the selector expression.",
- case_stmt->source);
+ case_selectors[i]->source);
return false;
}
- auto it = selectors.find(selector->value);
+ auto value = selector->As<uint32_t>();
+ auto it = selectors.find(value);
if (it != selectors.end()) {
- auto val = std::to_string(selector->value);
- AddError("duplicate switch case '" + val + "'", selector->source);
+ std::string err = "duplicate switch case '";
+ if (selector->Type()->Is<sem::I32>()) {
+ err += std::to_string(selector->As<int32_t>());
+ } else {
+ err += std::to_string(value);
+ }
+ err += "'";
+
+ AddError(err, case_selectors[i]->source);
AddNote("previous case declared here", it->second);
return false;
}
- selectors.emplace(selector->value, selector->source);
+ selectors.emplace(value, case_selectors[i]->source);
}
}
diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h
index a6b5c00..7028c05 100644
--- a/src/tint/sem/switch_statement.h
+++ b/src/tint/sem/switch_statement.h
@@ -26,6 +26,7 @@
} // namespace tint::ast
namespace tint::sem {
class CaseStatement;
+class Constant;
class Expression;
} // namespace tint::sem
@@ -82,14 +83,14 @@
const BlockStatement* Body() const { return body_; }
/// @returns the selectors for the case
- std::vector<const Expression*>& Selectors() { return selectors_; }
+ std::vector<const Constant*>& Selectors() { return selectors_; }
/// @returns the selectors for the case
- const std::vector<const Expression*>& Selectors() const { return selectors_; }
+ const std::vector<const Constant*>& Selectors() const { return selectors_; }
private:
const BlockStatement* body_ = nullptr;
- std::vector<const Expression*> selectors_;
+ std::vector<const Constant*> selectors_;
};
} // namespace tint::sem
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 49da7d6..cf75678 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -43,6 +43,7 @@
#include "src/tint/sem/statement.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@@ -1689,14 +1690,15 @@
if (stmt->IsDefault()) {
line() << "default: {";
} else {
- for (auto* selector : stmt->selectors) {
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
- if (!EmitLiteral(out, selector)) {
+ if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
- if (selector == stmt->selectors.Back()) {
+ if (selector == sem->Selectors().back()) {
out << " {";
}
}
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 439f0ee..7bdb8e9 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -44,6 +44,7 @@
#include "src/tint/sem/statement.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@@ -2564,14 +2565,15 @@
if (stmt->IsDefault()) {
line() << "default: {";
} else {
- for (auto* selector : stmt->selectors) {
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
- if (!EmitLiteral(out, selector)) {
+ if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
- if (selector == stmt->selectors.Back()) {
+ if (selector == sem->Selectors().back()) {
out << " {";
}
}
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 081444f..f765d93 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -52,6 +52,7 @@
#include "src/tint/sem/sampled_texture.h"
#include "src/tint/sem/storage_texture.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/u32.h"
@@ -1591,14 +1592,15 @@
if (stmt->IsDefault()) {
line() << "default: {";
} else {
- for (auto* selector : stmt->selectors) {
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(stmt);
+ for (auto* selector : sem->Selectors()) {
auto out = line();
out << "case ";
- if (!EmitLiteral(out, selector)) {
+ if (!EmitConstant(out, selector)) {
return false;
}
out << ":";
- if (selector == stmt->selectors.Back()) {
+ if (selector == sem->Selectors().back()) {
out << " {";
}
}
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 96edf6c..b72de67 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -39,6 +39,7 @@
#include "src/tint/sem/sampled_texture.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
@@ -3464,14 +3465,10 @@
auto block_id = std::get<uint32_t>(block);
case_ids.push_back(block_id);
- for (auto* selector : item->selectors) {
- auto* int_literal = selector->As<ast::IntLiteralExpression>();
- if (!int_literal) {
- error_ = "expected integer literal for switch case label";
- return false;
- }
- params.push_back(Operand(static_cast<uint32_t>(int_literal->value)));
+ auto* sem = builder_.Sem().Get<sem::CaseStatement>(item);
+ for (auto* selector : sem->Selectors()) {
+ params.push_back(Operand(selector->As<uint32_t>()));
params.push_back(Operand(block_id));
}
}
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 7795455..7183c74 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -50,6 +50,7 @@
#include "src/tint/ast/void.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/sem/struct.h"
+#include "src/tint/sem/switch_statement.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/writer/float_to_string.h"
@@ -1030,13 +1031,13 @@
out << "case ";
bool first = true;
- for (auto* selector : stmt->selectors) {
+ for (auto* expr : stmt->selectors) {
if (!first) {
out << ", ";
}
first = false;
- if (!EmitLiteral(out, selector)) {
+ if (!EmitExpression(out, expr)) {
return false;
}
}