[tint][it][ToProgram] Correctly handle pointers
Involves some additional tracking as the IR doesn't have the concept of references.
Change-Id: I64a87ea0c4971ef058e96425cc981e7bdacfe974
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/139644
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 69df972..3e30cfb 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -79,10 +79,6 @@
namespace {
-/// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by its
-/// single place of usage. Attempting to use this value a second time should result in an ICE.
-struct ConsumedValue {};
-
class State {
public:
explicit State(Module& m) : mod(m) {}
@@ -105,21 +101,43 @@
}
private:
+ /// The AST representation for an IR pointer type
+ enum class PtrKind {
+ kPtr, // IR pointer is represented in the AST as a pointer
+ kRef, // IR pointer is represented in the AST as a reference
+ };
+
/// The source IR module
Module& mod;
/// The target ProgramBuilder
ProgramBuilder b;
- using ValueBinding = std::variant<Symbol, const ast::Expression*, ConsumedValue>;
+ /// The structure for a value held by a 'let', 'var' or parameter.
+ struct VariableValue {
+ Symbol name; // Name of the variable
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
- /// A hashmap of value to one of:
- /// * Symbol - Name of 'let' (non-inlinable value), 'var' or parameter.
- /// * ast::Expression* - single use, inlined expression.
- /// * ConsumedValue - a special value used to indicate that the value has already been
- /// consumed.
+ /// The structure for an inlined value
+ struct InlinedValue {
+ const ast::Expression* expr = nullptr;
+ PtrKind ptr_kind = PtrKind::kRef;
+ };
+
+ /// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by
+ /// its single place of usage. Attempting to use this value a second time should result in an
+ /// ICE.
+ struct ConsumedValue {};
+
+ using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
+
+ /// IR values to their representation
utils::Hashmap<Value*, ValueBinding, 32> bindings_;
+ /// Names for values
+ utils::Hashmap<Value*, Symbol, 32> names_;
+
/// The nesting depth of the currently generated AST
/// 0 is module scope
/// 1 is root-level function scope
@@ -156,12 +174,13 @@
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
static constexpr size_t N = decltype(ast::Function::params)::static_length;
auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
- auto name = BindName(param);
+ auto name = NameFor(param);
+ Bind(param, name, PtrKind::kPtr);
auto ty = Type(param->Type());
return b.Param(name, ty);
});
- auto name = BindName(fn);
+ auto name = NameFor(fn);
auto ret_ty = Type(fn->ReturnType());
auto* body = Block(fn->Block());
utils::Vector<const ast::Attribute*, 1> attrs{};
@@ -423,7 +442,8 @@
void Var(ir::Var* var) {
auto* val = var->Result();
- Symbol name = BindName(val);
+ Symbol name = NameFor(var->Result());
+ Bind(var->Result(), name, PtrKind::kRef);
auto* ptr = As<type::Pointer>(val->Type());
auto ty = Type(ptr->StoreType());
@@ -457,16 +477,19 @@
}
void Call(ir::Call* call) {
- auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) { return Expr(arg); });
+ auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) {
+ // Pointer-like arguments are passed by pointer, never reference.
+ return Expr(arg, PtrKind::kPtr);
+ });
tint::Switch(
call, //
[&](ir::UserCall* c) {
- auto* expr = b.Call(BindName(c->Func()), std::move(args));
+ auto* expr = b.Call(NameFor(c->Func()), std::move(args));
if (!call->HasResults() || call->Result()->Usages().IsEmpty()) {
Append(b.CallStmt(expr));
return;
}
- Bind(c->Result(), expr);
+ Bind(c->Result(), expr, PtrKind::kPtr);
},
[&](ir::BuiltinCall* c) {
auto* expr = b.Call(c->Func(), std::move(args));
@@ -474,15 +497,15 @@
Append(b.CallStmt(expr));
return;
}
- Bind(c->Result(), expr);
+ Bind(c->Result(), expr, PtrKind::kPtr);
},
[&](ir::Construct* c) {
auto ty = Type(c->Result()->Type());
- Bind(c->Result(), b.Call(ty, std::move(args)));
+ Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
},
[&](ir::Convert* c) {
auto ty = Type(c->Result()->Type());
- Bind(c->Result(), b.Call(ty, std::move(args)));
+ Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
},
[&](Default) { UNHANDLED_CASE(call); });
}
@@ -620,31 +643,35 @@
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
- const ast::Expression* Expr(ir::Value* value) {
- return tint::Switch(
- value, //
- [&](ir::Constant* c) { return Constant(c); }, //
- [&](Default) -> const ast::Expression* {
+ const ast::Expression* Expr(ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef) {
+ using ExprAndPtrKind = std::pair<const ast::Expression*, PtrKind>;
+
+ auto [expr, got_ptr_kind] = tint::Switch(
+ value,
+ [&](ir::Constant* c) -> ExprAndPtrKind {
+ return {Constant(c), PtrKind::kRef};
+ },
+ [&](Default) -> ExprAndPtrKind {
auto lookup = bindings_.Find(value);
if (TINT_UNLIKELY(!lookup)) {
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << (value ? value->TypeInfo().name : "null")
<< ") value has no expression";
- return b.Expr("<error>");
+ return {};
}
return std::visit(
- [&](auto&& got) -> const ast::Expression* {
+ [&](auto&& got) -> ExprAndPtrKind {
using T = std::decay_t<decltype(got)>;
- if constexpr (std::is_same_v<T, Symbol>) {
- return b.Expr(got); // var, let or parameter.
+ if constexpr (std::is_same_v<T, VariableValue>) {
+ return {b.Expr(got.name), got.ptr_kind};
}
- if constexpr (std::is_same_v<T, const ast::Expression*>) {
+ if constexpr (std::is_same_v<T, InlinedValue>) {
// Single use (inlined) expression.
// Mark the bindings_ map entry as consumed.
*lookup = ConsumedValue{};
- return got;
+ return {got.expr, got.ptr_kind};
}
if constexpr (std::is_same_v<T, ConsumedValue>) {
@@ -654,10 +681,20 @@
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << value->TypeInfo().name << ") has unhandled value";
}
- return b.Expr("<error>");
+ return {};
},
*lookup);
});
+
+ if (!expr) {
+ return b.Expr("<error>");
+ }
+
+ if (value->Type()->Is<type::Pointer>()) {
+ return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
+ }
+
+ return expr;
}
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
@@ -818,16 +855,24 @@
return b.ty(n);
}
+ const ast::Expression* ToPtrKind(const ast::Expression* in, PtrKind got, PtrKind want) {
+ if (want == PtrKind::kRef && got == PtrKind::kPtr) {
+ return b.Deref(in);
+ }
+ if (want == PtrKind::kPtr && got == PtrKind::kRef) {
+ return b.AddressOf(in);
+ }
+ return in;
+ }
+
////////////////////////////////////////////////////////////////////////////////////////////////
// Bindings
////////////////////////////////////////////////////////////////////////////////////////////////
- /// Creates and returns a new, unique name for the given value, or returns the previously
- /// created name.
- /// @return the value's name
- Symbol BindName(Value* value, std::string_view suggested = {}) {
- TINT_ASSERT(IR, value);
- auto& existing = bindings_.GetOrCreate(value, [&] {
+ /// @returns the AST name for the given value, creating and returning a new name on the first
+ /// call.
+ Symbol NameFor(Value* value, std::string_view suggested = {}) {
+ return names_.GetOrCreate(value, [&] {
if (!suggested.empty()) {
return b.Symbols().New(suggested);
}
@@ -836,27 +881,41 @@
}
return b.Symbols().New("v");
});
- if (auto* name = std::get_if<Symbol>(&existing); TINT_LIKELY(name)) {
- return *name;
- }
-
- TINT_ICE(IR, b.Diagnostics()) << "BindName(" << value->TypeInfo().name
- << ") called on value that has non-name binding";
- return {};
}
- template <typename T>
- void Bind(ir::Value* value, const T* expr) {
+ /// Associates the IR value @p value with the AST expression @p expr.
+ /// @p ptr_kind defines how pointer values are represented by @p expr.
+ void Bind(ir::Value* value, const ast::Expression* expr, PtrKind ptr_kind = PtrKind::kRef) {
TINT_ASSERT(IR, value);
if (can_inline_.Remove(value)) {
// Value will be inlined at its place of usage.
- bool added = bindings_.Add(value, expr);
- if (TINT_UNLIKELY(!added)) {
- TINT_ICE(IR, b.Diagnostics())
- << "Bind(" << value->TypeInfo().name << ") called twice for same node";
+ if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
+ return;
}
} else {
- Append(b.Decl(b.Let(BindName(value), expr)));
+ if (value->Type()->Is<type::Pointer>()) {
+ expr = ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
+ }
+ Symbol name = NameFor(value);
+ Append(b.Decl(b.Let(name, expr)));
+ Bind(value, name, PtrKind::kPtr);
+ return;
+ }
+
+ TINT_ICE(IR, b.Diagnostics())
+ << "Bind(" << value->TypeInfo().name << ") called twice for same value";
+ }
+
+ /// Associates the IR value @p value with the AST 'var', 'let' or parameter with the name @p
+ /// name.
+ /// @p ptr_kind defines how pointer values are represented by @p expr.
+ void Bind(ir::Value* value, Symbol name, PtrKind ptr_kind) {
+ TINT_ASSERT(IR, value);
+
+ bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
+ if (TINT_UNLIKELY(!added)) {
+ TINT_ICE(IR, b.Diagnostics())
+ << "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
}
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 66ed84b..14bc8a1 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -152,6 +152,21 @@
)");
}
+TEST_F(IRToProgramRoundtripTest, FnCall_PtrArgs) {
+ Test(R"(
+var<private> y : i32 = 2i;
+
+fn a(px : ptr<function, i32>, py : ptr<private, i32>) -> i32 {
+ return (*(px) + *(py));
+}
+
+fn b() -> i32 {
+ var x : i32 = 1i;
+ return a(&(x), &(y));
+}
+)");
+}
+
////////////////////////////////////////////////////////////////////////////////
// Builtin Call
////////////////////////////////////////////////////////////////////////////////
@@ -171,6 +186,16 @@
)");
}
+TEST_F(IRToProgramRoundtripTest, BuiltinCall_PtrArg) {
+ Test(R"(
+var<workgroup> v : bool;
+
+fn foo() -> bool {
+ return workgroupUniformLoad(&(v));
+}
+)");
+}
+
////////////////////////////////////////////////////////////////////////////////
// Type Construct
////////////////////////////////////////////////////////////////////////////////
@@ -1243,6 +1268,31 @@
}
////////////////////////////////////////////////////////////////////////////////
+// Function-scope let
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_i32) {
+ Test(R"(
+fn f(i : i32) -> i32 {
+ let a = (42i + i);
+ let b = (24i + i);
+ let c = (a + b);
+ return c;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, FunctionScopeLet_ptr) {
+ Test(R"(
+fn f() -> i32 {
+ var a : array<i32, 3u>;
+ let b = &(a[1i]);
+ let c = *(b);
+ return c;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
// If
////////////////////////////////////////////////////////////////////////////////
TEST_F(IRToProgramRoundtripTest, If_CallFn) {