[spirv-reader] Support return, return-value
Bug: tint:3
Change-Id: Iaaf6ace739ac30e7f9f0bd4ddcef209ab1b71ed8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22480
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index d6a25bd..79bbd0e 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -32,6 +32,7 @@
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
+#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/storage_class.h"
#include "src/ast/switch_statement.h"
@@ -1753,9 +1754,23 @@
return success();
}
-bool FunctionEmitter::EmitNormalTerminator(const BlockInfo&) {
- // TODO(dneto): emit fallthrough, break, continue, return, kill
- return true;
+bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
+ const auto& terminator = *(block_info.basic_block->terminator());
+ switch (terminator.opcode()) {
+ case SpvOpReturn:
+ AddStatement(std::make_unique<ast::ReturnStatement>());
+ return true;
+ case SpvOpReturnValue: {
+ auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
+ AddStatement(
+ std::make_unique<ast::ReturnStatement>(std::move(value.expr)));
+ }
+ return true;
+ default:
+ break;
+ }
+ // TODO(dneto): emit fallthrough, break, continue, kill
+ return success();
}
bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index 1460556..58e4502 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -62,6 +62,8 @@
%uint = OpTypeInt 32 0
%selector = OpConstant %uint 42
+ %uintfn = OpTypeFunction %uint
+
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
@@ -6771,6 +6773,7 @@
{
}
}
+Return{}
)"));
}
@@ -6820,6 +6823,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -6869,6 +6873,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -6926,6 +6931,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -6994,6 +7000,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7052,6 +7059,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7110,6 +7118,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7223,6 +7232,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7264,6 +7274,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7305,6 +7316,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7345,6 +7357,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7385,6 +7398,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)"));
}
@@ -7443,6 +7457,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)")) << ToString(fe.ast_body());
}
@@ -7509,6 +7524,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)")) << ToString(fe.ast_body());
}
@@ -7595,6 +7611,7 @@
Identifier{var}
ScalarConstructor{999}
}
+Return{}
)")) << ToString(fe.ast_body());
}
@@ -7626,6 +7643,224 @@
// TODO(dneto): Needs "break" support
}
+TEST_F(SpvParserTest, EmitBody_Return_TopLevel) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_Return_InsideIf) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpSelectionMerge %99 None
+ OpBranchConditional %cond %20 %99
+
+ %20 = OpLabel
+ OpReturn
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(If{
+ (
+ ScalarConstructor{false}
+ )
+ {
+ Return{}
+ }
+}
+Else{
+ {
+ }
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_Return_InsideLoop) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ OpBranch %20
+
+ %20 = OpLabel
+ OpLoopMerge %99 %80 None
+ OpBranchConditional %cond %30 %30
+
+ %30 = OpLabel
+ OpReturn
+
+ %80 = OpLabel
+ OpBranch %20
+
+ %99 = OpLabel
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Loop{
+ Return{}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_TopLevel) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %200 = OpFunction %uint None %uintfn
+
+ %210 = OpLabel
+ OpReturnValue %uint_2
+
+ OpFunctionEnd
+
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %11 = OpFunctionCall %uint %200
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(200));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Return{
+ {
+ ScalarConstructor{2}
+ }
+}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_InsideIf) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %200 = OpFunction %uint None %uintfn
+
+ %210 = OpLabel
+ OpSelectionMerge %299 None
+ OpBranchConditional %cond %220 %299
+
+ %220 = OpLabel
+ OpReturnValue %uint_2
+
+ %299 = OpLabel
+ OpReturnValue %uint_3
+
+ OpFunctionEnd
+
+
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %11 = OpFunctionCall %uint %200
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(200));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(If{
+ (
+ ScalarConstructor{false}
+ )
+ {
+ Return{
+ {
+ ScalarConstructor{2}
+ }
+ }
+ }
+}
+Else{
+ {
+ }
+}
+Return{
+ {
+ ScalarConstructor{3}
+ }
+}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_Loop) {
+ auto* p = parser(test::Assemble(CommonTypes() + R"(
+ %200 = OpFunction %void None %voidfn
+
+ %210 = OpLabel
+ OpBranch %220
+
+ %220 = OpLabel
+ OpLoopMerge %299 %280 None
+ OpBranchConditional %cond %230 %230
+
+ %230 = OpLabel
+ OpReturnValue %uint_2
+
+ %280 = OpLabel
+ OpBranch %220
+
+ %299 = OpLabel
+ OpReturnValue %uint_3
+
+ OpFunctionEnd
+
+
+ %100 = OpFunction %void None %voidfn
+
+ %10 = OpLabel
+ %11 = OpFunctionCall %uint %200
+ OpReturn
+
+ OpFunctionEnd
+ )"));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ FunctionEmitter fe(p, *spirv_function(200));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+ EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Loop{
+ Return{
+ {
+ ScalarConstructor{2}
+ }
+ }
+}
+Return{
+ {
+ ScalarConstructor{3}
+ }
+}
+)")) << ToString(fe.ast_body());
+}
+
} // namespace
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc
index ab6f5ab..03ef876 100644
--- a/src/reader/spirv/parser_impl_function_decl_test.cc
+++ b/src/reader/spirv/parser_impl_function_decl_test.cc
@@ -107,18 +107,24 @@
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
+ // TODO(dneto): This will need to be updated when function calls are
+ // supported. Otherwise, use more general matching instead of substring
+ // equality.
EXPECT_THAT(module_ast, HasSubstr(R"(
Function leaf -> __void
()
{
+ Return{}
}
Function branch -> __void
()
{
+ Return{}
}
Function root -> __void
()
{
+ Return{}
})"));
}
@@ -138,7 +144,13 @@
Function ret_float -> __f32
()
{
- })"));
+ Return{
+ {
+ ScalarConstructor{0.000000}
+ }
+ }
+ })"))
+ << module_ast;
}
TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) {
@@ -177,6 +189,7 @@
}
)
{
+ Return{}
})"));
}
@@ -215,6 +228,7 @@
}
)
{
+ Return{}
})"));
}