[spirv-writer] Add fallthrough support
This CL adds support for the fallthrough statement in a `case` block.
Bug: tint:5
Change-Id: I282643a304846a19212d41bd8bd20a60398bd793
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22220
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index d633f9c..7169411 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -89,6 +89,10 @@
return model;
}
+bool LastIsFallthrough(const ast::StatementList& stmts) {
+ return !stmts.empty() && stmts.back()->IsFallthrough();
+}
+
// A terminator is anything which will case a SPIR-V terminator to be emitted.
// This means things like breaks, fallthroughs and continues which all emit an
// OpBranch or return for the OpReturn emission.
@@ -1395,7 +1399,13 @@
return false;
}
- if (!LastIsTerminator(item->body())) {
+ if (LastIsFallthrough(item->body())) {
+ if (i == (body.size() - 1)) {
+ error_ = "fallthrough of last case statement is disallowed";
+ return false;
+ }
+ push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])});
+ } else if (!LastIsTerminator(item->body())) {
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
}
}
@@ -1491,6 +1501,10 @@
if (stmt->IsContinue()) {
return GenerateContinueStatement(stmt->AsContinue());
}
+ if (stmt->IsFallthrough()) {
+ // Do nothing here, the fallthrough gets handled by the switch code.
+ return true;
+ }
if (stmt->IsIf()) {
return GenerateIfStatement(stmt->AsIf());
}
diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc
index 1284345..d717c08 100644
--- a/src/writer/spirv/builder_switch_test.cc
+++ b/src/writer/spirv/builder_switch_test.cc
@@ -19,6 +19,7 @@
#include "src/ast/bool_literal.h"
#include "src/ast/break_statement.h"
#include "src/ast/case_statement.h"
+#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/int_literal.h"
@@ -321,15 +322,145 @@
)");
}
-TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) {
- // switch (a) {
+TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
+ ast::type::I32Type i32;
+
+ // switch(a) {
// case 1:
- // v = 1;
- // fallthrough;
+ // v = 1;
+ // fallthrough;
// case 2:
- // v = 2;
+ // v = 2;
+ // default:
+ // v = 3;
// }
- FAIL();
+
+ auto v =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+ auto a =
+ std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
+
+ ast::StatementList case_1_body;
+ case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 1))));
+ case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
+
+ ast::StatementList case_2_body;
+ case_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 2))));
+
+ ast::StatementList default_body;
+ default_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 3))));
+
+ ast::CaseStatementList cases;
+ cases.push_back(std::make_unique<ast::CaseStatement>(
+ std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
+ cases.push_back(std::make_unique<ast::CaseStatement>(
+ std::make_unique<ast::IntLiteral>(&i32, 2), std::move(case_2_body)));
+ cases.push_back(
+ std::make_unique<ast::CaseStatement>(std::move(default_body)));
+
+ ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
+ std::move(cases));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(v.get());
+ td.RegisterVariableForTesting(a.get());
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ ast::Function func("a_func", {}, &i32);
+
+ Builder b(&mod);
+ ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
+ ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+
+ EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error();
+
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
+OpName %5 "a"
+OpName %7 "a_func"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%5 = OpVariable %2 Private %4
+%6 = OpTypeFunction %3
+%14 = OpConstant %3 1
+%15 = OpConstant %3 2
+%16 = OpConstant %3 3
+%7 = OpFunction %3 None %6
+%8 = OpLabel
+%10 = OpLoad %3 %5
+OpSelectionMerge %9 None
+OpSwitch %10 %11 1 %12 2 %13
+%12 = OpLabel
+OpStore %1 %14
+OpBranch %13
+%13 = OpLabel
+OpStore %1 %15
+OpBranch %9
+%11 = OpLabel
+OpStore %1 %16
+OpBranch %9
+%9 = OpLabel
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
+ ast::type::I32Type i32;
+
+ // switch(a) {
+ // case 1:
+ // v = 1;
+ // fallthrough;
+ // }
+
+ auto v =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+ auto a =
+ std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
+
+ ast::StatementList case_1_body;
+ case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 1))));
+ case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
+
+ ast::CaseStatementList cases;
+ cases.push_back(std::make_unique<ast::CaseStatement>(
+ std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
+
+ ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
+ std::move(cases));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(v.get());
+ td.RegisterVariableForTesting(a.get());
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ ast::Function func("a_func", {}, &i32);
+
+ Builder b(&mod);
+ ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
+ ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+
+ EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error();
+ EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed");
}
// TODO(dsinclair): Implement when parsing is handled for multi-value