blob: 3643af87df818d1168576f97e1566cc289e2d2ea [file] [log] [blame]
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/core/ir/to_program.h"
#include <string>
#include <tuple>
#include <utility>
#include "src/tint/lang/core/builtin/builtin.h"
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/bitcast.h"
#include "src/tint/lang/core/ir/block.h"
#include "src/tint/lang/core/ir/break_if.h"
#include "src/tint/lang/core/ir/call.h"
#include "src/tint/lang/core/ir/constant.h"
#include "src/tint/lang/core/ir/construct.h"
#include "src/tint/lang/core/ir/continue.h"
#include "src/tint/lang/core/ir/convert.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/discard.h"
#include "src/tint/lang/core/ir/exit_if.h"
#include "src/tint/lang/core/ir/exit_loop.h"
#include "src/tint/lang/core/ir/exit_switch.h"
#include "src/tint/lang/core/ir/if.h"
#include "src/tint/lang/core/ir/instruction.h"
#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
#include "src/tint/lang/core/ir/load_vector_element.h"
#include "src/tint/lang/core/ir/loop.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
#include "src/tint/lang/core/ir/next_iteration.h"
#include "src/tint/lang/core/ir/return.h"
#include "src/tint/lang/core/ir/store.h"
#include "src/tint/lang/core/ir/store_vector_element.h"
#include "src/tint/lang/core/ir/switch.h"
#include "src/tint/lang/core/ir/swizzle.h"
#include "src/tint/lang/core/ir/transform/rename_conflicts_wgsl.h"
#include "src/tint/lang/core/ir/unary.h"
#include "src/tint/lang/core/ir/unreachable.h"
#include "src/tint/lang/core/ir/user_call.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/ir/var.h"
#include "src/tint/lang/core/type/atomic.h"
#include "src/tint/lang/core/type/depth_multisampled_texture.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/reference.h"
#include "src/tint/lang/core/type/sampler.h"
#include "src/tint/lang/core/type/texture.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/utils/containers/hashmap.h"
#include "src/tint/utils/containers/predicates.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/containers/transform.h"
#include "src/tint/utils/containers/vector.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/math/math.h"
#include "src/tint/utils/rtti/switch.h"
// Helper for calling TINT_UNIMPLEMENTED() from a Switch(object_ptr) default case.
#define UNHANDLED_CASE(object_ptr) \
TINT_UNIMPLEMENTED(IR, b.Diagnostics()) \
<< "unhandled case in Switch(): " << (object_ptr ? object_ptr->TypeInfo().name : "<null>")
// Helper for incrementing nesting_depth_ and then decrementing nesting_depth_ at the end
// of the scope that holds the call.
#define SCOPED_NESTING() \
nesting_depth_++; \
TINT_DEFER(nesting_depth_--)
namespace tint::ir {
namespace {
class State {
public:
explicit State(Module& m) : mod(m) {}
Program Run() {
if (auto res = Validate(mod); !res) {
// IR module failed validation.
b.Diagnostics() = res.Failure();
return Program{std::move(b)};
}
transform::Transform::DataMap data;
transform::RenameConflictsWGSL{}.Run(&mod, data, data);
if (mod.root_block) {
RootBlock(mod.root_block);
}
// TODO(crbug.com/tint/1902): Emit user-declared types
for (auto* fn : mod.functions) {
Fn(fn);
}
return Program{std::move(b)};
}
private:
/// The AST representation for an IR pointer type
enum class PtrKind {
kPtr, // IR pointer is represented in the AST as a pointer
kRef, // IR pointer is represented in the AST as a reference
};
/// The source IR module
Module& mod;
/// The target ProgramBuilder
ProgramBuilder b;
/// The structure for a value held by a 'let', 'var' or parameter.
struct VariableValue {
Symbol name; // Name of the variable
PtrKind ptr_kind = PtrKind::kRef;
};
/// The structure for an inlined value
struct InlinedValue {
const ast::Expression* expr = nullptr;
PtrKind ptr_kind = PtrKind::kRef;
};
/// Empty struct used as a sentinel value to indicate that an ast::Value has been consumed by
/// its single place of usage. Attempting to use this value a second time should result in an
/// ICE.
struct ConsumedValue {};
using ValueBinding = std::variant<VariableValue, InlinedValue, ConsumedValue>;
/// IR values to their representation
utils::Hashmap<Value*, ValueBinding, 32> bindings_;
/// Names for values
utils::Hashmap<Value*, Symbol, 32> names_;
/// The nesting depth of the currently generated AST
/// 0 is module scope
/// 1 is root-level function scope
/// 2+ is within control flow
uint32_t nesting_depth_ = 0;
using StatementList = utils::Vector<const ast::Statement*,
decltype(ast::BlockStatement::statements)::static_length>;
StatementList* statements_ = nullptr;
/// The current switch case block
ir::Block* current_switch_case_ = nullptr;
/// Values that can be inlined.
utils::Hashset<ir::Value*, 64> can_inline_;
/// Set of enable directives emitted.
utils::Hashset<builtin::Extension, 4> enables_;
/// Map of struct to output program name.
utils::Hashmap<const type::Struct*, Symbol, 8> structs_;
/// True if 'diagnostic(off, derivative_uniformity)' has been emitted
bool disabled_derivative_uniformity_ = false;
void RootBlock(ir::Block* root) {
for (auto* inst : *root) {
tint::Switch(
inst, //
[&](ir::Var* var) { Var(var); }, //
[&](Default) { UNHANDLED_CASE(inst); });
}
}
const ast::Function* Fn(ir::Function* fn) {
SCOPED_NESTING();
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
static constexpr size_t N = decltype(ast::Function::params)::static_length;
auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
auto ty = Type(param->Type());
auto name = NameFor(param);
Bind(param, name, PtrKind::kPtr);
return b.Param(name, ty);
});
auto name = NameFor(fn);
auto ret_ty = Type(fn->ReturnType());
auto* body = Block(fn->Block());
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
return b.Func(name, std::move(params), ret_ty, body, std::move(attrs),
std::move(ret_attrs));
}
const ast::BlockStatement* Block(ir::Block* block) {
// TODO(crbug.com/tint/1902): Handle block arguments.
return b.Block(Statements(block));
}
StatementList Statements(ir::Block* block) {
StatementList stmts;
if (block) {
MarkInlinable(block);
TINT_SCOPED_ASSIGNMENT(statements_, &stmts);
for (auto* inst : *block) {
Instruction(inst);
}
}
return stmts;
}
void MarkInlinable(ir::Block* block) {
// An ordered list of possibly-inlinable values returned by sequenced instructions that have
// not yet been marked-for or ruled-out-for inlining.
utils::UniqueVector<ir::Value*, 32> pending_resolution;
// Walk the instructions of the block starting with the first.
for (auto* inst : *block) {
// Is the instruction sequenced?
bool sequenced = inst->Sequenced();
// Walk the instruction's operands starting with the right-most.
auto operands = inst->Operands();
for (auto* operand : utils::Reverse(operands)) {
if (!pending_resolution.Contains(operand)) {
continue;
}
// Operand is in 'pending_resolution'
if (pending_resolution.TryPop(operand)) {
// Operand was the last sequenced value to be added to 'pending_resolution'
// This operand can be inlined as it does not change the sequencing order.
can_inline_.Add(operand);
sequenced = true; // Inherit the 'sequenced' flag from the inlined value
} else {
// Operand was in 'pending_resolution', but was not the last sequenced value to
// be added. Inlining this operand would break the sequencing order, so must be
// emitted as a let. All preceding pending values must also be emitted as a
// let to prevent them being inlined and breaking the sequencing order.
// Remove all the values in pending upto and including 'operand'.
for (size_t i = 0; i < pending_resolution.Length(); i++) {
if (pending_resolution[i] == operand) {
pending_resolution.Erase(0, i + 1);
break;
}
}
}
}
if (inst->Results().Length() == 1) {
// Instruction has a single result value.
// Check to see if the result of this instruction is a candidate for inlining.
auto* result = inst->Result();
// Only values with a single usage can be inlined.
// Named values are not inlined, as we want to emit the name for a let.
if (result->Usages().Count() == 1 && !mod.NameOf(result).IsValid()) {
if (sequenced) {
// The value comes from a sequenced instruction. We need to ensure
// instruction ordering so add it to 'pending_resolution'.
pending_resolution.Add(result);
} else {
// The value comes from an unsequenced instruction. Just inline.
can_inline_.Add(result);
}
continue;
}
}
// At this point the value has been ruled out for inlining.
if (sequenced) {
// A sequenced instruction with zero or multiple return values cannot be inlined.
// All preceding sequenced instructions cannot be inlined past this point.
pending_resolution.Clear();
}
}
}
void Append(const ast::Statement* inst) { statements_->Push(inst); }
void Instruction(ir::Instruction* inst) {
tint::Switch(
inst, //
[&](ir::Access* i) { Access(i); }, //
[&](ir::Binary* i) { Binary(i); }, //
[&](ir::BreakIf* i) { BreakIf(i); }, //
[&](ir::Call* i) { Call(i); }, //
[&](ir::Continue*) {}, //
[&](ir::ExitIf*) {}, //
[&](ir::ExitLoop* i) { ExitLoop(i); }, //
[&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
[&](ir::If* i) { If(i); }, //
[&](ir::Let* i) { Let(i); }, //
[&](ir::Load* l) { Load(l); }, //
[&](ir::LoadVectorElement* i) { LoadVectorElement(i); }, //
[&](ir::Loop* l) { Loop(l); }, //
[&](ir::NextIteration*) {}, //
[&](ir::Return* i) { Return(i); }, //
[&](ir::Store* i) { Store(i); }, //
[&](ir::StoreVectorElement* i) { StoreVectorElement(i); }, //
[&](ir::Switch* i) { Switch(i); }, //
[&](ir::Swizzle* i) { Swizzle(i); }, //
[&](ir::Unary* i) { Unary(i); }, //
[&](ir::Unreachable*) {}, //
[&](ir::Var* i) { Var(i); }, //
[&](Default) { UNHANDLED_CASE(inst); });
}
void If(ir::If* if_) {
SCOPED_NESTING();
auto true_stmts = Statements(if_->True());
auto false_stmts = Statements(if_->False());
if (AsShortCircuit(if_, true_stmts, false_stmts)) {
return;
}
auto* cond = Expr(if_->Condition());
auto* true_block = b.Block(std::move(true_stmts));
switch (false_stmts.Length()) {
case 0:
Append(b.If(cond, true_block));
return;
case 1:
if (auto* else_if = false_stmts.Front()->As<ast::IfStatement>()) {
Append(b.If(cond, true_block, b.Else(else_if)));
return;
}
break;
}
auto* false_block = b.Block(std::move(false_stmts));
Append(b.If(cond, true_block, b.Else(false_block)));
}
void Loop(ir::Loop* l) {
SCOPED_NESTING();
// Build all the initializer statements
auto init_stmts = Statements(l->Initializer());
// If there's a single initializer statement and meets the WGSL 'for_init' pattern, then
// this can be used as the initializer for a for-loop.
// @see https://www.w3.org/TR/WGSL/#syntax-for_init
auto* init = (init_stmts.Length() == 1) &&
init_stmts.Front()
->IsAnyOf<ast::VariableDeclStatement, ast::AssignmentStatement,
ast::CompoundAssignmentStatement,
ast::IncrementDecrementStatement, ast::CallStatement>()
? init_stmts.Front()
: nullptr;
// Build the loop body statements. If the loop body starts with a if with the following
// pattern, then treat it as the loop condition:
// if cond {
// block { exit_if }
// block { exit_loop }
// }
const ast::Expression* cond = nullptr;
StatementList body_stmts;
{
MarkInlinable(l->Body());
TINT_SCOPED_ASSIGNMENT(statements_, &body_stmts);
for (auto* inst : *l->Body()) {
if (body_stmts.IsEmpty()) {
if (auto* if_ = inst->As<ir::If>()) {
if (!if_->HasResults() && //
if_->True()->Length() == 1 && //
if_->False()->Length() == 1 && //
tint::Is<ir::ExitIf>(if_->True()->Front()) && //
tint::Is<ir::ExitLoop>(if_->False()->Front())) {
// Matched the loop condition.
cond = Expr(if_->Condition());
continue; // Don't emit this as an instruction in the body.
}
}
}
// Process the loop body instruction. Append to 'body_stmts'
Instruction(inst);
}
}
// Build any continuing statements
auto cont_stmts = Statements(l->Continuing());
// If there's a single continuing statement and meets the WGSL 'for_update' pattern then
// this can be used as the continuing for a for-loop.
// @see https://www.w3.org/TR/WGSL/#syntax-for_update
auto* cont =
(cont_stmts.Length() == 1) &&
cont_stmts.Front()
->IsAnyOf<ast::AssignmentStatement, ast::CompoundAssignmentStatement,
ast::IncrementDecrementStatement, ast::CallStatement>()
? cont_stmts.Front()
: nullptr;
// Depending on 'init', 'cond' and 'cont', build a 'for', 'while' or 'loop'
const ast::Statement* loop = nullptr;
if ((!cont && !cont_stmts.IsEmpty()) // Non-trivial continuing
|| !cond // or non-trivial or no condition
) {
// Build a loop
if (cond) {
body_stmts.Insert(0, b.If(b.Not(cond), b.Block(b.Break())));
}
auto* body = b.Block(std::move(body_stmts));
loop = cont_stmts.IsEmpty() ? b.Loop(body) //
: b.Loop(body, b.Block(std::move(cont_stmts)));
if (!init_stmts.IsEmpty()) {
init_stmts.Push(loop);
loop = b.Block(std::move(init_stmts));
}
} else if (init || cont) {
// Build a for-loop
auto* body = b.Block(std::move(body_stmts));
loop = b.For(init, cond, cont, body);
if (!init && !init_stmts.IsEmpty()) {
init_stmts.Push(loop);
loop = b.Block(std::move(init_stmts));
}
} else {
// Build a while-loop
auto* body = b.Block(std::move(body_stmts));
loop = b.While(cond, body);
if (!init_stmts.IsEmpty()) {
init_stmts.Push(loop);
loop = b.Block(std::move(init_stmts));
}
}
statements_->Push(loop);
}
void Switch(ir::Switch* s) {
SCOPED_NESTING();
auto* cond = Expr(s->Condition());
auto cases = utils::Transform(
s->Cases(), //
[&](ir::Switch::Case c) -> const tint::ast::CaseStatement* {
SCOPED_NESTING();
const ast::BlockStatement* body = nullptr;
{
TINT_SCOPED_ASSIGNMENT(current_switch_case_, c.Block());
body = Block(c.Block());
}
auto selectors = utils::Transform(c.selectors, //
[&](ir::Switch::CaseSelector cs) {
return cs.IsDefault()
? b.DefaultCaseSelector()
: b.CaseSelector(Expr(cs.val));
});
return b.Case(std::move(selectors), body);
});
Append(b.Switch(cond, std::move(cases)));
}
void ExitSwitch(const ir::ExitSwitch* e) {
if (current_switch_case_ && current_switch_case_->Terminator() == e) {
return; // No need to emit
}
Append(b.Break());
}
void ExitLoop(const ir::ExitLoop*) { Append(b.Break()); }
void BreakIf(ir::BreakIf* i) { Append(b.BreakIf(Expr(i->Condition()))); }
void Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
// If this block is nested withing some control flow, then we must
// emit a 'return' statement, otherwise we've just naturally reached
// the end of the function where the 'return' is redundant.
if (nesting_depth_ > 1) {
Append(b.Return());
}
return;
}
// Return has arguments - this is the return value.
if (ret->Args().Length() != 1) {
TINT_ICE(IR, b.Diagnostics())
<< "expected 1 value for return, got " << ret->Args().Length();
return;
}
Append(b.Return(Expr(ret->Args().Front())));
}
void Var(ir::Var* var) {
auto* val = var->Result();
auto* ptr = As<type::Pointer>(val->Type());
auto ty = Type(ptr->StoreType());
Symbol name = NameFor(var->Result());
Bind(var->Result(), name, PtrKind::kRef);
utils::Vector<const ast::Attribute*, 4> attrs;
if (auto bp = var->BindingPoint()) {
attrs.Push(b.Group(AInt(bp->group)));
attrs.Push(b.Binding(AInt(bp->binding)));
}
const ast::Expression* init = nullptr;
if (var->Initializer()) {
init = Expr(var->Initializer());
}
switch (ptr->AddressSpace()) {
case builtin::AddressSpace::kFunction:
Append(b.Decl(b.Var(name, ty, init, std::move(attrs))));
return;
case builtin::AddressSpace::kStorage:
b.GlobalVar(name, ty, init, ptr->Access(), ptr->AddressSpace(), std::move(attrs));
return;
case builtin::AddressSpace::kHandle:
b.GlobalVar(name, ty, init, std::move(attrs));
return;
default:
b.GlobalVar(name, ty, init, ptr->AddressSpace(), std::move(attrs));
return;
}
}
void Let(ir::Let* let) {
Symbol name = NameFor(let->Result());
Append(b.Decl(b.Let(name, Expr(let->Value(), PtrKind::kPtr))));
Bind(let->Result(), name, PtrKind::kPtr);
}
void Store(ir::Store* store) {
auto* dst = Expr(store->To());
auto* src = Expr(store->From());
Append(b.Assign(dst, src));
}
void StoreVectorElement(ir::StoreVectorElement* store) {
auto* ptr = Expr(store->To());
auto* val = Expr(store->Value());
Append(b.Assign(VectorMemberAccess(ptr, store->Index()), val));
}
void Call(ir::Call* call) {
auto args = utils::Transform<4>(call->Args(), [&](ir::Value* arg) {
// Pointer-like arguments are passed by pointer, never reference.
return Expr(arg, PtrKind::kPtr);
});
tint::Switch(
call, //
[&](ir::UserCall* c) {
auto* expr = b.Call(NameFor(c->Func()), std::move(args));
if (!call->HasResults() || call->Result()->Usages().IsEmpty()) {
Append(b.CallStmt(expr));
return;
}
Bind(c->Result(), expr, PtrKind::kPtr);
},
[&](ir::CoreBuiltinCall* c) {
if (!disabled_derivative_uniformity_ && RequiresDerivativeUniformity(c->Func())) {
// TODO(crbug.com/tint/1985): Be smarter about disabling derivative uniformity.
b.DiagnosticDirective(builtin::DiagnosticSeverity::kOff,
builtin::CoreDiagnosticRule::kDerivativeUniformity);
disabled_derivative_uniformity_ = true;
}
auto* expr = b.Call(c->Func(), std::move(args));
if (!call->HasResults() || call->Result()->Type()->Is<type::Void>()) {
Append(b.CallStmt(expr));
return;
}
Bind(c->Result(), expr, PtrKind::kPtr);
},
[&](ir::Construct* c) {
auto ty = Type(c->Result()->Type());
Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
},
[&](ir::Convert* c) {
auto ty = Type(c->Result()->Type());
Bind(c->Result(), b.Call(ty, std::move(args)), PtrKind::kPtr);
},
[&](ir::Bitcast* c) {
auto ty = Type(c->Result()->Type());
Bind(c->Result(), b.Bitcast(ty, args[0]), PtrKind::kPtr);
},
[&](ir::Discard*) { Append(b.Discard()); }, //
[&](Default) { UNHANDLED_CASE(call); });
}
void Load(ir::Load* l) { Bind(l->Result(), Expr(l->From())); }
void LoadVectorElement(ir::LoadVectorElement* load) {
auto* ptr = Expr(load->From());
Bind(load->Result(), VectorMemberAccess(ptr, load->Index()));
}
void Unary(ir::Unary* u) {
const ast::Expression* expr = nullptr;
switch (u->Kind()) {
case ir::Unary::Kind::kComplement:
expr = b.Complement(Expr(u->Val()));
break;
case ir::Unary::Kind::kNegation:
expr = b.Negation(Expr(u->Val()));
break;
}
Bind(u->Result(), expr);
}
void Access(ir::Access* a) {
auto* expr = Expr(a->Object());
auto* obj_ty = a->Object()->Type()->UnwrapPtr();
for (auto* index : a->Indices()) {
tint::Switch(
obj_ty,
[&](const type::Vector* vec) {
TINT_DEFER(obj_ty = vec->type());
expr = VectorMemberAccess(expr, index);
},
[&](const type::Matrix* mat) {
obj_ty = mat->ColumnType();
expr = b.IndexAccessor(expr, Expr(index));
},
[&](const type::Array* arr) {
obj_ty = arr->ElemType();
expr = b.IndexAccessor(expr, Expr(index));
},
[&](const type::Struct* s) {
if (auto* c = index->As<ir::Constant>()) {
auto i = c->Value()->ValueAs<uint32_t>();
TINT_ASSERT_OR_RETURN(IR, i < s->Members().Length());
auto* member = s->Members()[i];
obj_ty = member->Type();
expr = b.MemberAccessor(expr, member->Name().NameView());
} else {
TINT_ICE(IR, b.Diagnostics())
<< "invalid index for struct type: " << index->TypeInfo().name;
}
},
[&](Default) { UNHANDLED_CASE(obj_ty); });
}
Bind(a->Result(), expr);
}
void Swizzle(ir::Swizzle* s) {
auto* vec = Expr(s->Object());
utils::Vector<char, 4> components;
for (uint32_t i : s->Indices()) {
if (TINT_UNLIKELY(i >= 4)) {
TINT_ICE(IR, b.Diagnostics()) << "invalid swizzle index: " << i;
return;
}
components.Push("xyzw"[i]);
}
auto* swizzle =
b.MemberAccessor(vec, std::string_view(components.begin(), components.Length()));
Bind(s->Result(), swizzle);
}
void Binary(ir::Binary* e) {
if (e->Kind() == ir::Binary::Kind::kEqual) {
auto* rhs = e->RHS()->As<ir::Constant>();
if (rhs && rhs->Type()->Is<type::Bool>() && rhs->Value()->ValueAs<bool>() == false) {
// expr == false
Bind(e->Result(), b.Not(Expr(e->LHS())));
return;
}
}
auto* lhs = Expr(e->LHS());
auto* rhs = Expr(e->RHS());
const ast::Expression* expr = nullptr;
switch (e->Kind()) {
case ir::Binary::Kind::kAdd:
expr = b.Add(lhs, rhs);
break;
case ir::Binary::Kind::kSubtract:
expr = b.Sub(lhs, rhs);
break;
case ir::Binary::Kind::kMultiply:
expr = b.Mul(lhs, rhs);
break;
case ir::Binary::Kind::kDivide:
expr = b.Div(lhs, rhs);
break;
case ir::Binary::Kind::kModulo:
expr = b.Mod(lhs, rhs);
break;
case ir::Binary::Kind::kAnd:
expr = b.And(lhs, rhs);
break;
case ir::Binary::Kind::kOr:
expr = b.Or(lhs, rhs);
break;
case ir::Binary::Kind::kXor:
expr = b.Xor(lhs, rhs);
break;
case ir::Binary::Kind::kEqual:
expr = b.Equal(lhs, rhs);
break;
case ir::Binary::Kind::kNotEqual:
expr = b.NotEqual(lhs, rhs);
break;
case ir::Binary::Kind::kLessThan:
expr = b.LessThan(lhs, rhs);
break;
case ir::Binary::Kind::kGreaterThan:
expr = b.GreaterThan(lhs, rhs);
break;
case ir::Binary::Kind::kLessThanEqual:
expr = b.LessThanEqual(lhs, rhs);
break;
case ir::Binary::Kind::kGreaterThanEqual:
expr = b.GreaterThanEqual(lhs, rhs);
break;
case ir::Binary::Kind::kShiftLeft:
expr = b.Shl(lhs, rhs);
break;
case ir::Binary::Kind::kShiftRight:
expr = b.Shr(lhs, rhs);
break;
}
Bind(e->Result(), expr);
}
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
const ast::Expression* Expr(ir::Value* value, PtrKind want_ptr_kind = PtrKind::kRef) {
using ExprAndPtrKind = std::pair<const ast::Expression*, PtrKind>;
auto [expr, got_ptr_kind] = tint::Switch(
value,
[&](ir::Constant* c) -> ExprAndPtrKind {
return {Constant(c), PtrKind::kRef};
},
[&](Default) -> ExprAndPtrKind {
auto lookup = bindings_.Find(value);
if (TINT_UNLIKELY(!lookup)) {
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << (value ? value->TypeInfo().name : "null")
<< ") value has no expression";
return {};
}
return std::visit(
[&](auto&& got) -> ExprAndPtrKind {
using T = std::decay_t<decltype(got)>;
if constexpr (std::is_same_v<T, VariableValue>) {
return {b.Expr(got.name), got.ptr_kind};
}
if constexpr (std::is_same_v<T, InlinedValue>) {
// Single use (inlined) expression.
// Mark the bindings_ map entry as consumed.
*lookup = ConsumedValue{};
return {got.expr, got.ptr_kind};
}
if constexpr (std::is_same_v<T, ConsumedValue>) {
TINT_ICE(IR, b.Diagnostics()) << "Expr(" << value->TypeInfo().name
<< ") called twice on the same value";
} else {
TINT_ICE(IR, b.Diagnostics())
<< "Expr(" << value->TypeInfo().name << ") has unhandled value";
}
return {};
},
*lookup);
});
if (!expr) {
return b.Expr("<error>");
}
if (value->Type()->Is<type::Pointer>()) {
return ToPtrKind(expr, got_ptr_kind, want_ptr_kind);
}
return expr;
}
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
const ast::Expression* Constant(ir::Constant* c) { return Constant(c->Value()); }
const ast::Expression* Constant(const constant::Value* c) {
auto composite = [&](bool can_splat) {
auto ty = Type(c->Type());
if (c->AllZero()) {
return b.Call(ty);
}
if (can_splat && c->Is<constant::Splat>()) {
return b.Call(ty, Constant(c->Index(0)));
}
utils::Vector<const ast::Expression*, 8> els;
for (size_t i = 0, n = c->NumElements(); i < n; i++) {
els.Push(Constant(c->Index(i)));
}
return b.Call(ty, std::move(els));
};
return tint::Switch(
c->Type(), //
[&](const type::I32*) { return b.Expr(c->ValueAs<i32>()); },
[&](const type::U32*) { return b.Expr(c->ValueAs<u32>()); },
[&](const type::F32*) { return b.Expr(c->ValueAs<f32>()); },
[&](const type::F16*) {
Enable(builtin::Extension::kF16);
return b.Expr(c->ValueAs<f16>());
},
[&](const type::Bool*) { return b.Expr(c->ValueAs<bool>()); },
[&](const type::Array*) { return composite(/* can_splat */ false); },
[&](const type::Vector*) { return composite(/* can_splat */ true); },
[&](const type::Matrix*) { return composite(/* can_splat */ false); },
[&](const type::Struct*) { return composite(/* can_splat */ false); },
[&](Default) {
UNHANDLED_CASE(c->Type());
return b.Expr("<error>");
});
}
void Enable(builtin::Extension ext) {
if (enables_.Add(ext)) {
b.Enable(ext);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Types
//
// The the case of an error:
// * The types generating methods must return a non-null ast type, which may not be semantically
// legal, but is enough to populate the AST.
// * A diagnostic error must be added to the ast::ProgramBuilder.
// This prevents littering the ToProgram logic with expensive error checking code.
////////////////////////////////////////////////////////////////////////////////////////////////
/// @param ty the type::Type
/// @return an ast::Type from @p ty.
/// @note May be a semantically-invalid placeholder type on error.
ast::Type Type(const type::Type* ty) {
return tint::Switch(
ty, //
[&](const type::Void*) { return ast::Type{}; }, //
[&](const type::I32*) { return b.ty.i32(); }, //
[&](const type::U32*) { return b.ty.u32(); }, //
[&](const type::F16*) {
Enable(builtin::Extension::kF16);
return b.ty.f16();
},
[&](const type::F32*) { return b.ty.f32(); }, //
[&](const type::Bool*) { return b.ty.bool_(); },
[&](const type::Matrix* m) {
return b.ty.mat(Type(m->type()), m->columns(), m->rows());
},
[&](const type::Vector* v) {
auto el = Type(v->type());
if (v->Packed()) {
TINT_ASSERT(IR, v->Width() == 3u);
return b.ty(builtin::Builtin::kPackedVec3, el);
} else {
return b.ty.vec(el, v->Width());
}
},
[&](const type::Array* a) {
auto el = Type(a->ElemType());
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(b.Stride(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(el, std::move(attrs));
}
auto count = a->ConstantCount();
if (TINT_UNLIKELY(!count)) {
TINT_ICE(IR, b.Diagnostics()) << type::Array::kErrExpectedConstantCount;
return b.ty.array(el, u32(1), std::move(attrs));
}
return b.ty.array(el, u32(count.value()), std::move(attrs));
},
[&](const type::Struct* s) { return Struct(s); },
[&](const type::Atomic* a) { return b.ty.atomic(Type(a->Type())); },
[&](const type::DepthTexture* t) { return b.ty.depth_texture(t->dim()); },
[&](const type::DepthMultisampledTexture* t) {
return b.ty.depth_multisampled_texture(t->dim());
},
[&](const type::ExternalTexture*) { return b.ty.external_texture(); },
[&](const type::MultisampledTexture* t) {
auto el = Type(t->type());
return b.ty.multisampled_texture(t->dim(), el);
},
[&](const type::SampledTexture* t) {
auto el = Type(t->type());
return b.ty.sampled_texture(t->dim(), el);
},
[&](const type::StorageTexture* t) {
return b.ty.storage_texture(t->dim(), t->texel_format(), t->access());
},
[&](const type::Sampler* s) { return b.ty.sampler(s->kind()); },
[&](const type::Pointer* p) {
// Note: type::Pointer always has an inferred access, but WGSL only allows an
// explicit access in the 'storage' address space.
auto el = Type(p->StoreType());
auto address_space = p->AddressSpace();
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
return b.ty.ptr(address_space, el, access);
},
[&](const type::Reference*) {
TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
return b.ty.i32();
},
[&](Default) {
UNHANDLED_CASE(ty);
return b.ty.i32();
});
}
ast::Type Struct(const type::Struct* s) {
auto n = structs_.GetOrCreate(s, [&] {
auto members = utils::Transform<8>(s->Members(), [&](const type::StructMember* m) {
auto ty = Type(m->Type());
const auto& ir_attrs = m->Attributes();
utils::Vector<const ast::Attribute*, 4> ast_attrs;
if (m->Type()->Align() != m->Align()) {
ast_attrs.Push(b.MemberAlign(u32(m->Align())));
}
if (m->Type()->Size() != m->Size()) {
ast_attrs.Push(b.MemberSize(u32(m->Size())));
}
if (auto location = ir_attrs.location) {
ast_attrs.Push(b.Location(u32(*location)));
}
if (auto index = ir_attrs.index) {
Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
ast_attrs.Push(b.Index(u32(*index)));
}
if (auto builtin = ir_attrs.builtin) {
ast_attrs.Push(b.Builtin(*builtin));
}
if (auto interpolation = ir_attrs.interpolation) {
ast_attrs.Push(b.Interpolate(interpolation->type, interpolation->sampling));
}
if (ir_attrs.invariant) {
ast_attrs.Push(b.Invariant());
}
return b.Member(m->Name().NameView(), ty, std::move(ast_attrs));
});
// TODO(crbug.com/tint/1902): Emit structure attributes
utils::Vector<const ast::Attribute*, 2> attrs;
auto name = b.Symbols().Register(s->Name().NameView());
b.Structure(name, std::move(members), std::move(attrs));
return name;
});
return b.ty(n);
}
const ast::Expression* ToPtrKind(const ast::Expression* in, PtrKind got, PtrKind want) {
if (want == PtrKind::kRef && got == PtrKind::kPtr) {
return b.Deref(in);
}
if (want == PtrKind::kPtr && got == PtrKind::kRef) {
return b.AddressOf(in);
}
return in;
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Bindings
////////////////////////////////////////////////////////////////////////////////////////////////
/// @returns the AST name for the given value, creating and returning a new name on the first
/// call.
Symbol NameFor(Value* value, std::string_view suggested = {}) {
return names_.GetOrCreate(value, [&] {
if (!suggested.empty()) {
return b.Symbols().Register(suggested);
}
if (auto sym = mod.NameOf(value)) {
return b.Symbols().Register(sym.NameView());
}
return b.Symbols().New("v");
});
}
/// Associates the IR value @p value with the AST expression @p expr.
/// @p ptr_kind defines how pointer values are represented by @p expr.
void Bind(ir::Value* value, const ast::Expression* expr, PtrKind ptr_kind = PtrKind::kRef) {
TINT_ASSERT(IR, value);
if (can_inline_.Remove(value)) {
// Value will be inlined at its place of usage.
if (TINT_LIKELY(bindings_.Add(value, InlinedValue{expr, ptr_kind}))) {
return;
}
} else {
if (value->Type()->Is<type::Pointer>()) {
expr = ToPtrKind(expr, ptr_kind, PtrKind::kPtr);
}
auto mod_name = mod.NameOf(value);
if (value->Usages().IsEmpty() && !mod_name.IsValid()) {
// Value has no usages and no name.
// Assign to a phony. These support more data types than a 'let', and avoids
// allocation of unused names.
Append(b.Assign(b.Phony(), expr));
} else {
Symbol name = NameFor(value, mod_name.NameView());
Append(b.Decl(b.Let(name, expr)));
Bind(value, name, PtrKind::kPtr);
}
return;
}
TINT_ICE(IR, b.Diagnostics())
<< "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
/// Associates the IR value @p value with the AST 'var', 'let' or parameter with the name @p
/// name.
/// @p ptr_kind defines how pointer values are represented by @p expr.
void Bind(ir::Value* value, Symbol name, PtrKind ptr_kind) {
TINT_ASSERT(IR, value);
bool added = bindings_.Add(value, VariableValue{name, ptr_kind});
if (TINT_UNLIKELY(!added)) {
TINT_ICE(IR, b.Diagnostics())
<< "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers
////////////////////////////////////////////////////////////////////////////////////////////////
bool AsShortCircuit(ir::If* i,
const StatementList& true_stmts,
const StatementList& false_stmts) {
if (!i->HasResults()) {
return false;
}
auto* result = i->Result();
if (!result->Type()->Is<type::Bool>()) {
return false; // Wrong result type
}
if (i->Exits().Count() != 2) {
return false; // Doesn't have two exits
}
if (!true_stmts.IsEmpty() || !false_stmts.IsEmpty()) {
return false; // True or False blocks contain statements
}
auto* cond = i->Condition();
auto* true_val = i->True()->Back()->Operands().Front();
auto* false_val = i->False()->Back()->Operands().Front();
if (IsConstant(false_val, false)) {
// %res = if %cond {
// block { # true
// exit_if %true_val;
// }
// block { # false
// exit_if false;
// }
// }
//
// transform into:
//
// res = cond && true_val;
//
auto* lhs = Expr(cond);
auto* rhs = Expr(true_val);
Bind(result, b.LogicalAnd(lhs, rhs));
return true;
}
if (IsConstant(true_val, true)) {
// %res = if %cond {
// block { # true
// exit_if true;
// }
// block { # false
// exit_if %false_val;
// }
// }
//
// transform into:
//
// res = cond || false_val;
//
auto* lhs = Expr(cond);
auto* rhs = Expr(false_val);
Bind(result, b.LogicalOr(lhs, rhs));
return true;
}
return false;
}
bool IsConstant(ir::Value* val, bool value) {
if (auto* c = val->As<ir::Constant>()) {
if (c->Type()->Is<type::Bool>()) {
return c->Value()->ValueAs<bool>() == value;
}
}
return false;
}
const ast::Expression* VectorMemberAccess(const ast::Expression* expr, ir::Value* index) {
if (auto* c = index->As<ir::Constant>()) {
switch (c->Value()->ValueAs<int>()) {
case 0:
return b.MemberAccessor(expr, "x");
case 1:
return b.MemberAccessor(expr, "y");
case 2:
return b.MemberAccessor(expr, "z");
case 3:
return b.MemberAccessor(expr, "w");
}
}
return b.IndexAccessor(expr, Expr(index));
}
bool RequiresDerivativeUniformity(builtin::Function fn) {
switch (fn) {
case builtin::Function::kDpdxCoarse:
case builtin::Function::kDpdyCoarse:
case builtin::Function::kFwidthCoarse:
case builtin::Function::kDpdxFine:
case builtin::Function::kDpdyFine:
case builtin::Function::kFwidthFine:
case builtin::Function::kDpdx:
case builtin::Function::kDpdy:
case builtin::Function::kFwidth:
case builtin::Function::kTextureSample:
case builtin::Function::kTextureSampleBias:
case builtin::Function::kTextureSampleCompare:
return true;
default:
return false;
}
}
};
} // namespace
Program ToProgram(Module& i) {
return State{i}.Run();
}
} // namespace tint::ir