[ir][to_program]: Simplify error handling
Emit semantically-invalid placeholder expressions to reduce the amount of error handling required through the code.
As well as greatly improving readability, will provide substantial performance wins.
Bug: tint:1902
Change-Id: I045755ad494af0db5cfb1cd4c40665c45bc61561
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/135940
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 33c9674..7563bf2 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -96,16 +96,10 @@
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
utils::Vector<const ast::Parameter*, 1> params{};
auto ret_ty = Type(fn->ReturnType());
- if (!ret_ty) {
- return nullptr;
- }
auto* body = BlockGraph(fn->StartTarget());
- if (!body) {
- return nullptr;
- }
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
- return b.Func(name, std::move(params), ret_ty.Get(), body, std::move(attrs),
+ return b.Func(name, std::move(params), ret_ty, body, std::move(attrs),
std::move(ret_attrs));
}
@@ -123,12 +117,8 @@
TINT_ASSERT(IR, block->HasBranchTarget());
for (auto* inst : *block) {
- auto stmt = Stmt(inst);
- if (TINT_UNLIKELY(!stmt)) {
- return nullptr;
- }
- if (auto* s = stmt.Get()) {
- stmts.Push(s);
+ if (auto* stmt = Stmt(inst)) {
+ stmts.Push(stmt);
}
}
if (auto* if_ = block->Branch()->As<ir::If>()) {
@@ -148,6 +138,35 @@
return b.Block(std::move(stmts));
}
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Statements
+ //
+ // Statement methods may return nullptr, in the case of instructions that do not map to an AST
+ // statement, or in the case of an error. These should simply be ignored.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @param inst the ir::Instruction
+ /// @return an ast::Statement from @p inst, or nullptr if there was an error
+ const ast::Statement* Stmt(const ir::Instruction* inst) {
+ return tint::Switch(
+ inst, //
+ [&](const ir::Call* i) { return CallStmt(i); }, //
+ [&](const ir::Var* i) { return Var(i); }, //
+ [&](const ir::Load*) { return nullptr; }, //
+ [&](const ir::Store* i) { return Store(i); }, //
+ [&](const ir::If* if_) { return If(if_); }, //
+ [&](const ir::Switch* switch_) { return Switch(switch_); }, //
+ [&](const ir::Return* ret) { return Return(ret); }, //
+ // TODO(dsinclair): Remove when branch is only a parent ...
+ [&](const ir::Branch*) { return nullptr; },
+ [&](Default) {
+ UNHANDLED_CASE(inst);
+ return nullptr;
+ });
+ }
+
+ /// @param i the ir::If
+ /// @return an ast::IfStatement from @p i, or nullptr if there was an error
const ast::IfStatement* If(const ir::If* i) {
SCOPED_NESTING();
auto* cond = Expr(i->Condition());
@@ -181,6 +200,8 @@
return b.If(cond, t);
}
+ /// @param s the ir::Switch
+ /// @return an ast::SwitchStatement from @p s, or nullptr if there was an error
const ast::SwitchStatement* Switch(const ir::Switch* s) {
SCOPED_NESTING();
@@ -223,7 +244,9 @@
return b.Switch(cond, std::move(cases));
}
- utils::Result<const ast::ReturnStatement*> Return(const ir::Return* ret) {
+ /// @param ret the ir::Return
+ /// @return an ast::ReturnStatement from @p ret, or nullptr if there was an error
+ const ast::ReturnStatement* Return(const ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
// If this block is nested withing some control flow, then we must
@@ -239,88 +262,61 @@
if (ret->Args().Length() != 1) {
TINT_ICE(IR, b.Diagnostics())
<< "expected 1 value for return, got " << ret->Args().Length();
- return utils::Failure;
+ return b.Return();
}
auto* val = Expr(ret->Args().Front());
if (TINT_UNLIKELY(!val)) {
- return utils::Failure;
+ return b.Return();
}
return b.Return(val);
}
- utils::Result<const ast::Statement*> Stmt(const ir::Instruction* inst) {
- return tint::Switch<utils::Result<const ast::Statement*>>(
- inst, //
- [&](const ir::Call* i) { return CallStmt(i); }, //
- [&](const ir::Var* i) { return Var(i); }, //
- [&](const ir::Load*) { return nullptr; },
- [&](const ir::Store* i) { return Store(i); }, //
- [&](const ir::If* if_) { return If(if_); },
- [&](const ir::Switch* switch_) { return Switch(switch_); },
- [&](const ir::Return* ret) { return Return(ret); },
- // TODO(dsinclair): Remove when branch is only a parent ...
- [&](const ir::Branch*) { return utils::Result<const ast::Statement*>{nullptr}; },
- [&](Default) {
- UNHANDLED_CASE(inst);
- return utils::Failure;
- });
- }
+ /// @param call the ir::Call
+ /// @return an ast::CallStatement from @p call, or nullptr if there was an error
+ const ast::CallStatement* CallStmt(const ir::Call* call) { return b.CallStmt(Call(call)); }
- const ast::CallStatement* CallStmt(const ir::Call* call) {
- auto* expr = Call(call);
- if (!expr) {
- return nullptr;
- }
- return b.CallStmt(expr);
- }
-
+ /// @param var the ir::Var
+ /// @return an ast::VariableDeclStatement from @p var
const ast::VariableDeclStatement* Var(const ir::Var* var) {
Symbol name = NameOf(var);
auto* ptr = var->Type()->As<type::Pointer>();
- if (!ptr) {
- Err("Incorrect type for var");
- return nullptr;
- }
auto ty = Type(ptr->StoreType());
const ast::Expression* init = nullptr;
if (var->Initializer()) {
init = Expr(var->Initializer());
- if (!init) {
- return nullptr;
- }
}
switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
- return b.Decl(b.Var(name, ty.Get(), init));
+ return b.Decl(b.Var(name, ty, init));
case builtin::AddressSpace::kStorage:
- return b.Decl(b.Var(name, ty.Get(), init, ptr->Access(), ptr->AddressSpace()));
+ return b.Decl(b.Var(name, ty, init, ptr->Access(), ptr->AddressSpace()));
default:
- return b.Decl(b.Var(name, ty.Get(), init, ptr->AddressSpace()));
+ return b.Decl(b.Var(name, ty, init, ptr->AddressSpace()));
}
}
+ /// @param store the ir::Store
+ /// @return an ast::AssignmentStatement from @p call
const ast::AssignmentStatement* Store(const ir::Store* store) {
auto* expr = Expr(store->From());
return b.Assign(NameOf(store->To()), expr);
}
- const ast::CallExpression* Call(const ir::Call* call) {
- auto args =
- utils::Transform<2>(call->Args(), [&](const ir::Value* arg) { return Expr(arg); });
- if (args.Any(utils::IsNull)) {
- return nullptr;
- }
- return tint::Switch(
- call, //
- [&](const ir::UserCall* c) { return b.Call(NameOf(c->Func()), std::move(args)); },
- [&](Default) {
- UNHANDLED_CASE(call);
- return nullptr;
- });
- }
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Expressions
+ //
+ // The the case of an error:
+ // * The expression generating methods must return a non-null ast expression pointer, which may
+ // not be semantically legal, but is enough to populate the AST.
+ // * A diagnostic error must be added to the ast::ProgramBuilder.
+ // This prevents littering the ToProgram logic with expensive error checking code.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ /// @param val the ir::Expression
+ /// @return an ast::Expression from @p val.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* Expr(const ir::Value* val) {
return tint::Switch(
val, //
@@ -329,10 +325,28 @@
[&](const ir::Var* v) { return VarExpr(v); },
[&](Default) {
UNHANDLED_CASE(val);
- return nullptr;
+ return b.Expr("<error>");
});
}
+ /// @param call the ir::Call
+ /// @return an ast::CallExpression from @p call.
+ /// @note May be a semantically-invalid placeholder expression on error.
+ const ast::CallExpression* Call(const ir::Call* call) {
+ auto args =
+ utils::Transform<2>(call->Args(), [&](const ir::Value* arg) { return Expr(arg); });
+ return tint::Switch(
+ call, //
+ [&](const ir::UserCall* c) { return b.Call(NameOf(c->Func()), std::move(args)); },
+ [&](Default) {
+ UNHANDLED_CASE(call);
+ return b.Call("<error>");
+ });
+ }
+
+ /// @param c the ir::Constant
+ /// @return an ast::Expression from @p c.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* ConstExpr(const ir::Constant* c) {
return tint::Switch(
c->Type(), //
@@ -343,16 +357,35 @@
[&](const type::Bool*) { return b.Expr(c->Value()->ValueAs<bool>()); },
[&](Default) {
UNHANDLED_CASE(c);
- return nullptr;
+ return b.Expr("<error>");
});
}
+ /// @param l the ir::Load
+ /// @return an ast::Expression from @p l.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->From()); }
+ /// @param v the ir::Var
+ /// @return an ast::Expression from @p v.
+ /// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* VarExpr(const ir::Var* v) { return b.Expr(NameOf(v)); }
- utils::Result<ast::Type> Type(const type::Type* ty) {
- return tint::Switch<utils::Result<ast::Type>>(
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Types
+ //
+ // The the case of an error:
+ // * The types generating methods must return a non-null ast type, which may not be semantically
+ // legal, but is enough to populate the AST.
+ // * A diagnostic error must be added to the ast::ProgramBuilder.
+ // This prevents littering the ToProgram logic with expensive error checking code.
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @param ty the type::Type
+ /// @return an ast::Type from @p ty.
+ /// @note May be a semantically-invalid placeholder type on error.
+ ast::Type Type(const type::Type* ty) {
+ return tint::Switch(
ty, //
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
@@ -360,98 +393,78 @@
[&](const type::F16*) { return b.ty.f16(); }, //
[&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
- [&](const type::Matrix* m) -> utils::Result<ast::Type> {
- auto el = Type(m->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.mat(el.Get(), m->columns(), m->rows());
+ [&](const type::Matrix* m) {
+ return b.ty.mat(Type(m->type()), m->columns(), m->rows());
},
- [&](const type::Vector* v) -> utils::Result<ast::Type> {
+ [&](const type::Vector* v) {
auto el = Type(v->type());
- if (!el) {
- return utils::Failure;
- }
if (v->Packed()) {
TINT_ASSERT(IR, v->Width() == 3u);
- return b.ty(builtin::Builtin::kPackedVec3, el.Get());
+ return b.ty(builtin::Builtin::kPackedVec3, el);
} else {
- return b.ty.vec(el.Get(), v->Width());
+ return b.ty.vec(el, v->Width());
}
},
- [&](const type::Array* a) -> utils::Result<ast::Type> {
+ [&](const type::Array* a) {
auto el = Type(a->ElemType());
- if (!el) {
- return utils::Failure;
- }
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(b.Stride(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
- return b.ty.array(el.Get(), std::move(attrs));
+ return b.ty.array(el, std::move(attrs));
}
auto count = a->ConstantCount();
if (TINT_UNLIKELY(!count)) {
TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
- return b.ty.array(el.Get(), u32(1), std::move(attrs));
+ return b.ty.array(el, u32(1), std::move(attrs));
}
- return b.ty.array(el.Get(), u32(count.value()), std::move(attrs));
+ return b.ty.array(el, u32(count.value()), std::move(attrs));
},
[&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
- [&](const type::Atomic* a) -> utils::Result<ast::Type> {
- auto el = Type(a->Type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.atomic(el.Get());
- },
+ [&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
return b.ty.depth_multisampled_texture(t->dim());
},
[&](const type::ExternalTexture*) { return b.ty.external_texture(); },
- [&](const type::MultisampledTexture* t) -> utils::Result<ast::Type> {
+ [&](const type::MultisampledTexture* t) {
auto el = Type(t->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.multisampled_texture(t->dim(), el.Get());
+ return b.ty.multisampled_texture(t->dim(), el);
},
- [&](const type::SampledTexture* t) -> utils::Result<ast::Type> {
+ [&](const type::SampledTexture* t) {
auto el = Type(t->type());
- if (!el) {
- return utils::Failure;
- }
- return b.ty.sampled_texture(t->dim(), el.Get());
+ return b.ty.sampled_texture(t->dim(), el);
},
[&](const type::StorageTexture* t) {
return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
},
[&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
- [&](const type::Pointer* p) -> utils::Result<ast::Type> {
+ [&](const type::Pointer* p) {
// Note: type::Pointer always has an inferred access, but WGSL only allows an
// explicit access in the 'storage' address space.
auto el = Type(p->StoreType());
- if (!el) {
- return utils::Failure;
- }
auto address_space = p->AddressSpace();
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
- return b.ty.pointer(el.Get(), address_space, access);
+ return b.ty.pointer(el, address_space, access);
},
- [&](const type::Reference*) -> utils::Result<ast::Type> {
+ [&](const type::Reference*) {
TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
- return ast::Type{};
+ return b.ty.i32();
},
[&](Default) {
UNHANDLED_CASE(ty);
- return ast::Type{};
+ return b.ty.i32();
});
}
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+ // Helpers
+ ////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /// @return the Symbol to emit for the given value
Symbol NameOf(const Value* value) {
TINT_ASSERT(IR, value);
return value_names_.GetOrCreate(value, [&] {
@@ -461,8 +474,6 @@
return b.Symbols().New("v" + std::to_string(value_names_.Count()));
});
}
-
- void Err(std::string str) { b.Diagnostics().add_error(diag::System::IR, std::move(str)); }
};
} // namespace