[spirv-writer] Add else support.
This CL adds the start of support for else statements.
Bug: tint:5
Change-Id: I742fd4582bfee4f31715b94b7aea6cf8383f4e22
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19412
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 566b8c7..14bcb19 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -24,6 +24,7 @@
#include "src/ast/builtin_decoration.h"
#include "src/ast/constructor_expression.h"
#include "src/ast/decorated_variable.h"
+#include "src/ast/else_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@@ -658,19 +659,26 @@
push_function_inst(spv::Op::OpLabel, {true_block});
for (const auto& inst : stmt->body()) {
if (!GenerateStatement(inst.get())) {
- return 0;
+ return false;
}
}
// TODO(dsinclair): The branch should be optional based on how the
// StatementList ended ...
-
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
if (false_block_id != merge_block_id) {
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
- // TODO(dsinclair): Output else statements, pass in merge_block_id?
+ for (const auto& else_stmt : stmt->else_statements()) {
+ if (!GenerateElseStatement(else_stmt.get())) {
+ return false;
+ }
+ }
+
+ // TODO(dsinclair): The branch should be optional based on how the
+ // StatementList ended ...
+ push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
}
// Output the merge block
@@ -679,6 +687,21 @@
return true;
}
+bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) {
+ // TODO(dsinclair): handle else if
+ if (stmt->HasCondition()) {
+ error_ = "else if not handled yet";
+ return false;
+ }
+
+ for (const auto& inst : stmt->body()) {
+ if (!GenerateStatement(inst.get())) {
+ return false;
+ }
+ }
+ return true;
+}
+
bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) {
if (stmt->has_value()) {
auto val_id = GenerateExpression(stmt->value());
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index d2981a1..dab1e3e 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -148,6 +148,10 @@
/// @param assign the statement to generate
/// @returns true if the statement was successfully generated
bool GenerateAssignStatement(ast::AssignmentStatement* assign);
+ /// Generates an else statement
+ /// @param stmt the statement to generate
+ /// @returns true on successfull generation
+ bool GenerateElseStatement(ast::ElseStatement* stmt);
/// Generates an entry point instruction
/// @param ep the entry point
/// @returns true if the instruction was generated, false otherwise
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 1be836a..15d6809 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bool_literal.h"
+#include "src/ast/else_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/int_literal.h"
@@ -109,19 +110,88 @@
)");
}
-TEST_F(BuilderTest, DISABLED_If_WithStatements_Returns) {
- // if (a) { return; }
-}
+TEST_F(BuilderTest, If_WithElse) {
+ ast::type::BoolType bool_type;
+ ast::type::I32Type i32;
-TEST_F(BuilderTest, DISABLED_If_WithElse) {}
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 2))));
+
+ ast::StatementList else_body;
+ else_body.push_back(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("v"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::IntLiteral>(&i32, 3))));
+
+ ast::ElseStatementList else_stmts;
+ else_stmts.push_back(
+ std::make_unique<ast::ElseStatement>(std::move(else_body)));
+
+ auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+ ast::IfStatement expr(std::move(cond), std::move(body));
+ expr.set_else_statements(std::move(else_stmts));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%4 = OpTypeBool
+%5 = OpConstantTrue %4
+%9 = OpConstant %3 2
+%10 = OpConstant %3 3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(OpSelectionMerge %6 None
+OpBranchConditional %5 %7 %8
+%7 = OpLabel
+OpStore %1 %9
+OpBranch %6
+%8 = OpLabel
+OpStore %1 %10
+OpBranch %6
+%6 = OpLabel
+)");
+}
TEST_F(BuilderTest, DISABLED_If_WithElseIf) {}
TEST_F(BuilderTest, DISABLED_If_WithMultiple) {}
-TEST_F(BuilderTest, DISABLED_If_WithBreak) {}
+TEST_F(BuilderTest, DISABLED_If_WithBreak) {
+ // if (a) {
+ // break;
+ // }
+}
-TEST_F(BuilderTest, DISABLED_If_WithContinue) {}
+TEST_F(BuilderTest, DISABLED_If_WithContinue) {
+ // if (a) {
+ // continue;
+ // }
+}
+
+TEST_F(BuilderTest, DISABLED_IF_WithReturn) {
+ // if (a) {
+ // return;
+ // }
+}
} // namespace
} // namespace spirv