[tint][ir][ToProgram] Emit returns with values

And implement functions with return values.

Bug: tint:1902
Change-Id: Id4015aa83bf75de2a0f3dfdbfe19f728c05226c8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133142
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 5e46e92..94b0f59 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -41,10 +41,17 @@
 #include "src/tint/utils/transform.h"
 #include "src/tint/utils/vector.h"
 
+// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
 #define UNHANDLED_CASE(object_ptr)          \
     TINT_UNIMPLEMENTED(IR, b.Diagnostics()) \
         << "unhandled case in Switch(): " << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
 
+// Helper for incrementing nesting_depth_ and then decrementing nesting_depth_ at the end
+// of the scope that holds the call.
+#define SCOPED_NESTING() \
+    nesting_depth_++;    \
+    TINT_DEFER(nesting_depth_--)
+
 namespace tint::ir {
 
 namespace {
@@ -63,56 +70,109 @@
     }
 
   private:
+    /// The source IR module
     const Module& mod;
+
+    /// The target ProgramBuilder
     ProgramBuilder b;
+
+    /// A hashmap of value to symbol used in the emitted AST
     utils::Hashmap<const Value*, Symbol, 32> value_names_;
 
-    void Fn(const Function* fn) {
+    // The nesting depth of the currently generated AST
+    // 0 is module scope
+    // 1 is root-level function scope
+    // 2+ is within control flow
+    uint32_t nesting_depth_ = 0;
+
+    const ast::Function* Fn(const Function* fn) {
+        SCOPED_NESTING();
+
         auto name = Sym(fn->name);
         // TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
         utils::Vector<const ast::Parameter*, 1> params{};
-        ast::Type ret_ty;
-        auto* body = FlowNodeGraph(fn->start_target, fn->end_target);
+        auto ret_ty = Type(fn->return_type);
+        if (!ret_ty) {
+            return nullptr;
+        }
+        auto* body = FlowNodeGraph(fn->start_target);
+        if (!body) {
+            return nullptr;
+        }
         utils::Vector<const ast::Attribute*, 1> attrs{};
         utils::Vector<const ast::Attribute*, 1> ret_attrs{};
-        b.Func(name, std::move(params), ret_ty, body, std::move(attrs), std::move(ret_attrs));
+        return b.Func(name, std::move(params), ret_ty.Get(), body, std::move(attrs),
+                      std::move(ret_attrs));
     }
 
-    const ast::BlockStatement* FlowNodeGraph(const ir::FlowNode* node,
-                                             const ir::FlowNode* stop_at) {
+    const ast::BlockStatement* FlowNodeGraph(ir::FlowNode* start_node,
+                                             ir::FlowNode* stop_at = nullptr) {
         // TODO(crbug.com/tint/1902): Check if the block is dead
         utils::Vector<const ast::Statement*,
                       decltype(ast::BlockStatement::statements)::static_length>
             stmts;
-        while (node != stop_at) {
+
+        ir::Branch root_branch{start_node, {}};
+        const ir::Branch* branch = &root_branch;
+
+        while (branch->target != stop_at) {
             enum Status { kContinue, kStop, kError };
             Status status = Switch(
-                node,  //
+                branch->target,
+
                 [&](const ir::Block* block) {
                     for (auto* inst : block->instructions) {
-                        if (auto* stmt = Stmt(inst); TINT_LIKELY(stmt)) {
-                            stmts.Push(stmt);
-                        } else {
+                        auto* stmt = Stmt(inst);
+                        if (TINT_UNLIKELY(!stmt)) {
                             return kError;
                         }
+                        stmts.Push(stmt);
                     }
-                    node = block->branch.target;
+                    branch = &block->branch;
                     return kContinue;
                 },
+
                 [&](const ir::If* if_) {
-                    if (auto* stmt = If(if_); TINT_LIKELY(stmt)) {
-                        stmts.Push(stmt);
-                        node = if_->merge.target;
-                        return node->inbound_branches.IsEmpty() ? kStop : kContinue;
+                    auto* stmt = If(if_);
+                    if (TINT_UNLIKELY(!stmt)) {
+                        return kError;
                     }
-                    return kError;
+                    stmts.Push(stmt);
+                    branch = &if_->merge;
+                    return branch->target->inbound_branches.IsEmpty() ? kStop : kContinue;
                 },
+
                 [&](const ir::FunctionTerminator*) {
-                    stmts.Push(b.Return());
+                    if (branch->args.IsEmpty()) {
+                        // Branch to function terminator has no arguments.
+                        // If this block is nested withing some control flow, then we must emit a
+                        // 'return' statement, otherwise we've just naturally reached the end of the
+                        // function where the 'return' is redundant.
+                        if (nesting_depth_ > 1) {
+                            stmts.Push(b.Return());
+                        }
+                        return kStop;
+                    }
+
+                    // Branch to function terminator has arguments - this is the return value.
+                    if (branch->args.Length() != 1) {
+                        TINT_ICE(IR, b.Diagnostics())
+                            << "expected 1 value for function terminator (return value), got "
+                            << branch->args.Length();
+                        return kError;
+                    }
+
+                    auto* val = Expr(branch->args.Front());
+                    if (TINT_UNLIKELY(!val)) {
+                        return kError;
+                    }
+
+                    stmts.Push(b.Return(val));
                     return kStop;
                 },
+
                 [&](Default) {
-                    UNHANDLED_CASE(node);
+                    UNHANDLED_CASE(branch->target);
                     return kError;
                 });
 
@@ -128,11 +188,14 @@
     }
 
     const ast::IfStatement* If(const ir::If* i) {
+        SCOPED_NESTING();
+
         auto* cond = Expr(i->condition);
         auto* t = FlowNodeGraph(i->true_.target, i->merge.target);
-        if (!t) {
+        if (TINT_UNLIKELY(!t)) {
             return nullptr;
         }
+
         if (!IsEmpty(i->false_.target, i->merge.target)) {
             // If the else target is an if flow node with the same merge target as this if, then
             // emit an 'else if' instead of a block statement for the else.
@@ -152,6 +215,7 @@
                 return b.If(cond, t, b.Else(f));
             }
         }
+
         return b.If(cond, t);
     }
 
@@ -190,7 +254,7 @@
             inst,                                            //
             [&](const ir::Call* i) { return CallStmt(i); },  //
             [&](const ir::Var* i) { return Var(i); },        //
-            [&](const ir::Store* i) { return Store(i); },
+            [&](const ir::Store* i) { return Store(i); },    //
             [&](Default) {
                 UNHANDLED_CASE(inst);
                 return nullptr;
@@ -217,11 +281,11 @@
         }
         switch (var->address_space) {
             case builtin::AddressSpace::kFunction:
-                return b.Decl(b.Var(name, ty, init));
+                return b.Decl(b.Var(name, ty.Get(), init));
             case builtin::AddressSpace::kStorage:
-                return b.Decl(b.Var(name, ty, init, var->access, var->address_space));
+                return b.Decl(b.Var(name, ty.Get(), init, var->access, var->address_space));
             default:
-                return b.Decl(b.Var(name, ty, init, var->address_space));
+                return b.Decl(b.Var(name, ty.Get(), init, var->address_space));
         }
     }
 
@@ -271,8 +335,8 @@
 
     const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); }
 
-    const ast::Type Type(const type::Type* ty) {
-        return Switch(
+    utils::Result<ast::Type> Type(const type::Type* ty) {
+        return Switch<utils::Result<ast::Type>>(
             ty,                                              //
             [&](const type::Void*) { return ast::Type{}; },  //
             [&](const type::I32*) { return b.ty.i32(); },    //
@@ -280,60 +344,87 @@
             [&](const type::F16*) { return b.ty.f16(); },    //
             [&](const type::F32*) { return b.ty.f32(); },    //
             [&](const type::Bool*) { return b.ty.bool_(); },
-            [&](const type::Matrix* m) {
+            [&](const type::Matrix* m) -> utils::Result<ast::Type> {
                 auto el = Type(m->type());
-                return b.ty.mat(el, m->columns(), m->rows());
+                if (!el) {
+                    return utils::Failure;
+                }
+                return b.ty.mat(el.Get(), m->columns(), m->rows());
             },
-            [&](const type::Vector* v) {
+            [&](const type::Vector* v) -> utils::Result<ast::Type> {
                 auto el = Type(v->type());
+                if (!el) {
+                    return utils::Failure;
+                }
                 if (v->Packed()) {
                     TINT_ASSERT(IR, v->Width() == 3u);
-                    return b.ty(builtin::Builtin::kPackedVec3, el);
+                    return b.ty(builtin::Builtin::kPackedVec3, el.Get());
                 } else {
-                    return b.ty.vec(el, v->Width());
+                    return b.ty.vec(el.Get(), v->Width());
                 }
             },
-            [&](const type::Array* a) {
+            [&](const type::Array* a) -> utils::Result<ast::Type> {
                 auto el = Type(a->ElemType());
+                if (!el) {
+                    return utils::Failure;
+                }
                 utils::Vector<const ast::Attribute*, 1> attrs;
                 if (!a->IsStrideImplicit()) {
                     attrs.Push(b.Stride(a->Stride()));
                 }
                 if (a->Count()->Is<type::RuntimeArrayCount>()) {
-                    return b.ty.array(el, std::move(attrs));
+                    return b.ty.array(el.Get(), std::move(attrs));
                 }
                 auto count = a->ConstantCount();
                 if (TINT_UNLIKELY(!count)) {
                     TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
-                    return b.ty.array(el, u32(1), std::move(attrs));
+                    return b.ty.array(el.Get(), u32(1), std::move(attrs));
                 }
-                return b.ty.array(el, u32(count.value()), std::move(attrs));
+                return b.ty.array(el.Get(), u32(count.value()), std::move(attrs));
             },
             [&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
-            [&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
+            [&](const type::Atomic* a) -> utils::Result<ast::Type> {
+                auto el = Type(a->Type());
+                if (!el) {
+                    return utils::Failure;
+                }
+                return b.ty.atomic(el.Get());
+            },
             [&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
             [&](const type::DepthMultisampledTexture* t) {
                 return b.ty.depth_multisampled_texture(t->dim());
             },
             [&](const type::ExternalTexture*) { return b.ty.external_texture(); },
-            [&](const type::MultisampledTexture* t) {
-                return b.ty.multisampled_texture(t->dim(), Type(t->type()));
+            [&](const type::MultisampledTexture* t) -> utils::Result<ast::Type> {
+                auto el = Type(t->type());
+                if (!el) {
+                    return utils::Failure;
+                }
+                return b.ty.multisampled_texture(t->dim(), el.Get());
             },
-            [&](const type::SampledTexture* t) {
-                return b.ty.sampled_texture(t->dim(), Type(t->type()));
+            [&](const type::SampledTexture* t) -> utils::Result<ast::Type> {
+                auto el = Type(t->type());
+                if (!el) {
+                    return utils::Failure;
+                }
+                return b.ty.sampled_texture(t->dim(), el.Get());
             },
             [&](const type::StorageTexture* t) {
                 return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
             },
             [&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
-            [&](const type::Pointer* p) {
+            [&](const type::Pointer* p) -> utils::Result<ast::Type> {
                 // Note: type::Pointer always has an inferred access, but WGSL only allows an
                 // explicit access in the 'storage' address space.
+                auto el = Type(p->StoreType());
+                if (!el) {
+                    return utils::Failure;
+                }
                 auto address_space = p->AddressSpace();
                 auto access = address_space == builtin::AddressSpace::kStorage
                                   ? p->Access()
                                   : builtin::Access::kUndefined;
-                return b.ty.pointer(Type(p->StoreType()), address_space, access);
+                return b.ty.pointer(el.Get(), address_space, access);
             },
             [&](const type::Reference* r) { return Type(r->StoreType()); },
             [&](Default) {
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 171f1f2..ab5bbbe 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -40,6 +40,13 @@
         ASSERT_TRUE(ir_module);
 
         auto output_program = ToProgram(ir_module.Get());
+        if (!output_program.IsValid()) {
+            tint::ir::Disassembler d{ir_module.Get()};
+            FAIL() << output_program.Diagnostics().str() << std::endl
+                   << "IR:" << std::endl
+                   << d.Disassemble();
+        }
+
         ASSERT_TRUE(output_program.IsValid()) << output_program.Diagnostics().str();
 
         auto output = writer::wgsl::Generate(&output_program, {});
@@ -79,6 +86,14 @@
 )");
 }
 
+TEST_F(IRToProgramRoundtripTest, SingleFunction_Return_i32) {
+    Test(R"(
+fn f() -> i32 {
+  return 42i;
+}
+)");
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Function-scope var
 ////////////////////////////////////////////////////////////////////////////////
@@ -127,9 +142,6 @@
 
 TEST_F(IRToProgramRoundtripTest, If_Return) {
     Test(R"(
-fn a() {
-}
-
 fn f() {
   var cond : bool = true;
   if (cond) {
@@ -139,6 +151,18 @@
 )");
 }
 
+TEST_F(IRToProgramRoundtripTest, If_Return_i32) {
+    Test(R"(
+fn f() -> i32 {
+  var cond : bool = true;
+  if (cond) {
+    return 42i;
+  }
+  return 10i;
+}
+)");
+}
+
 TEST_F(IRToProgramRoundtripTest, If_CallFn_Else_CallFn) {
     Test(R"(
 fn a() {
@@ -158,7 +182,20 @@
 )");
 }
 
-TEST_F(IRToProgramRoundtripTest, If_Return_Else_Return) {
+TEST_F(IRToProgramRoundtripTest, If_Return_f32_Else_Return_f32) {
+    Test(R"(
+fn f() -> f32 {
+  var cond : bool = true;
+  if (cond) {
+    return 1.0f;
+  } else {
+    return 2.0f;
+  }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, If_Return_u32_Else_CallFn) {
     Test(R"(
 fn a() {
 }
@@ -166,13 +203,15 @@
 fn b() {
 }
 
-fn f() {
+fn f() -> u32 {
   var cond : bool = true;
   if (cond) {
-    return;
+    return 1u;
   } else {
-    return;
+    a();
   }
+  b();
+  return 2u;
 }
 )");
 }
@@ -196,6 +235,7 @@
   } else if (cond_b) {
     b();
   }
+  c();
 }
 )");
 }