blob: b79a02971f0a6d8e6726c97c0e47915ded84f7bc [file] [log] [blame]
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/robustness.h"
#include <algorithm>
#include <limits>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/builtin.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/index_accessor_expression.h"
#include "src/tint/sem/load.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/value_expression.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
#include "src/tint/type/reference.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
/// PIMPL state for the transform
struct Robustness::State {
/// Constructor
/// @param p the source program
/// @param c the transform config
State(const Program* p, Config&& c) : src(p), cfg(std::move(c)) {}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
if (HasAction(Action::kPredicate)) {
AddPredicateParameters();
}
// Walk all the AST nodes in the module, starting with the leaf nodes.
// The most deeply nested expressions will come first.
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::IndexAccessorExpression* e) {
// obj[idx]
// Array, matrix and vector indexing may require robustness transformation.
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
switch (ActionFor(expr)) {
case Action::kPredicate:
PredicateIndexAccessor(expr);
break;
case Action::kClamp:
ClampIndexAccessor(expr);
break;
case Action::kIgnore:
break;
}
},
[&](const ast::IdentifierExpression* e) {
// Identifiers may resolve to pointer lets, which may be predicated.
// Inspect.
if (auto* user = sem.Get<sem::VariableUser>(e)) {
auto* v = user->Variable();
if (v->Type()->Is<type::Pointer>()) {
// Propagate predicate from pointer
if (auto pred = predicates.Get(v->Declaration()->initializer)) {
predicates.Add(e, *pred);
}
}
}
},
[&](const ast::AccessorExpression* e) {
// obj.member
// Propagate the predication from the object to this expression.
if (auto pred = predicates.Get(e->object)) {
predicates.Add(e, *pred);
}
},
[&](const ast::UnaryOpExpression* e) {
// Includes address-of, or indirection
// Propagate the predication from the inner expression to this expression.
if (auto pred = predicates.Get(e->expr)) {
predicates.Add(e, *pred);
}
},
[&](const ast::AssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment target is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::CompoundAssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::IncrementDecrementStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::CallExpression* e) {
if (auto* call = sem.Get<sem::Call>(e)) {
Switch(
call->Target(), //
[&](const sem::Builtin* builtin) {
// Calls to builtins may require robustness transformation.
// Inspect.
if (builtin->IsTexture()) {
switch (cfg.texture_action) {
case Action::kPredicate:
PredicateTextureBuiltin(call, builtin);
break;
case Action::kClamp:
ClampTextureBuiltin(call, builtin);
break;
case Action::kIgnore:
break;
}
} else {
MaybePredicateNonTextureBuiltin(call, builtin);
}
},
[&](const sem::Function* fn) {
// Calls to user function may require passing additional predicate
// arguments.
InsertPredicateArguments(call, fn);
});
}
});
// Check whether the node is an expression that:
// * Has a predicate
// * Is of a non-pointer or non-reference type
// If the above is true, then we need to predicate evaluation of this expression by
// replacing `expr` with `predicated_expr` and injecting the following above the
// expression's statement:
//
// var predicated_expr : expr_ty;
// if (predicate) {
// predicated_expr = expr;
// }
//
if (auto* expr = node->As<ast::Expression>()) {
if (auto pred = predicates.Get(expr)) {
// Expression is predicated
auto* sem_expr = sem.GetVal(expr);
if (!sem_expr->Type()->IsAnyOf<type::Reference, type::Pointer>()) {
auto pred_load = b.Symbols().New("predicated_expr");
auto ty = CreateASTTypeFor(ctx, sem_expr->Type());
hoist.InsertBefore(sem_expr->Stmt(), b.Decl(b.Var(pred_load, ty)));
hoist.InsertBefore(
sem_expr->Stmt(),
b.If(*pred, b.Block(b.Assign(pred_load, ctx.Clone(expr)))));
ctx.Replace(expr, b.Expr(pred_load));
// The predication has been consumed for this expression.
// Don't predicate expressions that use this expression.
predicates.Remove(expr);
}
}
}
}
ctx.Clone();
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The transform's config
Config cfg;
/// The target program builder
ProgramBuilder b{};
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Helper for hoisting declarations
HoistToDeclBefore hoist{ctx};
/// Alias to the source program's semantic info
const sem::Info& sem = ctx.src->Sem();
/// Map of expression to predicate condition
utils::Hashmap<const ast::Expression*, Symbol, 32> predicates{};
/// @return the `u32` typed expression that represents the maximum indexable value for the index
/// accessor @p expr, or nullptr if there is no robustness limit for this expression.
const ast::Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
auto* obj_type = expr->Object()->Type();
return Switch(
obj_type->UnwrapRef(), //
[&](const type::Vector* vec) -> const ast::Expression* {
if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
return nullptr;
}
return b.Expr(u32(vec->Width() - 1u));
},
[&](const type::Matrix* mat) -> const ast::Expression* {
if (expr->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
return nullptr;
}
return b.Expr(u32(mat->columns() - 1u));
},
[&](const type::Array* arr) -> const ast::Expression* {
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
// Size is unknown until runtime.
// Must clamp, even if the index is constant.
auto* arr_ptr = b.AddressOf(ctx.Clone(expr->Object()->Declaration()));
return b.Sub(b.Call(sem::BuiltinType::kArrayLength, arr_ptr), 1_u);
}
if (auto count = arr->ConstantCount()) {
if (expr->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
return nullptr;
}
return b.Expr(u32(count.value() - 1u));
}
// Note: Don't be tempted to use the array override variable as an expression here,
// the name might be shadowed!
b.Diagnostics().add_error(diag::System::Transform,
type::Array::kErrExpectedConstantCount);
return nullptr;
},
[&](Default) -> const ast::Expression* {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled object type in robustness of array index: "
<< src->FriendlyName(obj_type->UnwrapRef());
return nullptr;
});
}
/// Transform the program to insert additional predicate parameters to all user functions that
/// have a pointer parameter type in an address space that has predicate action.
void AddPredicateParameters() {
for (auto* fn : src->AST().Functions()) {
for (auto* param : fn->params) {
auto* sem_param = sem.Get(param);
if (auto* ptr = sem_param->Type()->As<type::Pointer>()) {
if (ActionFor(ptr->AddressSpace()) == Action::kPredicate) {
auto name = b.Symbols().New(src->Symbols().NameFor(param->name->symbol) +
"_predicate");
ctx.InsertAfter(fn->params, param, b.Param(name, b.ty.bool_()));
// Associate the pointer parameter expressions with the predicate.
for (auto* user : sem_param->Users()) {
predicates.Add(user->Declaration(), name);
}
}
}
}
}
}
/// Transforms call expressions to user functions, inserting additional predicate arguments
/// after all pointer parameters with a type in an address space that has predicate action.
void InsertPredicateArguments(const sem::Call* call, const sem::Function* fn) {
auto* expr = call->Declaration();
for (size_t i = 0; i < fn->Parameters().Length(); i++) {
auto* param = fn->Parameters()[i];
if (auto* ptr = param->Type()->As<type::Pointer>()) {
if (ActionFor(ptr->AddressSpace()) == Action::kPredicate) {
auto* arg = expr->args[i];
if (auto predicate = predicates.Get(arg)) {
ctx.InsertAfter(expr->args, arg, b.Expr(*predicate));
} else {
ctx.InsertAfter(expr->args, arg, b.Expr(true));
}
}
}
}
}
/// Applies predication to the index on an array, vector or matrix.
/// @param expr the index accessor expression.
void PredicateIndexAccessor(const sem::IndexAccessorExpression* expr) {
auto* obj = expr->Object()->Declaration();
auto* idx = expr->Index()->Declaration();
auto* max = DynamicLimitFor(expr);
if (!max) {
// robustness is not required
// Just propagate predicate from object
if (auto pred = predicates.Get(obj)) {
predicates.Add(expr->Declaration(), *pred);
}
return;
}
auto* stmt = expr->Stmt();
auto obj_pred = *predicates.GetOrZero(obj);
auto idx_let = b.Symbols().New("index");
auto pred = b.Symbols().New("predicate");
hoist.InsertBefore(stmt, b.Decl(b.Let(idx_let, ctx.Clone(idx))));
ctx.Replace(idx, b.Expr(idx_let));
auto* cond = b.LessThanEqual(b.Call<u32>(b.Expr(idx_let)), max);
if (obj_pred.IsValid()) {
cond = b.And(b.Expr(obj_pred), cond);
}
hoist.InsertBefore(stmt, b.Decl(b.Let(pred, cond)));
predicates.Add(expr->Declaration(), pred);
}
/// Applies bounds clamping to the index on an array, vector or matrix.
/// @param expr the index accessor expression.
void ClampIndexAccessor(const sem::IndexAccessorExpression* expr) {
auto* max = DynamicLimitFor(expr);
if (!max) {
return; // robustness is not required
}
auto* expr_sem = expr->Unwrap()->As<sem::IndexAccessorExpression>();
auto idx = ctx.Clone(expr->Declaration()->index);
if (expr_sem->Index()->Type()->is_signed_integer_scalar()) {
idx = b.Call<u32>(idx); // u32(idx)
}
auto* clamped_idx = b.Call(sem::BuiltinType::kMin, idx, max);
ctx.Replace(expr->Declaration()->index, clamped_idx);
}
/// Applies predication to the non-texture builtin call, if required.
void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
// Gather the predications for the builtin arguments
const ast::Expression* predicate = nullptr;
for (auto* arg : call->Declaration()->args) {
if (auto pred = predicates.Get(arg)) {
predicate = And(predicate, b.Expr(*pred));
}
}
if (predicate) {
if (builtin->Type() == sem::BuiltinType::kWorkgroupUniformLoad) {
// https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin:
// "Executes a control barrier synchronization function that affects memory and
// atomic operations in the workgroup address space."
// Because the call acts like a control barrier, we need to make sure that we still
// trigger a workgroup barrier if the predicate fails.
PredicateCall(call, predicate,
b.Block(b.CallStmt(b.Call(sem::BuiltinType::kWorkgroupBarrier))));
} else {
PredicateCall(call, predicate);
}
}
}
/// Applies predication to texture builtins, based on whether the coordinates, array index and
/// level arguments are all in bounds.
void PredicateTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
if (!TextureBuiltinNeedsRobustness(builtin->Type())) {
return;
}
auto* expr = call->Declaration();
auto* stmt = call->Stmt();
// Indices of the mandatory texture and coords parameters, and the optional
// array and level parameters.
auto& signature = builtin->Signature();
auto texture_arg_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
auto coords_arg_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
auto array_arg_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
auto level_arg_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
// Build the builtin predicate from the arguments
const ast::Expression* predicate = nullptr;
Symbol level_idx, num_levels;
if (level_arg_idx >= 0) {
auto* param = builtin->Parameters()[static_cast<size_t>(level_arg_idx)];
if (param->Type()->is_integer_scalar()) {
// let level_idx = u32(level-arg);
level_idx = b.Symbols().New("level_idx");
auto* arg = expr->args[static_cast<size_t>(level_arg_idx)];
hoist.InsertBefore(stmt,
b.Decl(b.Let(level_idx, CastToUnsigned(ctx.Clone(arg), 1u))));
// let num_levels = textureNumLevels(texture-arg);
num_levels = b.Symbols().New("num_levels");
hoist.InsertBefore(
stmt, b.Decl(b.Let(num_levels, b.Call(sem::BuiltinType::kTextureNumLevels,
ctx.Clone(texture_arg)))));
// predicate: level_idx < num_levels
predicate = And(predicate, b.LessThan(level_idx, num_levels));
// Replace the level argument with `level_idx`
ctx.Replace(arg, b.Expr(level_idx));
}
}
Symbol coords;
if (coords_arg_idx >= 0) {
auto* param = builtin->Parameters()[static_cast<size_t>(coords_arg_idx)];
if (param->Type()->is_integer_scalar_or_vector()) {
// let coords = u32(coords-arg)
coords = b.Symbols().New("coords");
auto* arg = expr->args[static_cast<size_t>(coords_arg_idx)];
hoist.InsertBefore(stmt,
b.Decl(b.Let(coords, CastToUnsigned(b.Expr(ctx.Clone(arg)),
WidthOf(param->Type())))));
// predicate: all(coords < textureDimensions(texture))
auto* dimensions =
level_idx.IsValid()
? b.Call(sem::BuiltinType::kTextureDimensions, ctx.Clone(texture_arg),
b.Call(sem::BuiltinType::kMin, b.Expr(level_idx),
b.Sub(num_levels, 1_a)))
: b.Call(sem::BuiltinType::kTextureDimensions, ctx.Clone(texture_arg));
predicate =
And(predicate, b.Call(sem::BuiltinType::kAll, b.LessThan(coords, dimensions)));
// Replace the level argument with `coord`
ctx.Replace(arg, b.Expr(coords));
}
}
if (array_arg_idx >= 0) {
// let array_idx = u32(array-arg)
auto* arg = expr->args[static_cast<size_t>(array_arg_idx)];
auto* num_layers = b.Call(sem::BuiltinType::kTextureNumLayers, ctx.Clone(texture_arg));
auto array_idx = b.Symbols().New("array_idx");
hoist.InsertBefore(stmt, b.Decl(b.Let(array_idx, CastToUnsigned(ctx.Clone(arg), 1u))));
// predicate: array_idx < textureNumLayers(texture)
predicate = And(predicate, b.LessThan(array_idx, num_layers));
// Replace the array index argument with `array_idx`
ctx.Replace(arg, b.Expr(array_idx));
}
if (predicate) {
PredicateCall(call, predicate);
}
}
/// Applies bounds clamping to the coordinates, array index and level arguments of the texture
/// builtin.
void ClampTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
if (!TextureBuiltinNeedsRobustness(builtin->Type())) {
return;
}
auto* expr = call->Declaration();
auto* stmt = call->Stmt();
// Indices of the mandatory texture and coords parameters, and the optional
// array and level parameters.
auto& signature = builtin->Signature();
auto texture_arg_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
auto coords_arg_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
auto array_arg_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
auto level_arg_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
// If the level is provided, then we need to clamp this. As the level is used by
// textureDimensions() and the texture[Load|Store]() calls, we need to clamp both usages.
Symbol level_idx;
if (level_arg_idx >= 0) {
const auto* param = builtin->Parameters()[static_cast<size_t>(level_arg_idx)];
if (param->Type()->is_integer_scalar()) {
const auto* arg = expr->args[static_cast<size_t>(level_arg_idx)];
level_idx = b.Symbols().New("level_idx");
const auto* num_levels =
b.Call(sem::BuiltinType::kTextureNumLevels, ctx.Clone(texture_arg));
const auto* max = b.Sub(num_levels, 1_a);
hoist.InsertBefore(
stmt, b.Decl(b.Let(level_idx, b.Call(sem::BuiltinType::kMin,
b.Call<u32>(ctx.Clone(arg)), max))));
ctx.Replace(arg, b.Expr(level_idx));
}
}
// Clamp the coordinates argument
if (coords_arg_idx >= 0) {
const auto* param = builtin->Parameters()[static_cast<size_t>(coords_arg_idx)];
if (param->Type()->is_integer_scalar_or_vector()) {
auto* arg = expr->args[static_cast<size_t>(coords_arg_idx)];
const auto width = WidthOf(param->Type());
const auto* dimensions =
level_idx.IsValid()
? b.Call(sem::BuiltinType::kTextureDimensions, ctx.Clone(texture_arg),
level_idx)
: b.Call(sem::BuiltinType::kTextureDimensions, ctx.Clone(texture_arg));
// dimensions is u32 or vecN<u32>
const auto* unsigned_max = b.Sub(dimensions, ScalarOrVec(b.Expr(1_a), width));
if (param->Type()->is_signed_integer_scalar_or_vector()) {
const auto* zero = ScalarOrVec(b.Expr(0_a), width);
const auto* signed_max = CastToSigned(unsigned_max, width);
ctx.Replace(arg,
b.Call(sem::BuiltinType::kClamp, ctx.Clone(arg), zero, signed_max));
} else {
ctx.Replace(arg, b.Call(sem::BuiltinType::kMin, ctx.Clone(arg), unsigned_max));
}
}
}
// Clamp the array_index argument, if provided
if (array_arg_idx >= 0) {
auto* param = builtin->Parameters()[static_cast<size_t>(array_arg_idx)];
auto* arg = expr->args[static_cast<size_t>(array_arg_idx)];
auto* num_layers = b.Call(sem::BuiltinType::kTextureNumLayers, ctx.Clone(texture_arg));
const auto* unsigned_max = b.Sub(num_layers, 1_a);
if (param->Type()->is_signed_integer_scalar()) {
const auto* signed_max = CastToSigned(unsigned_max, 1u);
ctx.Replace(arg, b.Call(sem::BuiltinType::kClamp, ctx.Clone(arg), 0_a, signed_max));
} else {
ctx.Replace(arg, b.Call(sem::BuiltinType::kMin, ctx.Clone(arg), unsigned_max));
}
}
}
/// @param type builtin type
/// @returns true if the given builtin is a texture function that requires predication or
/// clamping of arguments.
bool TextureBuiltinNeedsRobustness(sem::BuiltinType type) {
return type == sem::BuiltinType::kTextureLoad || type == sem::BuiltinType::kTextureStore ||
type == sem::BuiltinType::kTextureDimensions;
}
/// @returns a bitwise and of the two expressions, or the other expression if one is null.
const ast::Expression* And(const ast::Expression* lhs, const ast::Expression* rhs) {
if (lhs && rhs) {
return b.And(lhs, rhs);
}
if (lhs) {
return lhs;
}
return rhs;
}
/// Transforms a call statement or expression so that the expression is predicated by @p
/// predicate.
/// @param else_stmt - the statement to execute for the predication failure
void PredicateCall(const sem::Call* call,
const ast::Expression* predicate,
const ast::BlockStatement* else_stmt = nullptr) {
auto* expr = call->Declaration();
auto* stmt = call->Stmt();
auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>();
if (call_stmt && call_stmt->expr == expr) {
// Wrap the statement in an if-statement with the predicate condition.
hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
ProgramBuilder::ElseStmt(else_stmt)));
} else {
// Emit the following before the expression's statement:
// var predicated_value : return-type;
// if (predicate) {
// predicated_value = call(...);
// }
auto value = b.Symbols().New("predicated_value");
hoist.InsertBefore(stmt, b.Decl(b.Var(value, CreateASTTypeFor(ctx, call->Type()))));
hoist.InsertBefore(stmt, b.If(predicate, b.Block(b.Assign(value, ctx.Clone(expr))),
ProgramBuilder::ElseStmt(else_stmt)));
// Replace the call expression with `predicated_value`
ctx.Replace(expr, b.Expr(value));
}
}
/// @returns true if @p action is enabled for any address space
bool HasAction(Action action) const {
return action == cfg.function_action || //
action == cfg.texture_action || //
action == cfg.private_action || //
action == cfg.push_constant_action || //
action == cfg.storage_action || //
action == cfg.uniform_action || //
action == cfg.workgroup_action;
}
/// @returns the robustness action to perform for an OOB access with the expression @p expr
Action ActionFor(const sem::ValueExpression* expr) {
return Switch(
expr->Type(), //
[&](const type::Reference* t) { return ActionFor(t->AddressSpace()); },
[&](Default) { return cfg.value_action; });
}
/// @returns the robustness action to perform for an OOB access in the address space @p
/// address_space
Action ActionFor(builtin::AddressSpace address_space) {
switch (address_space) {
case builtin::AddressSpace::kFunction:
return cfg.function_action;
case builtin::AddressSpace::kHandle:
return cfg.texture_action;
case builtin::AddressSpace::kPrivate:
return cfg.private_action;
case builtin::AddressSpace::kPushConstant:
return cfg.push_constant_action;
case builtin::AddressSpace::kStorage:
return cfg.storage_action;
case builtin::AddressSpace::kUniform:
return cfg.uniform_action;
case builtin::AddressSpace::kWorkgroup:
return cfg.workgroup_action;
default:
break;
}
TINT_UNREACHABLE(Transform, b.Diagnostics()) << "unhandled address space" << address_space;
return Action::kDefault;
}
/// @returns the vector width of @p ty, or 1 if @p ty is not a vector
static uint32_t WidthOf(const type::Type* ty) {
if (auto* vec = ty->As<type::Vector>()) {
return vec->Width();
}
return 1u;
}
/// @returns a scalar or vector type with the element type @p scalar and width @p width
ast::Type ScalarOrVecTy(ast::Type scalar, uint32_t width) const {
if (width > 1) {
return b.ty.vec(scalar, width);
}
return scalar;
}
/// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
/// otherwise returns @p scalar.
const ast::Expression* ScalarOrVec(const ast::Expression* scalar, uint32_t width) {
if (width > 1) {
return b.Call(b.ty.vec<Infer>(width), scalar);
}
return scalar;
}
/// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
/// is 1.
const ast::CallExpression* CastToSigned(const ast::Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
}
/// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
/// is 1.
const ast::CallExpression* CastToUnsigned(const ast::Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
}
};
Robustness::Config::Config() = default;
Robustness::Config::Config(const Config&) = default;
Robustness::Config::~Config() = default;
Robustness::Config& Robustness::Config::operator=(const Config&) = default;
Robustness::Robustness() = default;
Robustness::~Robustness() = default;
Transform::ApplyResult Robustness::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
return State{src, std::move(cfg)}.Run();
}
} // namespace tint::transform