[tint][it][ToProgram] Correctly handle pointers

Involves some additional tracking as the IR doesn't have the concept of references.

Change-Id: I64a87ea0c4971ef058e96425cc981e7bdacfe974
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139644
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 69df972..3e30cfb 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -79,10 +79,6 @@
 
 namespace {
 
-/// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by its
-/// single place of usage. Attempting to use this value a second time should result in an ICE.
-struct ConsumedValue {};
-
 class State {
   public:
     explicit State(Module& m) : mod(m) {}
@@ -105,21 +101,43 @@
     }
 
   private:
+    /// The AST representation for an IR pointer type
+    enum class PtrKind {
+        kPtr,  // IR pointer is represented in the AST as a pointer
+        kRef,  // IR pointer is represented in the AST as a reference
+    };
+
     /// The source IR module
     Module& mod;
 
     /// The target ProgramBuilder
     ProgramBuilder b;
 
-    using ValueBinding = std::variant<Symbol, const ast::Expression*, ConsumedValue>;
+    /// The structure for a value held by a 'let', 'var' or parameter.
+    struct VariableValue {
+        Symbol name;  // Name of the variable
+        PtrKind ptr_kind = PtrKind::kRef;
+    };
 
-    /// A hashmap of value to one of:
-    /// * Symbol           - Name of 'let' (non-inlinable value), 'var' or parameter.
-    /// * ast::Expression* - single use, inlined expression.
-    /// * ConsumedValue    - a special value used to indicate that the value has already been
-    ///                      consumed.
+    /// The structure for an inlined value
+    struct InlinedValue {
+        const ast::Expression* expr = nullptr;
+        PtrKind ptr_kind = PtrKind::kRef;
+    };
+
+    /// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by
+    /// its single place of usage. Attempting to use this value a second time should result in an
+    /// ICE.
+    struct ConsumedValue {};
+
+    using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
+
+    /// IR values to their representation
     utils::Hashmap<Value*, ValueBinding, 32> bindings_;
 
+    /// Names for values
+    utils::Hashmap<Value*, Symbol, 32> names_;
+
     /// The nesting depth of the currently generated AST
     /// 0  is module scope
     /// 1  is root-level function scope
@@ -156,12 +174,13 @@
         // TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
         static constexpr size_t N = decltype(ast::Function::params)::static_length;
         auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
-            auto name = BindName(param);
+            auto name = NameFor(param);
+            Bind(param, name, PtrKind::kPtr);
             auto ty = Type(param->Type());
             return b.Param(name, ty);
         });
 
-        auto name = BindName(fn);
+        auto name = NameFor(fn);
         auto ret_ty = Type(fn->ReturnType());
         auto* body = Block(fn->Block());
         utils::Vector<const ast::Attribute*, 1> attrs{};
@@ -423,7 +442,8 @@
 
     void Var(ir::Var* var) {
         auto* val = var->Result();
-        Symbol name = BindName(val);
+        Symbol name = NameFor(var->Result());
+        Bind(var->Result(), name, PtrKind::kRef);
         auto* ptr = As<type::Pointer>(val->Type());
         auto ty = Type(ptr->StoreType());
 
@@ -457,16 +477,19 @@
     }
 
     void Call(ir::Call* call) {
-        auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) { return Expr(arg); });
+        auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) {
+            // Pointer-like arguments are passed by pointer, never reference.
+            return Expr(arg, PtrKind::kPtr);
+        });
         tint::Switch(
             call,  //
             [&](ir::UserCall* c) {
-                auto* expr = b.Call(BindName(c->Func()), std::move(args));
+                auto* expr = b.Call(NameFor(c->Func()), std::move(args));
                 if (!call->HasResults() || call->Result()->Usages().IsEmpty()) {
                     Append(b.CallStmt(expr));
                     return;
                 }
-                Bind(c->Result(), expr);
+                Bind(c->Result(), expr, PtrKind::kPtr);
             },
             [&](ir::BuiltinCall* c) {
                 auto* expr = b.Call(c->Func(), std::move(args));
@@ -474,15 +497,15 @@
                     Append(b.CallStmt(expr));
                     return;
                 }
-                Bind(c->Result(), expr);
+                Bind(c->Result(), expr, PtrKind::kPtr);
             },
             [&](ir::Construct* c) {
                 auto ty = Type(c->Result()->Type());
-                Bind(c->Result(), b.Call(ty, std::move(args)));
+                Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
             },
             [&](ir::Convert* c) {
                 auto ty = Type(c->Result()->Type());
-                Bind(c->Result(), b.Call(ty, std::move(args)));
+                Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
             },
             [&](Default) { UNHANDLED_CASE(call); });
     }
