blob: c536ae6bce309248671104e5ec1232c1c5a7ed64 [file] [log] [blame]
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ir/to_program.h"
#include <string>
#include <utility>
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/exit_if.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/unary.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
#include "src/tint/program_builder.h"
#include "src/tint/switch.h"
#include "src/tint/type/atomic.h"
#include "src/tint/type/depth_multisampled_texture.h"
#include "src/tint/type/depth_texture.h"
#include "src/tint/type/multisampled_texture.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/reference.h"
#include "src/tint/type/sampler.h"
#include "src/tint/type/texture.h"
#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/predicates.h"
#include "src/tint/utils/transform.h"
#include "src/tint/utils/vector.h"
// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
#define UNHANDLED_CASE(object_ptr) \
TINT_UNIMPLEMENTED(IR, b.Diagnostics()) \
<< "unhandled case in Switch(): " << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
// Helper for incrementing nesting_depth_ and then decrementing nesting_depth_ at the end
// of the scope that holds the call.
#define SCOPED_NESTING() \
nesting_depth_++; \
TINT_DEFER(nesting_depth_--)
namespace tint::ir {
namespace {
class State {
public:
explicit State(Module& m) : mod(m) {}
Program Run() {
// TODO(crbug.com/tint/1902): Emit root block
// TODO(crbug.com/tint/1902): Emit user-declared types
for (auto* fn : mod.functions) {
Fn(fn);
}
return Program{std::move(b)};
}
private:
/// The source IR module
Module& mod;
/// The target ProgramBuilder
ProgramBuilder b;
/// A hashmap of value to symbol used in the emitted AST
utils::Hashmap<Value*, Symbol, 32> value_names_;
// The nesting depth of the currently generated AST
// 0 is module scope
// 1 is root-level function scope
// 2+ is within control flow
uint32_t nesting_depth_ = 0;
const ast::Function* Fn(Function* fn) {
SCOPED_NESTING();
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
static constexpr size_t N = decltype(ast::Function::params)::static_length;
auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
auto name = AssignNameTo(param);
auto ty = Type(param->Type());
return b.Param(name, ty);
});
auto name = AssignNameTo(fn);
auto ret_ty = Type(fn->ReturnType());
auto* body = BlockGraph(fn->StartTarget());
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
return b.Func(name, std::move(params), ret_ty, body, std::move(attrs),
std::move(ret_attrs));
}
const ast::BlockStatement* BlockGraph(ir::Block* start_node) {
// TODO(crbug.com/tint/1902): Check if the block is dead
utils::Vector<const ast::Statement*,
decltype(ast::BlockStatement::statements)::static_length>
stmts;
ir::Block* block = start_node;
// TODO(crbug.com/tint/1902): Handle block arguments.
while (block) {
TINT_ASSERT(IR, block->HasBranchTarget());
for (auto* inst : *block) {
if (auto* stmt = Stmt(inst)) {
stmts.Push(stmt);
}
}
if (auto* if_ = block->Branch()->As<ir::If>()) {
if (if_->Merge()->HasBranchTarget()) {
block = if_->Merge();
continue;
}
} else if (auto* switch_ = block->Branch()->As<ir::Switch>()) {
if (switch_->Merge()->HasBranchTarget()) {
block = switch_->Merge();
continue;
}
}
break;
}
return b.Block(std::move(stmts));
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Statements
//
// Statement methods may return nullptr, in the case of instructions that do not map to an AST
// statement, or in the case of an error. These should simply be ignored.
////////////////////////////////////////////////////////////////////////////////////////////////
/// @param inst the ir::Instruction
/// @return an ast::Statement from @p inst, or nullptr if there was an error
const ast::Statement* Stmt(ir::Instruction* inst) {
return tint::Switch(
inst, //
[&](ir::Store* i) { return Store(i); }, //
[&](ir::Call* i) { return CallStmt(i); }, //
[&](ir::Var* i) { return Var(i); }, //
[&](ir::If* if_) { return If(if_); }, //
[&](ir::Switch* switch_) { return Switch(switch_); }, //
[&](ir::Return* ret) { return Return(ret); }, //
[&](ir::Value*) { return ValueStmt(inst); },
// TODO(dsinclair): Remove when branch is only a parent ...
[&](ir::Branch*) { return nullptr; },
[&](Default) {
UNHANDLED_CASE(inst);
return nullptr;
});
}
/// @param i the ir::If
/// @return an ast::IfStatement from @p i, or nullptr if there was an error
const ast::IfStatement* If(ir::If* i) {
SCOPED_NESTING();
auto* cond = Expr(i->Condition());
auto* t = BlockGraph(i->True());
if (TINT_UNLIKELY(!t)) {
return nullptr;
}
auto* false_blk = i->False();
if (false_blk->Length() > 1 || (false_blk->Length() == 1 && false_blk->HasBranchTarget() &&
!false_blk->Branch()->Is<ir::ExitIf>())) {
// If the else target is an `if` which has a merge target that just bounces to the outer
// if merge target then emit an 'else if' instead of a block statement for the else.
if (auto* inst = i->False()->Instructions(); inst && inst->As<ir::If>()) {
auto* if_ = inst->As<ir::If>();
if (auto* br = if_->Merge()->Branch()->As<ir::ExitIf>(); br && br->If() == i) {
auto* f = If(if_);
if (!f) {
return nullptr;
}
return b.If(cond, t, b.Else(f));
}
} else {
auto* f = BlockGraph(i->False());
if (!f) {
return nullptr;
}
return b.If(cond, t, b.Else(f));
}
}
return b.If(cond, t);
}
/// @param s the ir::Switch
/// @return an ast::SwitchStatement from @p s, or nullptr if there was an error
const ast::SwitchStatement* Switch(ir::Switch* s) {
SCOPED_NESTING();
auto* cond = Expr(s->Condition());
if (!cond) {
return nullptr;
}
auto cases =
utils::Transform(s->Cases(), //
[&](ir::Switch::Case c) -> const tint::ast::CaseStatement* {
SCOPED_NESTING();
auto* body = BlockGraph(c.start);
if (!body) {
return nullptr;
}
auto selectors = utils::Transform(
c.selectors, //
[&](ir::Switch::CaseSelector cs) -> const ast::CaseSelector* {
if (cs.IsDefault()) {
return b.DefaultCaseSelector();
}
auto* expr = Expr(cs.val);
if (!expr) {
return nullptr;
}
return b.CaseSelector(expr);
});
if (selectors.Any(utils::IsNull)) {
return nullptr;
}
return b.Case(std::move(selectors), body);
});
if (cases.Any(utils::IsNull)) {
return nullptr;
}
return b.Switch(cond, std::move(cases));
}
/// @param ret the ir::Return
/// @return an ast::ReturnStatement from @p ret, or nullptr if there was an error
const ast::ReturnStatement* Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
// If this block is nested withing some control flow, then we must
// emit a 'return' statement, otherwise we've just naturally reached
// the end of the function where the 'return' is redundant.
if (nesting_depth_ > 1) {
return b.Return();
}
return nullptr;
}
// Return has arguments - this is the return value.
if (ret->Args().Length() != 1) {
TINT_ICE(IR, b.Diagnostics())
<< "expected 1 value for return, got " << ret->Args().Length();
return b.Return();
}
auto* val = Expr(ret->Args().Front());
if (TINT_UNLIKELY(!val)) {
return b.Return();
}
return b.Return(val);
}
/// @param call the ir::Call
/// @return an ast::CallStatement from @p call, or nullptr if there was an error
const ast::CallStatement* CallStmt(ir::Call* call) { return b.CallStmt(Call(call)); }
/// @param var the ir::Var
/// @return an ast::VariableDeclStatement from @p var
const ast::VariableDeclStatement* Var(ir::Var* var) {
Symbol name = AssignNameTo(var);
auto* ptr = var->Type();
auto ty = Type(ptr->StoreType());
const ast::Expression* init = nullptr;
if (var->Initializer()) {
init = Expr(var->Initializer());
}
switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
return b.Decl(b.Var(name, ty, init));
case builtin::AddressSpace::kStorage:
return b.Decl(b.Var(name, ty, init, ptr->Access(), ptr->AddressSpace()));
default:
return b.Decl(b.Var(name, ty, init, ptr->AddressSpace()));
}
}
/// @param store the ir::Store
/// @return an ast::AssignmentStatement from @p call
const ast::AssignmentStatement* Store(ir::Store* store) {
auto* expr = Expr(store->From());
return b.Assign(AssignNameTo(store->To()), expr);
}
/// @param val the ir::Value
/// @return an ast::Statement from @p val, or nullptr if the value does not produce a statement.
const ast::Statement* ValueStmt(ir::Value* val) {
// As we're visiting this value's declaration it shouldn't already have a name reserved.
TINT_ASSERT(IR, !value_names_.Contains(val));
// Determine whether the value should be placed into a let, or inlined in its single place
// of usage. Currently a value is inlined if it has a single usage and is unnamed.
// TODO(crbug.com/tint/1902): This logic needs to check that the sequence of side-effecting
// expressions is not changed by inlining the expression. This needs fixing.
bool create_let = val->Usages().Count() > 1 || mod.NameOf(val).IsValid();
if (create_let) {
auto* init = Expr(val); // Must come before giving the value a name
auto name = AssignNameTo(val);
return b.Decl(b.Let(name, init));
}
return nullptr; // Value will be inlined at its place of usage.
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Expressions
//
// The the case of an error:
// * The expression generating methods must return a non-null ast expression pointer, which may
// not be semantically legal, but is enough to populate the AST.
// * A diagnostic error must be added to the ast::ProgramBuilder.
// This prevents littering the ToProgram logic with expensive error checking code.
////////////////////////////////////////////////////////////////////////////////////////////////
/// @param val the ir::Expression
/// @return an ast::Expression from @p val.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* Expr(ir::Value* val) {
if (auto name = value_names_.Get(val)) {
return b.Expr(name.value());
}
return tint::Switch(
val, //
[&](ir::Constant* c) { return ConstExpr(c); }, //
[&](ir::Load* l) { return LoadExpr(l); }, //
[&](ir::Unary* u) { return UnaryExpr(u); }, //
[&](ir::Binary* u) { return BinaryExpr(u); }, //
[&](Default) {
UNHANDLED_CASE(val);
return b.Expr("<error>");
});
}
/// @param call the ir::Call
/// @return an ast::CallExpression from @p call.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::CallExpression* Call(ir::Call* call) {
auto args = utils::Transform<2>(call->Args(), [&](ir::Value* arg) { return Expr(arg); });
return tint::Switch(
call, //
[&](ir::UserCall* c) { return b.Call(AssignNameTo(c->Func()), std::move(args)); },
[&](Default) {
UNHANDLED_CASE(call);
return b.Call("<error>");
});
}
/// @param c the ir::Constant
/// @return an ast::Expression from @p c.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* ConstExpr(ir::Constant* c) {
return tint::Switch(
c->Type(), //
[&](const type::I32*) { return b.Expr(c->Value()->ValueAs<i32>()); },
[&](const type::U32*) { return b.Expr(c->Value()->ValueAs<u32>()); },
[&](const type::F32*) { return b.Expr(c->Value()->ValueAs<f32>()); },
[&](const type::F16*) { return b.Expr(c->Value()->ValueAs<f16>()); },
[&](const type::Bool*) { return b.Expr(c->Value()->ValueAs<bool>()); },
[&](Default) {
UNHANDLED_CASE(c);
return b.Expr("<error>");
});
}
/// @param l the ir::Load
/// @return an ast::Expression from @p l.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* LoadExpr(ir::Load* l) { return Expr(l->From()); }
/// @param u the ir::Unary
/// @return an ast::UnaryOpExpression from @p u.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* UnaryExpr(ir::Unary* u) {
switch (u->Kind()) {
case ir::Unary::Kind::kComplement:
return b.Complement(Expr(u->Val()));
case ir::Unary::Kind::kNegation:
return b.Negation(Expr(u->Val()));
}
return b.Expr("<error>");
}
/// @param e the ir::Binary
/// @return an ast::BinaryOpExpression from @p e.
/// @note May be a semantically-invalid placeholder expression on error.
const ast::Expression* BinaryExpr(ir::Binary* e) {
if (e->Kind() == ir::Binary::Kind::kEqual) {
auto* rhs = e->RHS()->As<ir::Constant>();
if (rhs && rhs->Type()->Is<type::Bool>() && rhs->Value()->ValueAs<bool>() == false) {
// expr == false
return b.Not(Expr(e->LHS()));
}
}
auto* lhs = Expr(e->LHS());
auto* rhs = Expr(e->RHS());
switch (e->Kind()) {
case ir::Binary::Kind::kAdd:
return b.Add(lhs, rhs);
case ir::Binary::Kind::kSubtract:
return b.Sub(lhs, rhs);
case ir::Binary::Kind::kMultiply:
return b.Mul(lhs, rhs);
case ir::Binary::Kind::kDivide:
return b.Div(lhs, rhs);
case ir::Binary::Kind::kModulo:
return b.Mod(lhs, rhs);
case ir::Binary::Kind::kAnd:
return b.And(lhs, rhs);
case ir::Binary::Kind::kOr:
return b.Or(lhs, rhs);
case ir::Binary::Kind::kXor:
return b.Xor(lhs, rhs);
case ir::Binary::Kind::kEqual:
return b.Equal(lhs, rhs);
case ir::Binary::Kind::kNotEqual:
return b.NotEqual(lhs, rhs);
case ir::Binary::Kind::kLessThan:
return b.LessThan(lhs, rhs);
case ir::Binary::Kind::kGreaterThan:
return b.GreaterThan(lhs, rhs);
case ir::Binary::Kind::kLessThanEqual:
return b.LessThanEqual(lhs, rhs);
case ir::Binary::Kind::kGreaterThanEqual:
return b.GreaterThanEqual(lhs, rhs);
case ir::Binary::Kind::kShiftLeft:
return b.Shl(lhs, rhs);
case ir::Binary::Kind::kShiftRight:
return b.Shr(lhs, rhs);
}
return b.Expr("<error>");
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Types
//
// The the case of an error:
// * The types generating methods must return a non-null ast type, which may not be semantically
// legal, but is enough to populate the AST.
// * A diagnostic error must be added to the ast::ProgramBuilder.
// This prevents littering the ToProgram logic with expensive error checking code.
////////////////////////////////////////////////////////////////////////////////////////////////
/// @param ty the type::Type
/// @return an ast::Type from @p ty.
/// @note May be a semantically-invalid placeholder type on error.
ast::Type Type(const type::Type* ty) {
return tint::Switch(
ty, //
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
[&](const type::U32*) { return b.ty.u32(); }, //
[&](const type::F16*) { return b.ty.f16(); }, //
[&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
[&](const type::Matrix* m) {
return b.ty.mat(Type(m->type()), m->columns(), m->rows());
},
[&](const type::Vector* v) {
auto el = Type(v->type());
if (v->Packed()) {
TINT_ASSERT(IR, v->Width() == 3u);
return b.ty(builtin::Builtin::kPackedVec3, el);
} else {
return b.ty.vec(el, v->Width());
}
},
[&](const type::Array* a) {
auto el = Type(a->ElemType());
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(b.Stride(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(el, std::move(attrs));
}
auto count = a->ConstantCount();
if (TINT_UNLIKELY(!count)) {
TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
return b.ty.array(el, u32(1), std::move(attrs));
}
return b.ty.array(el, u32(count.value()), std::move(attrs));
},
[&](const type::Struct* s) { return b.ty(s->Name().NameView()); },
[&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
return b.ty.depth_multisampled_texture(t->dim());
},
[&](const type::ExternalTexture*) { return b.ty.external_texture(); },
[&](const type::MultisampledTexture* t) {
auto el = Type(t->type());
return b.ty.multisampled_texture(t->dim(), el);
},
[&](const type::SampledTexture* t) {
auto el = Type(t->type());
return b.ty.sampled_texture(t->dim(), el);
},
[&](const type::StorageTexture* t) {
return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
},
[&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
[&](const type::Pointer* p) {
// Note: type::Pointer always has an inferred access, but WGSL only allows an
// explicit access in the 'storage' address space.
auto el = Type(p->StoreType());
auto address_space = p->AddressSpace();
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
return b.ty.ptr(address_space, el, access);
},
[&](const type::Reference*) {
TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
return b.ty.i32();
},
[&](Default) {
UNHANDLED_CASE(ty);
return b.ty.i32();
});
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers
////////////////////////////////////////////////////////////////////////////////////////////////
/// Creates and returns a new, unique name for the given value, or returns the previously
/// created name.
/// @return the value's name
Symbol AssignNameTo(Value* value) {
TINT_ASSERT(IR, value);
return value_names_.GetOrCreate(value, [&] {
if (auto sym = mod.NameOf(value)) {
return b.Symbols().New(sym.Name());
}
return b.Symbols().New("v" + std::to_string(value_names_.Count()));
});
}
};
} // namespace
Program ToProgram(Module& i) {
return State{i}.Run();
}
} // namespace tint::ir