[spirv-writer] Add BlockStatement emission.
This CL adds BlockStatement support to the spirv-writer. The type
determiner is also updated as needed.
Bug: tint:134
Change-Id: I91e08c3acafd67401a010fff21abde7feec46e8e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25609
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 42f91ec..0eaa74c 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -800,6 +800,7 @@
"src/writer/spirv/builder_as_expression_test.cc",
"src/writer/spirv/builder_assign_test.cc",
"src/writer/spirv/builder_binary_expression_test.cc",
+ "src/writer/spirv/builder_block_test.cc",
"src/writer/spirv/builder_call_test.cc",
"src/writer/spirv/builder_cast_expression_test.cc",
"src/writer/spirv/builder_constructor_expression_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 5181674..efa3aeb 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -454,6 +454,7 @@
writer/spirv/builder_as_expression_test.cc
writer/spirv/builder_assign_test.cc
writer/spirv/builder_binary_expression_test.cc
+ writer/spirv/builder_block_test.cc
writer/spirv/builder_call_test.cc
writer/spirv/builder_cast_expression_test.cc
writer/spirv/builder_constructor_expression_test.cc
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 0de445d..fbbc5d5 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -22,6 +22,7 @@
#include "src/ast/as_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
+#include "src/ast/block_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
@@ -239,6 +240,19 @@
return true;
}
+bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) {
+ for (const auto& stmt : *stmts) {
+ if (!DetermineVariableStorageClass(stmt.get())) {
+ return false;
+ }
+
+ if (!DetermineResultType(stmt.get())) {
+ return false;
+ }
+ }
+ return true;
+}
+
bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) {
for (const auto& stmt : stmts) {
if (!DetermineVariableStorageClass(stmt.get())) {
@@ -282,6 +296,9 @@
auto* a = stmt->AsAssign();
return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs());
}
+ if (stmt->IsBlock()) {
+ return DetermineStatements(stmt->AsBlock());
+ }
if (stmt->IsBreak()) {
return true;
}
@@ -347,7 +364,8 @@
return DetermineResultType(v->variable()->constructor());
}
- set_error(stmt->source(), "unknown statement type for type determination");
+ set_error(stmt->source(),
+ "unknown statement type for type determination: " + stmt->str());
return false;
}
diff --git a/src/type_determiner.h b/src/type_determiner.h
index ea9b9fc..d7dc0d1 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -64,6 +64,10 @@
/// Determines type information for a set of statements
/// @param stmts the statements to check
/// @returns true if the determination was successful
+ bool DetermineStatements(const ast::BlockStatement* stmts);
+ /// Determines type information for a set of statements
+ /// @param stmts the statements to check
+ /// @returns true if the determination was successful
bool DetermineStatements(const ast::StatementList& stmts);
/// Determines type information for a statement
/// @param stmt the statement to check
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index d238ab6..075eed5 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -24,6 +24,7 @@
#include "src/ast/as_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
+#include "src/ast/block_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
@@ -63,7 +64,7 @@
class FakeStmt : public ast::Statement {
public:
bool IsValid() const override { return true; }
- void to_str(std::ostream&, size_t) const override {}
+ void to_str(std::ostream& out, size_t) const override { out << "Fake"; }
};
class FakeExpr : public ast::Expression {
@@ -97,7 +98,8 @@
s.set_source(Source{0, 0});
EXPECT_FALSE(td()->DetermineResultType(&s));
- EXPECT_EQ(td()->error(), "unknown statement type for type determination");
+ EXPECT_EQ(td()->error(),
+ "unknown statement type for type determination: Fake");
}
TEST_F(TypeDeterminerTest, Stmt_Error_Unknown) {
@@ -106,7 +108,7 @@
EXPECT_FALSE(td()->DetermineResultType(&s));
EXPECT_EQ(td()->error(),
- "2:30: unknown statement type for type determination");
+ "2:30: unknown statement type for type determination: Fake");
}
TEST_F(TypeDeterminerTest, Stmt_Assign) {
@@ -158,6 +160,29 @@
EXPECT_TRUE(rhs_ptr->result_type()->IsF32());
}
+TEST_F(TypeDeterminerTest, Stmt_Block) {
+ ast::type::I32Type i32;
+ ast::type::F32Type f32;
+
+ auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2));
+ auto* lhs_ptr = lhs.get();
+
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.3f));
+ auto* rhs_ptr = rhs.get();
+
+ ast::BlockStatement block;
+ block.append(std::make_unique<ast::AssignmentStatement>(std::move(lhs),
+ std::move(rhs)));
+
+ EXPECT_TRUE(td()->DetermineResultType(&block));
+ ASSERT_NE(lhs_ptr->result_type(), nullptr);
+ ASSERT_NE(rhs_ptr->result_type(), nullptr);
+ EXPECT_TRUE(lhs_ptr->result_type()->IsI32());
+ EXPECT_TRUE(rhs_ptr->result_type()->IsF32());
+}
+
TEST_F(TypeDeterminerTest, Stmt_Else) {
ast::type::I32Type i32;
ast::type::F32Type f32;
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 6f18c5c..94914ce 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -23,6 +23,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/binding_decoration.h"
+#include "src/ast/block_statement.h"
#include "src/ast/bool_literal.h"
#include "src/ast/builtin_decoration.h"
#include "src/ast/call_expression.h"
@@ -1338,6 +1339,18 @@
return result_id;
}
+bool Builder::GenerateBlockStatement(ast::BlockStatement* stmt) {
+ scope_stack_.push_scope();
+ for (const auto& block_stmt : *stmt) {
+ if (!GenerateStatement(block_stmt.get())) {
+ return false;
+ }
+ }
+ scope_stack_.pop_scope();
+
+ return true;
+}
+
uint32_t Builder::GenerateCallExpression(ast::CallExpression* expr) {
if (!expr->func()->IsIdentifier()) {
error_ = "invalid function name";
@@ -1807,6 +1820,9 @@
if (stmt->IsAssign()) {
return GenerateAssignStatement(stmt->AsAssign());
}
+ if (stmt->IsBlock()) {
+ return GenerateBlockStatement(stmt->AsBlock());
+ }
if (stmt->IsBreak()) {
return GenerateBreakStatement(stmt->AsBreak());
}
@@ -1839,7 +1855,7 @@
return GenerateVariableDeclStatement(stmt->AsVariableDecl());
}
- error_ = "Unknown statement";
+ error_ = "Unknown statement: " + stmt->str();
return false;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 50e63fa..1e62b3a 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -167,6 +167,10 @@
/// @param assign the statement to generate
/// @returns true if the statement was successfully generated
bool GenerateAssignStatement(ast::AssignmentStatement* assign);
+ /// Generates a block statement
+ /// @param stmt the statement to generate
+ /// @returns true if the statement was successfully generated
+ bool GenerateBlockStatement(ast::BlockStatement* stmt);
/// Generates a break statement
/// @param stmt the statement to generate
/// @returns true if the statement was successfully generated
diff --git a/src/writer/spirv/builder_block_test.cc b/src/writer/spirv/builder_block_test.cc
new file mode 100644
index 0000000..b13ca1d
--- /dev/null
+++ b/src/writer/spirv/builder_block_test.cc
@@ -0,0 +1,101 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/block_statement.h"
+#include "src/ast/float_literal.h"
+#include "src/ast/identifier_expression.h"
+#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/type/f32_type.h"
+#include "src/ast/variable_decl_statement.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
+#include "src/writer/spirv/builder.h"
+#include "src/writer/spirv/spv_dump.h"
+
+namespace tint {
+namespace writer {
+namespace spirv {
+namespace {
+
+using BuilderTest = testing::Test;
+
+TEST_F(BuilderTest, Block) {
+ ast::type::F32Type f32;
+
+ // Note, this test uses shadow variables which aren't allowed in WGSL but
+ // serves to prove the block code is pushing new scopes as needed.
+ ast::BlockStatement outer;
+
+ outer.append(std::make_unique<ast::VariableDeclStatement>(
+ std::make_unique<ast::Variable>("var", ast::StorageClass::kFunction,
+ &f32)));
+ outer.append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("var"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f))));
+
+ auto inner = std::make_unique<ast::BlockStatement>();
+ inner->append(std::make_unique<ast::VariableDeclStatement>(
+ std::make_unique<ast::Variable>("var", ast::StorageClass::kFunction,
+ &f32)));
+ inner->append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("var"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0f))));
+
+ outer.append(std::move(inner));
+ outer.append(std::make_unique<ast::AssignmentStatement>(
+ std::make_unique<ast::IdentifierExpression>("var"),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f))));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ ASSERT_TRUE(td.DetermineResultType(&outer)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_TRUE(b.GenerateStatement(&outer)) << b.error();
+ EXPECT_FALSE(b.has_error());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypePointer Function %3
+%4 = OpConstantNull %3
+%5 = OpConstant %3 1
+%7 = OpConstant %3 2
+%8 = OpConstant %3 3
+)");
+
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%1 = OpVariable %2 Function %4
+%6 = OpVariable %2 Function %4
+)");
+
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpStore %1 %5
+OpStore %6 %7
+OpStore %1 %8
+)");
+}
+
+} // namespace
+} // namespace spirv
+} // namespace writer
+} // namespace tint