[spirv-writer] Add elseif support.
This CL adds support for having elseif statements after an if statement.
Bug: tint:5
Change-Id: I3cd3c5bddaa57c998b1a3fbee7bd87536533301d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19500
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 14bcb19..8bf9da4 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -630,8 +630,12 @@
return result_id;
}
-bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
- auto cond_id = GenerateExpression(stmt->condition());
+bool Builder::GenerateConditionalBlock(
+ ast::Expression* cond,
+ const ast::StatementList& true_body,
+ size_t cur_else_idx,
+ const ast::ElseStatementList& else_stmts) {
+ auto cond_id = GenerateExpression(cond);
if (cond_id == 0) {
return false;
}
@@ -646,10 +650,10 @@
auto true_block = result_op();
auto true_block_id = true_block.to_i();
- // if there are no else statements we branch on false to the merge block
+ // if there are no more else statements we branch on false to the merge block
// otherwise we branch to the false block
auto false_block_id =
- stmt->has_else_statements() ? next_id() : merge_block_id;
+ cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
push_function_inst(spv::Op::OpBranchConditional,
{Operand::Int(cond_id), Operand::Int(true_block_id),
@@ -657,28 +661,33 @@
// Output true block
push_function_inst(spv::Op::OpLabel, {true_block});
- for (const auto& inst : stmt->body()) {
- if (!GenerateStatement(inst.get())) {
- return false;
- }
+ if (!GenerateStatementList(true_body)) {
+ return false;
}
-
// TODO(dsinclair): The branch should be optional based on how the
// StatementList ended ...
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+ // Start the false block if needed
if (false_block_id != merge_block_id) {
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
- for (const auto& else_stmt : stmt->else_statements()) {
- if (!GenerateElseStatement(else_stmt.get())) {
+ auto* else_stmt = else_stmts[cur_else_idx].get();
+ // Handle the else case by just outputting the statements.
+ if (!else_stmt->HasCondition()) {
+ if (!GenerateStatementList(else_stmt->body())) {
return false;
}
+ // TODO(dsinclair): The branch should be optional based on how the
+ // StatementList ended ...
+ push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+ } else {
+ if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(),
+ cur_else_idx + 1, else_stmts)) {
+ return false;
+ }
+ push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
}
-
- // TODO(dsinclair): The branch should be optional based on how the
- // StatementList ended ...
- push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
}
// Output the merge block
@@ -687,18 +696,11 @@
return true;
}
-bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) {
- // TODO(dsinclair): handle else if
- if (stmt->HasCondition()) {
- error_ = "else if not handled yet";
+bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
+ if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0,
+ stmt->else_statements())) {
return false;
}
-
- for (const auto& inst : stmt->body()) {
- if (!GenerateStatement(inst.get())) {
- return false;
- }
- }
return true;
}
@@ -716,6 +718,15 @@
return true;
}
+bool Builder::GenerateStatementList(const ast::StatementList& list) {
+ for (const auto& inst : list) {
+ if (!GenerateStatement(inst.get())) {
+ return false;
+ }
+ }
+ return true;
+}
+
bool Builder::GenerateStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) {
return GenerateAssignStatement(stmt->AsAssign());
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index dab1e3e..0faf52d 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -22,6 +22,7 @@
#include "spirv/unified1/spirv.h"
#include "src/ast/builtin.h"
+#include "src/ast/else_statement.h"
#include "src/ast/literal.h"
#include "src/ast/module.h"
#include "src/ast/struct_member.h"
@@ -148,10 +149,6 @@
/// @param assign the statement to generate
/// @returns true if the statement was successfully generated
bool GenerateAssignStatement(ast::AssignmentStatement* assign);
- /// Generates an else statement
- /// @param stmt the statement to generate
- /// @returns true on successfull generation
- bool GenerateElseStatement(ast::ElseStatement* stmt);
/// Generates an entry point instruction
/// @param ep the entry point
/// @returns true if the instruction was generated, false otherwise
@@ -209,10 +206,24 @@
/// @param stmt the statement to generate
/// @returns true on success, false otherwise
bool GenerateReturnStatement(ast::ReturnStatement* stmt);
+ /// Generates a conditional section merge block
+ /// @param cond the condition
+ /// @param true_body the statements making up the true block
+ /// @param cur_else_idx the index of the current else statement to process
+ /// @param else_stmts the list of all else statements
+ /// @returns true on success, false on failure
+ bool GenerateConditionalBlock(ast::Expression* cond,
+ const ast::StatementList& true_body,
+ size_t cur_else_idx,
+ const ast::ElseStatementList& else_stmts);
/// Generates a statement
/// @param stmt the statement to generate
/// @returns true if the statement was generated
bool GenerateStatement(ast::Statement* stmt);
+ /// Generates a list of statements
+ /// @param list the statement list to generate
+ /// @returns true on successful generation
+ bool GenerateStatementList(const ast::StatementList& list);
/// Geneates an OpStore
/// @param to the ID to store too
/// @param from the ID to store from
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 15d6809..e7f8bfb 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -39,6 +39,8 @@
TEST_F(BuilderTest, If_Empty) {
ast::type::BoolType bool_type;
+ // if (true) {
+ // }
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true));
@@ -68,6 +70,9 @@
ast::type::BoolType bool_type;
ast::type::I32Type i32;
+ // if (true) {
+ // v = 2;
+ // }
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
@@ -114,6 +119,11 @@
ast::type::BoolType bool_type;
ast::type::I32Type i32;
+ // if (true) {
+ // v = 2;
+ // } else {
+ // v = 3;
+ // }
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
@@ -171,9 +181,186 @@
)");
}
-TEST_F(BuilderTest, DISABLED_If_WithElseIf) {}
+TEST_F(BuilderTest, If_WithElseIf) {
+ ast::type::BoolType bool_type;
+ ast::type::I32Type i32;
-TEST_F(BuilderTest, DISABLED_If_WithMultiple) {}
+ // if (true) {
+ // v = 2;
+ // } elseif (true) {
+ // v = 3;
+ // }
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 2))));
+
+ ast::StatementList else_body;
+ else_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 3))));
+
+ auto else_cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+ ast::ElseStatementList else_stmts;
+ else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+ std::move(else_cond), std::move(else_body)));
+
+ auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+ ast::IfStatement expr(std::move(cond), std::move(body));
+ expr.set_else_statements(std::move(else_stmts));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%4 = OpTypeBool
+%5 = OpConstantTrue %4
+%9 = OpConstant %3 2
+%12 = OpConstant %3 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpSelectionMerge %6 None
+OpBranchConditional %5 %7 %8
+%7 = OpLabel
+OpStore %1 %9
+OpBranch %6
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %5 %11 %10
+%11 = OpLabel
+OpStore %1 %12
+OpBranch %10
+%10 = OpLabel
+OpBranch %6
+%6 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, If_WithMultiple) {
+ ast::type::BoolType bool_type;
+ ast::type::I32Type i32;
+
+ // if (true) {
+ // v = 2;
+ // } elseif (true) {
+ // v = 3;
+ // } elseif (false) {
+ // v = 4;
+ // } else {
+ // v = 5;
+ // }
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 2))));
+ ast::StatementList elseif_1_body;
+ elseif_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 3))));
+ ast::StatementList elseif_2_body;
+ elseif_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 4))));
+ ast::StatementList else_body;
+ else_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 5))));
+
+ auto elseif_1_cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true));
+ auto elseif_2_cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, false));
+
+ ast::ElseStatementList else_stmts;
+ else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+ std::move(elseif_1_cond), std::move(elseif_1_body)));
+ else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+ std::move(elseif_2_cond), std::move(elseif_2_body)));
+ else_stmts.push_back(
+ std::make_unique<ast::ElseStatement>(std::move(else_body)));
+
+ auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+ ast::IfStatement expr(std::move(cond), std::move(body));
+ expr.set_else_statements(std::move(else_stmts));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%4 = OpTypeBool
+%5 = OpConstantTrue %4
+%9 = OpConstant %3 2
+%13 = OpConstant %3 3
+%14 = OpConstantFalse %4
+%18 = OpConstant %3 4
+%19 = OpConstant %3 5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpSelectionMerge %6 None
+OpBranchConditional %5 %7 %8
+%7 = OpLabel
+OpStore %1 %9
+OpBranch %6
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %5 %11 %12
+%11 = OpLabel
+OpStore %1 %13
+OpBranch %10
+%12 = OpLabel
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %17
+%16 = OpLabel
+OpStore %1 %18
+OpBranch %15
+%17 = OpLabel
+OpStore %1 %19
+OpBranch %15
+%15 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpBranch %6
+%6 = OpLabel
+)");
+}
TEST_F(BuilderTest, DISABLED_If_WithBreak) {
// if (a) {