blob: ba61160940126fe060520f374578db74a90e2fd0 [file] [log] [blame]
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
namespace tint {
namespace transform {
namespace {
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
if (type->Is<sem::Matrix>()) {
return true;
} else if (auto* ary = type->As<sem::Array>()) {
return ContainsMatrix(ary->ElemType());
} else if (auto* str = type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (ContainsMatrix(member->Type())) {
return true;
}
}
}
return false;
}
} // namespace
/// State holds the current transform state.
struct ModuleScopeVarToEntryPointParam::State {
/// The clone context.
CloneContext& ctx;
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// Clone any struct types that are contained in `ty` (including `ty` itself),
/// and add it to the global declarations now, so that they precede new global
/// declarations that need to reference them.
/// @param ty the type to clone
void CloneStructTypes(const sem::Type* ty) {
if (auto* str = ty->As<sem::Struct>()) {
if (!cloned_structs_.emplace(str).second) {
// The struct has already been cloned.
return;
}
// Recurse into members.
for (auto* member : str->Members()) {
CloneStructTypes(member->Type());
}
// Clone the struct and add it to the global declaration list.
// Remove the old declaration.
auto* ast_str = str->Declaration();
ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
} else if (auto* arr = ty->As<sem::Array>()) {
CloneStructTypes(arr->ElemType());
}
}
/// Process the module.
void Process() {
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<const ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any module-scope
// variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kNone) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
call->Declaration());
}
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<const ast::IdentifierExpression*,
const ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their replacement.
struct NewVar {
Symbol symbol;
bool is_pointer;
};
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting
// MSL's limit for threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
ast::StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
}
return workgroup_parameter_symbol;
};
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
auto sc = var->StorageClass();
if (sc == ast::StorageClass::kNone) {
continue;
}
if (sc != ast::StorageClass::kPrivate &&
sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kUniform &&
sc != ast::StorageClass::kUniformConstant &&
sc != ast::StorageClass::kWorkgroup) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
// This is the symbol for the variable that replaces the module-scope
// var.
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() {
return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
};
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations);
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
ctx.InsertFront(func_ast->params, param);
} else if (sc == ast::StorageClass::kStorage ||
sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->decorations);
attributes.push_back(ctx.dst->Disable(
ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = ctx.dst->ty.pointer(
store_type(), sc, var->Declaration()->declared_access);
auto* param =
ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol);
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(workgroup_param()), member));
auto* local_var =
ctx.dst->Const(new_var_symbol,
ctx.dst->ty.pointer(
store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func_ast->body->statements,
ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// Variables in the Private and Workgroup storage classes are
// redeclared at function scope. Disable storage class validation on
// this variable.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor);
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body->statements,
ctx.dst->Decl(local_var));
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::DecorationList attributes;
if (!var->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(
param_type, sc, var->Declaration()->declared_access);
is_pointer = true;
// Disable validation of the parameter's storage class and of
// arguments passed it.
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
attributes.push_back(ctx.dst->Disable(
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
ctx.InsertBack(
func_ast->params,
ctx.dst->Param(new_var_symbol, param_type, attributes));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just
// remove the address-of instead of adding a deref, since we
// already have a pointer.
auto* ident =
user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_newvar[var] = {new_var_symbol, is_pointer};
}
if (!workgroup_parameter_members.empty()) {
// Create the workgroup memory parameter.
// The parameter is a struct that contains members for each workgroup
// variable.
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
std::move(workgroup_parameter_members));
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
ast::StorageClass::kWorkgroup);
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto* param =
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
ctx.InsertFront(func_ast->params, param);
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target =
ctx.src->AST().Functions().Find(call->target.name->symbol);
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
auto sc = target_var->StorageClass();
if (sc == ast::StorageClass::kNone) {
continue;
}
auto new_var = var_to_newvar[target_var];
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (is_entry_point && !is_handle && !new_var.is_pointer) {
// We need to pass a pointer and we don't already have one, so take
// the address of the new variable.
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->args, arg);
}
}
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() != ast::StorageClass::kNone) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}
private:
std::unordered_set<const sem::Struct*> cloned_structs_;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
State state{ctx};
state.Process();
ctx.Clone();
}
} // namespace transform
} // namespace tint