Convert Function to use BlockStatement. This CL converts the Function class to using a BlockStatement internally. All usages have been updated execept for the two readers. Bug: tint:130 Change-Id: I7159cf2d3ed5cb8a34d51fbe848b88f0e5479605 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25720 Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index c7c1587..f1be730 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h
@@ -48,6 +48,11 @@ /// Retrieves the statement at |idx| /// @param idx the index. The index is not bounds checked. /// @returns the statement at |idx| + ast::Statement* get(size_t idx) { return statements_[idx].get(); } + + /// Retrieves the statement at |idx| + /// @param idx the index. The index is not bounds checked. + /// @returns the statement at |idx| ast::Statement* operator[](size_t idx) { return statements_[idx].get(); } /// Retrieves the statement at |idx| /// @param idx the index. The index is not bounds checked.
diff --git a/src/ast/function.cc b/src/ast/function.cc index 380d9a6..6fe5a32 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc
@@ -29,7 +29,8 @@ : Node(), name_(name), params_(std::move(params)), - return_type_(return_type) {} + return_type_(return_type), + body_(std::make_unique<BlockStatement>()) {} Function::Function(const Source& source, const std::string& name, @@ -38,7 +39,8 @@ : Node(source), name_(name), params_(std::move(params)), - return_type_(return_type) {} + return_type_(return_type), + body_(std::make_unique<BlockStatement>()) {} Function::Function(Function&&) = default; @@ -154,16 +156,20 @@ ancestor_entry_points_.push_back(ep); } +void Function::set_body(StatementList body) { + for (auto& stmt : body) { + body_->append(std::move(stmt)); + } +} + bool Function::IsValid() const { for (const auto& param : params_) { if (param == nullptr || !param->IsValid()) return false; } - for (const auto& stmt : body_) { - if (stmt == nullptr || !stmt->IsValid()) - return false; + if (body_ == nullptr || !body_->IsValid()) { + return false; } - if (name_.length() == 0) { return false; } @@ -194,7 +200,7 @@ make_indent(out, indent); out << "{" << std::endl; - for (const auto& stmt : body_) + for (const auto& stmt : *body_) stmt->to_str(out, indent + 2); make_indent(out, indent);
diff --git a/src/ast/function.h b/src/ast/function.h index b130ddc..a6dbe18 100644 --- a/src/ast/function.h +++ b/src/ast/function.h
@@ -22,6 +22,7 @@ #include <vector> #include "src/ast/binding_decoration.h" +#include "src/ast/block_statement.h" #include "src/ast/builtin_decoration.h" #include "src/ast/expression.h" #include "src/ast/location_decoration.h" @@ -123,9 +124,14 @@ /// Sets the body of the function /// @param body the function body - void set_body(StatementList body) { body_ = std::move(body); } + void set_body(StatementList body); + /// Sets the body of the function + /// @param body the function body + void set_body(std::unique_ptr<BlockStatement> body) { + body_ = std::move(body); + } /// @returns the function body - const StatementList& body() const { return body_; } + BlockStatement* body() const { return body_.get(); } /// @returns true if the name and type are both present bool IsValid() const override; @@ -144,7 +150,7 @@ std::string name_; VariableList params_; type::Type* return_type_ = nullptr; - StatementList body_; + std::unique_ptr<BlockStatement> body_; std::vector<Variable*> referenced_module_vars_; std::vector<std::string> ancestor_entry_points_; };
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index a2c6471..7988388 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc
@@ -188,11 +188,11 @@ params.push_back( std::make_unique<Variable>("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique<DiscardStatement>()); + auto block = std::make_unique<ast::BlockStatement>(); + block->append(std::make_unique<DiscardStatement>()); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_TRUE(f.IsValid()); } @@ -251,12 +251,12 @@ params.push_back( std::make_unique<Variable>("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique<DiscardStatement>()); - body.push_back(nullptr); + auto block = std::make_unique<ast::BlockStatement>(); + block->append(std::make_unique<DiscardStatement>()); + block->append(nullptr); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_FALSE(f.IsValid()); } @@ -268,12 +268,12 @@ params.push_back( std::make_unique<Variable>("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique<DiscardStatement>()); - body.push_back(nullptr); + auto block = std::make_unique<ast::BlockStatement>(); + block->append(std::make_unique<DiscardStatement>()); + block->append(nullptr); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_FALSE(f.IsValid()); } @@ -281,11 +281,11 @@ type::VoidType void_type; type::I32Type i32; - StatementList body; - body.push_back(std::make_unique<DiscardStatement>()); + auto block = std::make_unique<ast::BlockStatement>(); + block->append(std::make_unique<DiscardStatement>()); Function f("func", {}, &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); std::ostringstream out; f.to_str(out, 2); @@ -305,11 +305,11 @@ params.push_back( std::make_unique<Variable>("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique<DiscardStatement>()); + auto block = std::make_unique<ast::BlockStatement>(); + block->append(std::make_unique<DiscardStatement>()); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); std::ostringstream out; f.to_str(out, 2);
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc index 271634e..fe8e3c7 100644 --- a/src/reader/wgsl/parser_impl_function_decl_test.cc +++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -40,8 +40,9 @@ ASSERT_NE(f->return_type(), nullptr); EXPECT_TRUE(f->return_type()->IsVoid()); - ASSERT_EQ(f->body().size(), 1u); - EXPECT_TRUE(f->body()[0]->IsReturn()); + auto* body = f->body(); + ASSERT_EQ(body->size(), 1u); + EXPECT_TRUE(body->get(0)->IsReturn()); } TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 075eed5..c7f9bcc 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc
@@ -734,10 +734,9 @@ std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32); var->set_is_const(true); - ast::StatementList body; - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::AssignmentStatement>( std::move(my_var), std::make_unique<ast::IdentifierExpression>("my_var"))); @@ -756,12 +755,12 @@ auto my_var = std::make_unique<ast::IdentifierExpression>("my_var"); auto* my_var_ptr = my_var.get(); - ast::StatementList body; - body.push_back(std::make_unique<ast::VariableDeclStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>( std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::move(my_var), std::make_unique<ast::IdentifierExpression>("my_var"))); @@ -823,17 +822,17 @@ std::make_unique<ast::Function>("my_func", std::move(params), &f32); auto* func_ptr = func.get(); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("out_var"), std::make_unique<ast::IdentifierExpression>("in_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("wg_var"), std::make_unique<ast::IdentifierExpression>("wg_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("sb_var"), std::make_unique<ast::IdentifierExpression>("sb_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("priv_var"), std::make_unique<ast::IdentifierExpression>("priv_var"))); func->set_body(std::move(body)); @@ -882,17 +881,17 @@ auto func = std::make_unique<ast::Function>("my_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("out_var"), std::make_unique<ast::IdentifierExpression>("in_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("wg_var"), std::make_unique<ast::IdentifierExpression>("wg_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("sb_var"), std::make_unique<ast::IdentifierExpression>("sb_var"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("priv_var"), std::make_unique<ast::IdentifierExpression>("priv_var"))); func->set_body(std::move(body)); @@ -901,7 +900,9 @@ auto func2 = std::make_unique<ast::Function>("func", std::move(params), &f32); auto* func2_ptr = func2.get(); - body.push_back(std::make_unique<ast::AssignmentStatement>( + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("out_var"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("my_func"), @@ -933,9 +934,9 @@ std::make_unique<ast::Function>("my_func", std::move(params), &f32); auto* func_ptr = func.get(); - ast::StatementList body; - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("var"), std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.f)))); @@ -1990,9 +1991,10 @@ auto func = std::make_unique<ast::Function>("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -2011,9 +2013,10 @@ auto func = std::make_unique<ast::Function>("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -2030,9 +2033,10 @@ auto func = std::make_unique<ast::Function>("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -3963,13 +3967,14 @@ auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32); auto* func_b_ptr = func_b.get(); - ast::StatementList body; + auto body = std::make_unique<ast::BlockStatement>(); func_b->set_body(std::move(body)); auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32); auto* func_c_ptr = func_c.get(); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("second"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("b"), @@ -3979,7 +3984,8 @@ auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32); auto* func_a_ptr = func_a.get(); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("first"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("c"), @@ -3990,12 +3996,13 @@ std::make_unique<ast::Function>("ep_1_func", std::move(params), &f32); auto* ep_1_func_ptr = ep_1_func.get(); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("call_a"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("a"), ast::ExpressionList{}))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("call_b"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("b"), @@ -4006,7 +4013,8 @@ std::make_unique<ast::Function>("ep_2_func", std::move(params), &f32); auto* ep_2_func_ptr = ep_2_func.get(); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("call_c"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("c"),
diff --git a/src/validator_impl.cc b/src/validator_impl.cc index 9195baf..3645c96 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc
@@ -43,11 +43,20 @@ } bool ValidatorImpl::ValidateFunction(const ast::Function& func) { - if (!ValidateStatements(func.body())) + if (!ValidateStatements(*(func.body()))) return false; return true; } +bool ValidatorImpl::ValidateStatements(const ast::BlockStatement& block) { + for (const auto& stmt : block) { + if (!ValidateStatement(*(stmt.get()))) { + return false; + } + } + return true; +} + bool ValidatorImpl::ValidateStatements(const ast::StatementList& stmts) { for (const auto& stmt : stmts) { if (!ValidateStatement(*(stmt.get()))) {
diff --git a/src/validator_impl.h b/src/validator_impl.h index 27ffb8d..750fb66 100644 --- a/src/validator_impl.h +++ b/src/validator_impl.h
@@ -54,6 +54,10 @@ /// @param func the function to check /// @returns true if the validation was successful bool ValidateFunction(const ast::Function& func); + /// Validates a block of statements + /// @param block the statements to check + /// @returns true if the validation was successful + bool ValidateStatements(const ast::BlockStatement& block); /// Validates a set of statements /// @param stmts the statements to check /// @returns true if the validation was successful
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index f0588b6..79fd9ab 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc
@@ -1231,11 +1231,11 @@ } } - out_ << ")"; + out_ << ") "; current_ep_name_ = ep_name; - if (!EmitStatementBlockAndNewline(func->body())) { + if (!EmitBlockAndNewline(func->body())) { return false; } @@ -1395,7 +1395,7 @@ } generating_entry_point_ = true; - for (const auto& s : func->body()) { + for (const auto& s : *(func->body())) { if (!EmitStatement(s.get())) { return false; } @@ -1578,8 +1578,6 @@ } bool GeneratorImpl::EmitBlock(ast::BlockStatement* stmt) { - make_indent(); - out_ << "{" << std::endl; increment_indent(); @@ -1604,6 +1602,15 @@ return result; } +bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) { + make_indent(); + const bool result = EmitBlock(stmt); + if (result) { + out_ << std::endl; + } + return result; +} + bool GeneratorImpl::EmitStatementBlock(const ast::StatementList& statements) { out_ << " {" << std::endl; @@ -1636,7 +1643,7 @@ return EmitAssign(stmt->AsAssign()); } if (stmt->IsBlock()) { - return EmitBlockAndNewline(stmt->AsBlock()); + return EmitIndentedBlockAndNewline(stmt->AsBlock()); } if (stmt->IsBreak()) { return EmitBreak(stmt->AsBreak());
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 1bc5178..4178a30 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h
@@ -80,6 +80,10 @@ /// Handles a block statement with a newline at the end /// @param stmt the statement to emit /// @returns true if the statement was emitted successfully + bool EmitIndentedBlockAndNewline(ast::BlockStatement* stmt); + /// Handles a block statement with a newline at the end + /// @param stmt the statement to emit + /// @returns true if the statement was emitted successfully bool EmitBlockAndNewline(ast::BlockStatement* stmt); /// Handles a break statement /// @param stmt the statement to emit
diff --git a/src/writer/msl/generator_impl_block_test.cc b/src/writer/msl/generator_impl_block_test.cc index 9be5b60..f901c4b 100644 --- a/src/writer/msl/generator_impl_block_test.cc +++ b/src/writer/msl/generator_impl_block_test.cc
@@ -50,7 +50,7 @@ g.increment_indent(); ASSERT_TRUE(g.EmitBlock(&b)) << g.error(); - EXPECT_EQ(g.result(), R"( { + EXPECT_EQ(g.result(), R"({ discard_fragment(); })"); }
diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc index ec237da..8c85901 100644 --- a/src/writer/msl/generator_impl_entry_point_test.cc +++ b/src/writer/msl/generator_impl_entry_point_test.cc
@@ -73,11 +73,11 @@ auto func = std::make_unique<ast::Function>("vtx_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -139,11 +139,11 @@ auto func = std::make_unique<ast::Function>("vtx_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -205,11 +205,11 @@ auto func = std::make_unique<ast::Function>("frag_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -271,11 +271,11 @@ auto func = std::make_unique<ast::Function>("frag_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -334,11 +334,11 @@ auto func = std::make_unique<ast::Function>("comp_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -392,11 +392,11 @@ auto func = std::make_unique<ast::Function>("comp_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("foo"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("bar"))); func->set_body(std::move(body)); @@ -460,8 +460,8 @@ auto func = std::make_unique<ast::Function>("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("depth"), std::make_unique<ast::MemberAccessorExpression>( std::make_unique<ast::IdentifierExpression>("coord"),
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 0fb8de8..4a810e6 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc
@@ -53,8 +53,8 @@ auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); ast::Module m; @@ -79,8 +79,8 @@ auto func = std::make_unique<ast::Function>("main", ast::VariableList{}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); ast::Module m; @@ -113,8 +113,8 @@ auto func = std::make_unique<ast::Function>("my_func", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); ast::Module m; @@ -184,11 +184,11 @@ auto func = std::make_unique<ast::Function>("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -254,13 +254,13 @@ auto func = std::make_unique<ast::Function>("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("depth"), std::make_unique<ast::MemberAccessorExpression>( std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x")))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -319,9 +319,9 @@ std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x"))); - ast::StatementList body; - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -375,9 +375,9 @@ std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x"))); - ast::StatementList body; - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -439,14 +439,14 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("val"), std::make_unique<ast::IdentifierExpression>("param"))); - body.push_back(std::make_unique<ast::ReturnStatement>( + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::IdentifierExpression>("foo"))); sub_func->set_body(std::move(body)); @@ -458,12 +458,14 @@ ast::ExpressionList expr; expr.push_back(std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.0f))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("sub_func"), std::move(expr)))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -530,8 +532,8 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::IdentifierExpression>("param"))); sub_func->set_body(std::move(body)); @@ -543,12 +545,14 @@ ast::ExpressionList expr; expr.push_back(std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.0f))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("depth"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("sub_func"), std::move(expr)))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -617,13 +621,13 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("depth"), std::make_unique<ast::MemberAccessorExpression>( std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x")))); - body.push_back(std::make_unique<ast::ReturnStatement>( + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::IdentifierExpression>("param"))); sub_func->set_body(std::move(body)); @@ -635,12 +639,14 @@ ast::ExpressionList expr; expr.push_back(std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.0f))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("depth"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("sub_func"), std::move(expr)))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -700,8 +706,8 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::MemberAccessorExpression>( std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x")))); @@ -722,8 +728,9 @@ std::make_unique<ast::IdentifierExpression>("sub_func"), std::move(expr))); - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -778,8 +785,8 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::MemberAccessorExpression>( std::make_unique<ast::IdentifierExpression>("coord"), std::make_unique<ast::IdentifierExpression>("x")))); @@ -800,8 +807,9 @@ std::make_unique<ast::IdentifierExpression>("sub_func"), std::move(expr))); - body.push_back(std::make_unique<ast::VariableDeclStatement>(std::move(var))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -857,11 +865,11 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::IdentifierExpression>("foo"))); - body.push_back(std::make_unique<ast::ReturnStatement>( + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::IdentifierExpression>("foo"))); sub_func->set_body(std::move(body)); @@ -870,12 +878,13 @@ auto func_1 = std::make_unique<ast::Function>("frag_1_main", std::move(params), &void_type); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::CallExpression>( std::make_unique<ast::IdentifierExpression>("sub_func"), ast::ExpressionList{}))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -956,15 +965,15 @@ auto func_1 = std::make_unique<ast::Function>("frag_1_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("bar"), std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.0f)))); ast::StatementList list; list.push_back(std::make_unique<ast::ReturnStatement>()); - body.push_back(std::make_unique<ast::IfStatement>( + body->append(std::make_unique<ast::IfStatement>( std::make_unique<ast::BinaryExpression>( ast::BinaryOp::kEqual, std::make_unique<ast::ScalarConstructorExpression>( @@ -973,7 +982,7 @@ std::make_unique<ast::SintLiteral>(&i32, 1))), std::move(list))); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -1017,8 +1026,8 @@ auto sub_func = std::make_unique<ast::Function>("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::ScalarConstructorExpression>( std::make_unique<ast::FloatLiteral>(&f32, 1.0)))); sub_func->set_body(std::move(body)); @@ -1028,15 +1037,15 @@ auto func_1 = std::make_unique<ast::Function>("frag_1_main", std::move(params), &void_type); - body.push_back(std::make_unique<ast::VariableDeclStatement>( - std::make_unique<ast::Variable>("foo", ast::StorageClass::kFunction, - &f32))); - body.back()->AsVariableDecl()->variable()->set_constructor( - std::make_unique<ast::CallExpression>( - std::make_unique<ast::IdentifierExpression>("sub_func"), - ast::ExpressionList{})); + auto var = std::make_unique<ast::Variable>( + "foo", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("sub_func"), + ast::ExpressionList{})); - body.push_back(std::make_unique<ast::ReturnStatement>()); + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var))); + body->append(std::make_unique<ast::ReturnStatement>()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -1126,8 +1135,8 @@ auto func = std::make_unique<ast::Function>("my_func", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>()); func->set_body(std::move(body)); ast::Module m;
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 94914ce..0040481 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc
@@ -426,7 +426,7 @@ push_function(Function{definition_inst, result_op(), std::move(params)}); - for (const auto& stmt : func->body()) { + for (const auto& stmt : *(func->body())) { if (!GenerateStatement(stmt.get())) { return false; }
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index 5499f59..5b78cbf 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc
@@ -142,8 +142,8 @@ ast::Function a_func("a_func", std::move(func_params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::BinaryExpression>( ast::BinaryOp::kAdd, std::make_unique<ast::IdentifierExpression>("a"), std::make_unique<ast::IdentifierExpression>("b")))); @@ -210,8 +210,8 @@ ast::Function a_func("a_func", std::move(func_params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::BinaryExpression>( ast::BinaryOp::kAdd, std::make_unique<ast::IdentifierExpression>("a"), std::make_unique<ast::IdentifierExpression>("b"))));
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index 34dc0f8..04176c1 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -168,15 +168,16 @@ ast::type::VoidType void_type; ast::Function func("main", {}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique<ast::AssignmentStatement>( + + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("my_out"), std::make_unique<ast::IdentifierExpression>("my_in"))); - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("my_wg"), std::make_unique<ast::IdentifierExpression>("my_wg"))); // Add duplicate usages so we show they don't get output multiple times. - body.push_back(std::make_unique<ast::AssignmentStatement>( + body->append(std::make_unique<ast::AssignmentStatement>( std::make_unique<ast::IdentifierExpression>("my_out"), std::make_unique<ast::IdentifierExpression>("my_in"))); func.set_body(std::move(body));
diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index a0fb76f..700d8fe 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc
@@ -67,8 +67,8 @@ ast::Function func("a_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>( + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>( std::make_unique<ast::IdentifierExpression>("a"))); func.set_body(std::move(body)); @@ -93,8 +93,8 @@ TEST_F(BuilderTest, Function_WithBody) { ast::type::VoidType void_type; - ast::StatementList body; - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::ReturnStatement>()); ast::Function func("a_func", {}, &void_type); func.set_body(std::move(body));
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 373ef40..3fdabee 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc
@@ -346,7 +346,8 @@ return false; } - return EmitStatementBlockAndNewline(func->body()); + out_ << " "; + return EmitBlockAndNewline(func->body()); } bool GeneratorImpl::EmitType(ast::type::Type* type) { @@ -600,8 +601,6 @@ } bool GeneratorImpl::EmitBlock(ast::BlockStatement* stmt) { - make_indent(); - out_ << "{" << std::endl; increment_indent(); @@ -618,6 +617,15 @@ return true; } +bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) { + make_indent(); + const bool result = EmitBlock(stmt); + if (result) { + out_ << std::endl; + } + return result; +} + bool GeneratorImpl::EmitBlockAndNewline(ast::BlockStatement* stmt) { const bool result = EmitBlock(stmt); if (result) { @@ -658,7 +666,7 @@ return EmitAssign(stmt->AsAssign()); } if (stmt->IsBlock()) { - return EmitBlockAndNewline(stmt->AsBlock()); + return EmitIndentedBlockAndNewline(stmt->AsBlock()); } if (stmt->IsBreak()) { return EmitBreak(stmt->AsBreak());
diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h index 6c211cf..50e3141 100644 --- a/src/writer/wgsl/generator_impl.h +++ b/src/writer/wgsl/generator_impl.h
@@ -74,6 +74,10 @@ /// Handles a block statement with a newline at the end /// @param stmt the statement to emit /// @returns true if the statement was emitted successfully + bool EmitIndentedBlockAndNewline(ast::BlockStatement* stmt); + /// Handles a block statement with a newline at the end + /// @param stmt the statement to emit + /// @returns true if the statement was emitted successfully bool EmitBlockAndNewline(ast::BlockStatement* stmt); /// Handles a break statement /// @param stmt the statement to emit
diff --git a/src/writer/wgsl/generator_impl_block_test.cc b/src/writer/wgsl/generator_impl_block_test.cc index bc82273..43c0286 100644 --- a/src/writer/wgsl/generator_impl_block_test.cc +++ b/src/writer/wgsl/generator_impl_block_test.cc
@@ -48,7 +48,7 @@ g.increment_indent(); ASSERT_TRUE(g.EmitBlock(&b)) << g.error(); - EXPECT_EQ(g.result(), R"( { + EXPECT_EQ(g.result(), R"({ discard; })"); }
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index c650e93..0240a62 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -30,9 +30,9 @@ using WgslGeneratorImplTest = testing::Test; TEST_F(WgslGeneratorImplTest, Emit_Function) { - ast::StatementList body; - body.push_back(std::make_unique<ast::DiscardStatement>()); - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::DiscardStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); ast::type::VoidType void_type; ast::Function func("my_func", {}, &void_type); @@ -50,9 +50,9 @@ } TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { - ast::StatementList body; - body.push_back(std::make_unique<ast::DiscardStatement>()); - body.push_back(std::make_unique<ast::ReturnStatement>()); + auto body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::DiscardStatement>()); + body->append(std::make_unique<ast::ReturnStatement>()); ast::type::F32Type f32; ast::type::I32Type i32;