@@ -620,31 +643,35 @@
 
     TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
 
-    const ast::Expression* Expr(ir::Value* value) {
-        return tint::Switch(
-            value,                                         //
-            [&](ir::Constant* c) { return Constant(c); },  //
-            [&](Default) -> const ast::Expression* {
+    const ast::Expression* Expr(ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef) {
+        using ExprAndPtrKind = std::pair<const ast::Expression*, PtrKind>;
+
+        auto [expr, got_ptr_kind] = tint::Switch(
+            value,
+            [&](ir::Constant* c) -> ExprAndPtrKind {
+                return {Constant(c), PtrKind::kRef};
+            },
+            [&](Default) -> ExprAndPtrKind {
                 auto lookup = bindings_.Find(value);
                 if (TINT_UNLIKELY(!lookup)) {
                     TINT_ICE(IR, b.Diagnostics())
                         << "Expr(" << (value ? value->TypeInfo().name : "null")
                         << ") value has no expression";
-                    return b.Expr("<error>");
+                    return {};
                 }
                 return std::visit(
-                    [&](auto&& got) -> const ast::Expression* {
+                    [&](auto&& got) -> ExprAndPtrKind {
                         using T = std::decay_t<decltype(got)>;
 
-                        if constexpr (std::is_same_v<T, Symbol>) {
-                            return b.Expr(got);  // var, let or parameter.
+                        if constexpr (std::is_same_v<T, VariableValue>) {
+                            return {b.Expr(got.name), got.ptr_kind};
                         }
 
-                        if constexpr (std::is_same_v<T, const ast::Expression*>) {
+                        if constexpr (std::is_same_v<T, InlinedValue>) {
                             // Single use (inlined) expression.
                             // Mark the bindings_ map entry as consumed.
                             *lookup = ConsumedValue{};
-                            return got;
+                            return {got.expr, got.ptr_kind};
                         }
 
                         if constexpr (std::is_same_v<T, ConsumedValue>) {
@@ -654,10 +681,20 @@
                             TINT_ICE(IR, b.Diagnostics())
                                 << "Expr(" << value->TypeInfo().name << ") has unhandled value";
                         }
-                        return b.Expr("<error>");
+                        return {};
                     },
                     *lookup);
             });
+
+        if (!expr) {
+            return b.Expr("<error>");
+        }
+
+        if (value->Type()->Is<type::Pointer>()) {
+            return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
+        }
+
+        return expr;
     }
 
     TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
@@ -818,16 +855,24 @@
         return b.ty(n);
     }
 
+    const ast::Expression* ToPtrKind(const ast::Expression* in, PtrKind got, PtrKind want) {
+        if (want == PtrKind::kRef && got == PtrKind::kPtr) {
+            return b.Deref(in);
+        }
+        if (want == PtrKind::kPtr && got == PtrKind::kRef) {
+            return b.AddressOf(in);
+        }
+        return in;
+    }
+
     ////////////////////////////////////////////////////////////////////////////////////////////////
     // Bindings
     ////////////////////////////////////////////////////////////////////////////////////////////////
 
-    /// Creates and returns a new, unique name for the given value, or returns the previously
-    /// created name.
-    /// @return the value's name
-    Symbol BindName(Value* value, std::string_view suggested = {}) {
-        TINT_ASSERT(IR, value);
-        auto& existing = bindings_.GetOrCreate(value, [&] {
+    /// @returns the AST name for the given value, creating and returning a new name on the first
+    /// call.
+    Symbol NameFor(Value* value, std::string_view suggested = {}) {
+        return names_.GetOrCreate(value, [&] {
             if (!suggested.empty()) {
                 return b.Symbols().New(suggested);
             }
@@ -836,27 +881,41 @@
             }
             return b.Symbols().New("v");
         });
-        if (auto* name = std::get_if<Symbol>(&existing); TINT_LIKELY(name)) {
-            return *name;
-        }
-
-        TINT_ICE(IR, b.Diagnostics()) << "BindName(" << value->TypeInfo().name
-                                      << ") called on value that has non-name binding";
-        return {};
     }
 
-    template <typename T>
-    void Bind(ir::Value* value, const T* expr) {
+    /// Associates the IR value @p value with the AST expression @p expr.
+    /// @p ptr_kind defines how pointer values are represented by @p expr.
+    void Bind(ir::Value* value, const ast::Expression* expr, PtrKind ptr_kind = PtrKind::kRef) {
         TINT_ASSERT(IR, value);
         if (can_inline_.Remove(value)) {
             // Value will be inlined at its place of usage.
-            bool added = bindings_.Add(value, expr);
-            if (TINT_UNLIKELY(!added)) {
-                TINT_ICE(IR, b.Diagnostics())
-                    << "Bind(" << value->TypeInfo().name << ") called twice for same node";
+            if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
+                return;
             }
         } else {
-            Append(b.Decl(b.Let(BindName(value), expr)));
+            if (value->Type()->Is<type::Pointer>()) {
+                expr = ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
+            }
+            Symbol name = NameFor(value);
+            Append(b.Decl(b.Let(name, expr)));
+            Bind(value, name, PtrKind::kPtr);
+            return;
+        }
+
+        TINT_ICE(IR, b.Diagnostics())
+            << "Bind(" << value->TypeInfo().name << ") called twice for same value";
+    }
+
+    /// Associates the IR value @p value with the AST 'var', 'let' or parameter with the name @p
+    /// name.
+    /// @p ptr_kind defines how pointer values are represented by @p expr.
+    void Bind(ir::Value* value, Symbol name, PtrKind ptr_kind) {
+        TINT_ASSERT(IR, value);
+
+        bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
+        if (TINT_UNLIKELY(!added)) {
+            TINT_ICE(IR, b.Diagnostics())
+                << "Bind(" << value->TypeInfo().name << ") called twice for same value";
         }
     }
 
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 66ed84b..14bc8a1 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -152,6 +152,21 @@
 )");
 }
 
+TEST_F(IRToProgramRoundtripTest, FnCall_PtrArgs) {
+    Test(R"(
+var<private> y : i32 = 2i;
+
+fn a(px : ptr<function, i32>, py : ptr<private, i32>) -> i32 {
+  return (*(px) + *(py));
+}
+
+fn b() -> i32 {
+  var x : i32 = 1i;
+  return a(&(x), &(y));
+}
+)");
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Builtin Call
 ////////////////////////////////////////////////////////////////////////////////
@@ -171,6 +186,16 @@
 )");
 }
 
+TEST_F(IRToProgramRoundtripTest, BuiltinCall_PtrArg) {
+    Test(R"(
+var<workgroup> v : bool;
+
+fn foo() -> bool {
+  return workgroupUniformLoad(&(v));
+}
+)");
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Type Construct
 ////////////////////////////////////////////////////////////////////////////////
@@ -1243,6 +1268,31 @@
 }
 
 ////////////////////////////////////////////////////////////////////////////////
+// Function-scope let
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_i32) {
+    Test(R"(
+fn f(i : i32) -> i32 {
+  let a = (42i + i);
+  let b = (24i + i);
+  let c = (a + b);
+  return c;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_ptr) {
+    Test(R"(
+fn f() -> i32 {
+  var a : array<i32, 3u>;
+  let b = &(a[1i]);
+  let c = *(b);
+  return c;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
 // If
 ////////////////////////////////////////////////////////////////////////////////
 TEST_F(IRToProgramRoundtripTest, If_CallFn) {