[spirv-reader] Support return, return-value

Bug: tint:3
Change-Id: Iaaf6ace739ac30e7f9f0bd4ddcef209ab1b71ed8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22480
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index d6a25bd..79bbd0e 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -32,6 +32,7 @@
 #include "src/ast/if_statement.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/switch_statement.h"
@@ -1753,9 +1754,23 @@
   return success();
 }
 
-bool FunctionEmitter::EmitNormalTerminator(const BlockInfo&) {
-  // TODO(dneto): emit fallthrough, break, continue, return, kill
-  return true;
+bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
+  const auto& terminator = *(block_info.basic_block->terminator());
+  switch (terminator.opcode()) {
+    case SpvOpReturn:
+      AddStatement(std::make_unique<ast::ReturnStatement>());
+      return true;
+    case SpvOpReturnValue: {
+      auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
+      AddStatement(
+          std::make_unique<ast::ReturnStatement>(std::move(value.expr)));
+    }
+      return true;
+    default:
+      break;
+  }
+  // TODO(dneto): emit fallthrough, break, continue, kill
+  return success();
 }
 
 bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index 1460556..58e4502 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -62,6 +62,8 @@
     %uint = OpTypeInt 32 0
     %selector = OpConstant %uint 42
 
+    %uintfn = OpTypeFunction %uint
+
     %uint_0 = OpConstant %uint 0
     %uint_1 = OpConstant %uint 1
     %uint_2 = OpConstant %uint 2
@@ -6771,6 +6773,7 @@
   {
   }
 }
+Return{}
 )"));
 }
 
@@ -6820,6 +6823,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -6869,6 +6873,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -6926,6 +6931,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -6994,6 +7000,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7052,6 +7059,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7110,6 +7118,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7223,6 +7232,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7264,6 +7274,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7305,6 +7316,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7345,6 +7357,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7385,6 +7398,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )"));
 }
 
@@ -7443,6 +7457,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )")) << ToString(fe.ast_body());
 }
 
@@ -7509,6 +7524,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )")) << ToString(fe.ast_body());
 }
 
@@ -7595,6 +7611,7 @@
   Identifier{var}
   ScalarConstructor{999}
 }
+Return{}
 )")) << ToString(fe.ast_body());
 }
 
@@ -7626,6 +7643,224 @@
   // TODO(dneto): Needs "break" support
 }
 
+TEST_F(SpvParserTest, EmitBody_Return_TopLevel) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_Return_InsideIf) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpBranchConditional %cond %20 %99
+
+     %20 = OpLabel
+     OpReturn
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(If{
+  (
+    ScalarConstructor{false}
+  )
+  {
+    Return{}
+  }
+}
+Else{
+  {
+  }
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_Return_InsideLoop) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %80 None
+     OpBranchConditional %cond %30 %30
+
+     %30 = OpLabel
+     OpReturn
+
+     %80 = OpLabel
+     OpBranch %20
+
+     %99 = OpLabel
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Loop{
+  Return{}
+}
+Return{}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_TopLevel) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %200 = OpFunction %uint None %uintfn
+
+     %210 = OpLabel
+     OpReturnValue %uint_2
+
+     OpFunctionEnd
+
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %11 = OpFunctionCall %uint %200
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(200));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Return{
+  {
+    ScalarConstructor{2}
+  }
+}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_InsideIf) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %200 = OpFunction %uint None %uintfn
+
+     %210 = OpLabel
+     OpSelectionMerge %299 None
+     OpBranchConditional %cond %220 %299
+
+     %220 = OpLabel
+     OpReturnValue %uint_2
+
+     %299 = OpLabel
+     OpReturnValue %uint_3
+
+     OpFunctionEnd
+
+
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %11 = OpFunctionCall %uint %200
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(200));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(If{
+  (
+    ScalarConstructor{false}
+  )
+  {
+    Return{
+      {
+        ScalarConstructor{2}
+      }
+    }
+  }
+}
+Else{
+  {
+  }
+}
+Return{
+  {
+    ScalarConstructor{3}
+  }
+}
+)")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest, EmitBody_ReturnValue_Loop) {
+  auto* p = parser(test::Assemble(CommonTypes() + R"(
+     %200 = OpFunction %void None %voidfn
+
+     %210 = OpLabel
+     OpBranch %220
+
+     %220 = OpLabel
+     OpLoopMerge %299 %280 None
+     OpBranchConditional %cond %230 %230
+
+     %230 = OpLabel
+     OpReturnValue %uint_2
+
+     %280 = OpLabel
+     OpBranch %220
+
+     %299 = OpLabel
+     OpReturnValue %uint_3
+
+     OpFunctionEnd
+
+
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     %11 = OpFunctionCall %uint %200
+     OpReturn
+
+     OpFunctionEnd
+  )"));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p, *spirv_function(200));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Loop{
+  Return{
+    {
+      ScalarConstructor{2}
+    }
+  }
+}
+Return{
+  {
+    ScalarConstructor{3}
+  }
+}
+)")) << ToString(fe.ast_body());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader
diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc
index ab6f5ab..03ef876 100644
--- a/src/reader/spirv/parser_impl_function_decl_test.cc
+++ b/src/reader/spirv/parser_impl_function_decl_test.cc
@@ -107,18 +107,24 @@
   EXPECT_TRUE(p->BuildAndParseInternalModule());
   EXPECT_TRUE(p->error().empty());
   const auto module_ast = p->module().to_str();
+  // TODO(dneto): This will need to be updated when function calls are
+  // supported. Otherwise, use more general matching instead of substring
+  // equality.
   EXPECT_THAT(module_ast, HasSubstr(R"(
   Function leaf -> __void
   ()
   {
+    Return{}
   }
   Function branch -> __void
   ()
   {
+    Return{}
   }
   Function root -> __void
   ()
   {
+    Return{}
   })"));
 }
 
@@ -138,7 +144,13 @@
   Function ret_float -> __f32
   ()
   {
-  })"));
+    Return{
+      {
+        ScalarConstructor{0.000000}
+      }
+    }
+  })"))
+      << module_ast;
 }
 
 TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) {
@@ -177,6 +189,7 @@
     }
   )
   {
+    Return{}
   })"));
 }
 
@@ -215,6 +228,7 @@
     }
   )
   {
+    Return{}
   })"));
 }