blob: f4b9fa33882c147e6505857bb46d9ca8566d17fc [file] [log] [blame]
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/wgsl/ast/transform/combine_samplers.h"
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/utils/containers/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CombineSamplers);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CombineSamplers::BindingInfo);
namespace {
bool IsGlobal(const tint::sem::VariablePair& pair) {
return pair.first->Is<tint::sem::GlobalVariable>() &&
(!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
}
} // namespace
namespace tint::ast::transform {
using namespace tint::number_suffixes; // NOLINT
CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map, const BindingPoint& placeholder)
: binding_map(map), placeholder_binding_point(placeholder) {}
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
CombineSamplers::BindingInfo::~BindingInfo() = default;
/// PIMPL state for the transform
struct CombineSamplers::State {
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// The binding info
const BindingInfo* binding_info;
/// Map from a texture/sampler pair to the corresponding combined sampler
/// variable
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const Variable*>;
/// A map of all global texture/sampler variable pairs to the global
/// combined sampler variable that will replace it.
CombinedTextureSamplerMap global_combined_texture_samplers_;
/// A map of all texture/sampler variable pairs that contain a function
/// parameter to the combined sampler function paramter that will replace it.
std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
function_combined_texture_samplers_;
/// Placeholder global samplers used when a function contains texture-only
/// references (one comparison sampler, one regular). These are also used as
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
/// resolver, but are then ignored and removed by the GLSL writer.
const Variable* placeholder_samplers_[2] = {};
/// Group and binding attributes used by all combined sampler globals.
/// Group 0 and binding 0 are used, with collisions disabled.
/// @returns the newly-created attribute list
auto Attributes() const {
tint::Vector<const Attribute*, 3> attributes{ctx.dst->Group(0_a), ctx.dst->Binding(0_a)};
attributes.Push(ctx.dst->Disable(DisabledValidation::kBindingPointCollision));
return attributes;
}
/// Constructor
/// @param program the source program
/// @param info the binding map information
State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {}
/// Creates a combined sampler global variables.
/// (Note this is actually a Texture node at the AST level, but it will be
/// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
/// @param texture_var the texture (global) variable
/// @param sampler_var the sampler (global) variable
/// @param name the default name to use (may be overridden by map lookup)
/// @returns the newly-created global variable
const Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
const sem::Variable* sampler_var,
std::string name) {
SamplerTexturePair bp_pair;
bp_pair.texture_binding_point = *texture_var->As<sem::GlobalVariable>()->BindingPoint();
bp_pair.sampler_binding_point =
sampler_var ? *sampler_var->As<sem::GlobalVariable>()->BindingPoint()
: binding_info->placeholder_binding_point;
auto it = binding_info->binding_map.find(bp_pair);
if (it != binding_info->binding_map.end()) {
name = it->second;
}
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
/// Creates placeholder global sampler variables.
/// @param kind the sampler kind to create for
/// @returns the newly-created global variable
const Variable* CreatePlaceholder(type::SamplerKind kind) {
Type type = ctx.dst->ty.sampler(kind);
const char* name = kind == type::SamplerKind::kComparisonSampler
? "placeholder_comparison_sampler"
: "placeholder_sampler";
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
/// Creates Identifier for a given texture and sampler variable pair.
/// Depth textures with no samplers are turned into the corresponding
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
/// @param texture the texture variable of interest
/// @param sampler the texture variable of interest
/// @returns the newly-created type
Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
const type::Type* texture_type = texture->Type()->UnwrapRef();
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
if (depth && !sampler) {
return ctx.dst->ty.sampled_texture(depth->dim(), ctx.dst->ty.f32());
} else {
return CreateASTTypeFor(ctx, texture_type);
}
}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
auto& sem = ctx.src->Sem();
// Remove all texture and sampler global variables. These will be replaced
// by combined samplers.
for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* global_sem = sem.Get(global)->As<sem::GlobalVariable>();
auto* type = ctx.src->TypeOf(global->type);
if (tint::IsAnyOf<type::Texture, type::Sampler>(type) &&
!type->Is<type::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
} else if (auto binding_point = global_sem->BindingPoint()) {
if (binding_point->group == 0 && binding_point->binding == 0) {
auto* attribute = ctx.dst->Disable(DisabledValidation::kBindingPointCollision);
ctx.InsertFront(global->attributes, attribute);
}
}
}
// Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found.
ctx.ReplaceAll([&](const Function* ast_fn) -> const Function* {
if (auto* fn = sem.Get(ast_fn)) {
auto pairs = fn->TextureSamplerPairs();
if (pairs.IsEmpty()) {
return nullptr;
}
tint::Vector<const Parameter*, 8> params;
for (auto pair : fn->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
std::string name = texture_var->Declaration()->name->symbol.Name();
if (sampler_var) {
name += "_" + sampler_var->Declaration()->name->symbol.Name();
}
if (IsGlobal(pair)) {
// Both texture and sampler are global; add a new global variable
// to represent the combined sampler (if not already created).
tint::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
return CreateCombinedGlobal(texture_var, sampler_var, name);
});
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var);
function_combined_texture_samplers_[fn][pair] = var;
}
}
// Filter out separate textures and samplers from the original
// function signature.
for (auto* param : fn->Parameters()) {
if (!param->Type()->IsAnyOf<type::Texture, type::Sampler>()) {
params.Push(ctx.Clone(param->Declaration()));
}
}
// Create a new function signature that differs only in the parameter
// list.
auto name = ctx.Clone(ast_fn->name);
auto return_type = ctx.Clone(ast_fn->return_type);
auto* body = ctx.Clone(ast_fn->body);
auto attributes = ctx.Clone(ast_fn->attributes);
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
return ctx.dst->create<Function>(name, params, return_type, body,
std::move(attributes),
std::move(return_type_attributes));
}
return nullptr;
});
// Replace all function call expressions containing texture or
// sampler parameters to use the current function's combined samplers or
// the combined global samplers, as appropriate.
ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
tint::Vector<const Expression*, 8> args;
// Replace all texture builtin calls.
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
const auto& signature = builtin->Signature();
auto sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
auto texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
if (texture_index == -1) {
return nullptr;
}
const sem::ValueExpression* texture =
call->Arguments()[static_cast<size_t>(texture_index)];
// We don't want to combine storage textures with anything, since
// they never have associated samplers in GLSL.
if (texture->Type()->UnwrapRef()->Is<type::StorageTexture>()) {
return nullptr;
}
const sem::ValueExpression* sampler =
sampler_index != -1 ? call->Arguments()[static_cast<size_t>(sampler_index)]
: nullptr;
auto* texture_var = texture->UnwrapLoad()->As<sem::VariableUser>()->Variable();
auto* sampler_var =
sampler ? sampler->UnwrapLoad()->As<sem::VariableUser>()->Variable()
: nullptr;
sem::VariablePair new_pair(texture_var, sampler_var);
for (auto* arg : expr->args) {
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
if (type->Is<type::Texture>()) {
const Variable* var =
IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()]
[new_pair];
args.Push(ctx.dst->Expr(var->name->symbol));
} else if (auto* sampler_type = type->As<type::Sampler>()) {
type::SamplerKind kind = sampler_type->kind();
int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
const Variable*& p = placeholder_samplers_[index];
if (!p) {
p = CreatePlaceholder(kind);
}
args.Push(ctx.dst->Expr(p->name->symbol));
} else {
args.Push(ctx.Clone(arg));
}
}
const Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
if (builtin->Type() == builtin::Function::kTextureLoad &&
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
!call->Stmt()->Declaration()->Is<CallStatement>()) {
value = ctx.dst->MemberAccessor(value, "x");
}
return value;
}
// Replace all function calls.
if (auto* callee = call->Target()->As<sem::Function>()) {
for (auto pair : callee->TextureSamplerPairs()) {
// Global pairs used by the callee do not require a function
// parameter at the call site.
if (IsGlobal(pair)) {
continue;
}
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
if (auto* param = texture_var->As<sem::Parameter>()) {
const sem::ValueExpression* texture = call->Arguments()[param->Index()];
texture_var =
texture->UnwrapLoad()->As<sem::VariableUser>()->Variable();
}
if (sampler_var) {
if (auto* param = sampler_var->As<sem::Parameter>()) {
const sem::ValueExpression* sampler =
call->Arguments()[param->Index()];
sampler_var =
sampler->UnwrapLoad()->As<sem::VariableUser>()->Variable();
}
}
sem::VariablePair new_pair(texture_var, sampler_var);
// If both texture and sampler are (now) global, pass that
// global variable to the callee. Otherwise use the caller's
// function parameter for this pair.
const Variable* var =
IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()]
[new_pair];
auto* arg = ctx.dst->Expr(var->name->symbol);
args.Push(arg);
}
// Append all of the remaining non-texture and non-sampler
// parameters.
for (auto* arg : expr->args) {
if (!ctx.src->TypeOf(arg)
->UnwrapRef()
->IsAnyOf<type::Texture, type::Sampler>()) {
args.Push(ctx.Clone(arg));
}
}
return ctx.dst->Call(ctx.Clone(expr->target), args);
}
}
return nullptr;
});
ctx.Clone();
return Program(std::move(b));
}
};
CombineSamplers::CombineSamplers() = default;
CombineSamplers::~CombineSamplers() = default;
Transform::ApplyResult CombineSamplers::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
auto* binding_info = inputs.Get<BindingInfo>();
if (!binding_info) {
ProgramBuilder b;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
return State(src, binding_info).Run();
}
} // namespace tint::ast::transform