blob: 48d73ad4043e21b244d141c5d5499e8c17b2a8c0 [file] [log] [blame]
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
#include <algorithm>
#include <string>
#include <utility>
#include "src/tint/lang/core/builtin.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/reference.h"
#include "src/tint/lang/core/type/vector.h"
#include "src/tint/lang/wgsl/ast/assignment_statement.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/array_count.h"
#include "src/tint/lang/wgsl/sem/index_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/load.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/type_expression.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/containers/hashmap.h"
#include "src/tint/utils/containers/hashset.h"
#include "src/tint/utils/containers/vector.h"
#include "src/tint/utils/rtti/switch.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::PackedVec3);
using namespace tint::number_suffixes; // NOLINT
namespace tint::ast::transform {
/// PIMPL state for the transform
struct PackedVec3::State {
/// Constructor
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// The name of the struct member used when wrapping packed vec3 types.
static constexpr const char* kStructMemberName = "elements";
/// The names of the structures used to wrap packed vec3 types.
Hashmap<const type::Type*, Symbol, 4> packed_vec3_wrapper_struct_names;
/// A cache of host-shareable structures that have been rewritten.
Hashmap<const type::Type*, Symbol, 4> rewritten_structs;
/// A map from type to the name of a helper function used to pack that type.
Hashmap<const type::Type*, Symbol, 4> pack_helpers;
/// A map from type to the name of a helper function used to unpack that type.
Hashmap<const type::Type*, Symbol, 4> unpack_helpers;
/// @param ty the type to test
/// @returns true if `ty` is a vec3, false otherwise
bool IsVec3(const type::Type* ty) {
if (auto* vec = ty->As<type::Vector>()) {
if (vec->Width() == 3) {
return true;
}
}
return false;
}
/// @param ty the type to test
/// @returns true if `ty` is or contains a vec3, false otherwise
bool ContainsVec3(const type::Type* ty) {
return Switch(
ty, //
[&](const type::Vector* vec) { return IsVec3(vec); },
[&](const type::Matrix* mat) { return ContainsVec3(mat->ColumnType()); },
[&](const type::Array* arr) { return ContainsVec3(arr->ElemType()); },
[&](const type::Struct* str) {
for (auto* member : str->Members()) {
if (ContainsVec3(member->Type())) {
return true;
}
}
return false;
});
}
/// Create a `__packed_vec3` type with the same element type as `ty`.
/// @param ty a three-element vector type
/// @returns the new AST type
Type MakePackedVec3(const type::Type* ty) {
auto* vec = ty->As<type::Vector>();
TINT_ASSERT(vec != nullptr && vec->Width() == 3);
return b.ty(core::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
}
/// Recursively rewrite a type using `__packed_vec3`, if needed.
/// When used as an array element type, the `__packed_vec3` type will be wrapped in a structure
/// and given an `@align()` attribute to give it alignment it needs to yield the correct array
/// element stride. For vec3 types used in structures directly, the `@align()` attribute is
/// placed on the containing structure instead. Matrices with three rows become arrays of
/// columns, and used the aligned wrapper struct for the column type.
/// @param ty the type to rewrite
/// @param array_element `true` if this is being called for the element of an array
/// @returns the new AST type, or nullptr if rewriting was not necessary
Type RewriteType(const type::Type* ty, bool array_element = false) {
return Switch(
ty,
[&](const type::Vector* vec) -> Type {
if (IsVec3(vec)) {
if (array_element) {
// Create a struct with a single `__packed_vec3` member.
// Give the struct member the same alignment as the original unpacked vec3
// type, to avoid changing the array element stride.
return b.ty(packed_vec3_wrapper_struct_names.GetOrCreate(vec, [&] {
auto name = b.Symbols().New(
"tint_packed_vec3_" + vec->type()->FriendlyName() +
(array_element ? "_array_element" : "_struct_member"));
auto* member =
b.Member(kStructMemberName, MakePackedVec3(vec),
tint::Vector{b.MemberAlign(AInt(vec->Align()))});
b.Structure(b.Ident(name), tint::Vector{member}, tint::Empty);
return name;
}));
} else {
return MakePackedVec3(vec);
}
}
return {};
},
[&](const type::Matrix* mat) -> Type {
// Rewrite the matrix as an array of columns that use the aligned wrapper struct.
auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
if (new_col_type) {
return b.ty.array(new_col_type, u32(mat->columns()));
}
return {};
},
[&](const type::Array* arr) -> Type {
// Rewrite the array with the modified element type.
auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
if (new_type) {
tint::Vector<const Attribute*, 1> attrs;
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(new_type, std::move(attrs));
} else if (auto count = arr->ConstantCount()) {
return b.ty.array(new_type, u32(count.value()), std::move(attrs));
} else {
TINT_ICE() << type::Array::kErrExpectedConstantCount;
return {};
}
}
return {};
},
[&](const type::Struct* str) -> Type {
if (ContainsVec3(str)) {
auto name = rewritten_structs.GetOrCreate(str, [&] {
tint::Vector<const StructMember*, 4> members;
for (auto* member : str->Members()) {
// If the member type contains a vec3, rewrite it.
auto new_type = RewriteType(member->Type());
if (new_type) {
// Copy the member attributes.
bool needs_align = true;
tint::Vector<const Attribute*, 4> attributes;
if (auto* sem_mem = member->As<sem::StructMember>()) {
for (auto* attr : sem_mem->Declaration()->attributes) {
if (attr->IsAnyOf<StructMemberAlignAttribute,
StructMemberOffsetAttribute>()) {
needs_align = false;
}
attributes.Push(ctx.Clone(attr));
}
}
// If the alignment wasn't already specified, add an attribute to
// make sure that we don't alter the alignment when using the packed
// vector type.
if (needs_align) {
attributes.Push(b.MemberAlign(AInt(member->Align())));
}
members.Push(b.Member(ctx.Clone(member->Name()), new_type,
std::move(attributes)));
} else {
// No vec3s, just clone the member as is.
if (auto* sem_mem = member->As<sem::StructMember>()) {
members.Push(ctx.Clone(sem_mem->Declaration()));
} else {
members.Push(
b.Member(ctx.Clone(member->Name()), new_type, tint::Empty));
}
}
}
// Create the new structure.
auto struct_name =
b.Symbols().New(str->Name().Name() + "_tint_packed_vec3");
b.Structure(struct_name, std::move(members));
return struct_name;
});
return b.ty(name);
}
return {};
});
}
/// Create a helper function to recursively pack or unpack a composite that contains vec3 types.
/// @param name_prefix the name of the helper function
/// @param ty the composite type to pack or unpack
/// @param pack_or_unpack_element a function that packs or unpacks an element with a given type
/// @param in_type a function that create an AST type for the input type
/// @param out_type a function that create an AST type for the output type
/// @returns the name of the helper function
Symbol MakePackUnpackHelper(
const char* name_prefix,
const type::Type* ty,
const std::function<const Expression*(const Expression*, const type::Type*)>&
pack_or_unpack_element,
const std::function<Type()>& in_type,
const std::function<Type()>& out_type) {
// Allocate a variable to hold the return value of the function.
tint::Vector<const Statement*, 4> statements;
statements.Push(b.Decl(b.Var("result", out_type())));
// Helper that generates a loop to copy and pack/unpack elements of an array to the result:
// for (var i = 0u; i < num_elements; i = i + 1) {
// result[i] = pack_or_unpack_element(in[i]);
// }
auto copy_array_elements = [&](uint32_t num_elements, const type::Type* element_type) {
// Generate an expression for packing or unpacking an element of the array.
auto* element = pack_or_unpack_element(b.IndexAccessor("in", "i"), element_type);
statements.Push(b.For( //
b.Decl(b.Var("i", b.ty.u32())), //
b.LessThan("i", u32(num_elements)), //
b.Assign("i", b.Add("i", 1_a)), //
b.Block(tint::Vector{
b.Assign(b.IndexAccessor("result", "i"), element),
})));
};
// Copy the elements of the value over to the result.
Switch(
ty,
[&](const type::Array* arr) {
TINT_ASSERT(arr->ConstantCount());
copy_array_elements(arr->ConstantCount().value(), arr->ElemType());
},
[&](const type::Matrix* mat) {
copy_array_elements(mat->columns(), mat->ColumnType());
},
[&](const type::Struct* str) {
// Copy the struct members over one at a time, packing/unpacking as necessary.
for (auto* member : str->Members()) {
const Expression* element =
b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
if (ContainsVec3(member->Type())) {
element = pack_or_unpack_element(element, member->Type());
}
statements.Push(b.Assign(
b.MemberAccessor("result", b.Ident(ctx.Clone(member->Name()))), element));
}
});
// Return the result.
statements.Push(b.Return("result"));
// Create the function and return its name.
auto name = b.Symbols().New(name_prefix);
b.Func(name, tint::Vector{b.Param("in", in_type())}, out_type(), std::move(statements));
return name;
}
/// Unpack the composite value `expr` to the unpacked type `ty`. If `ty` is a matrix, this will
/// produce a regular matNx3 value from an array of packed column vectors.
/// @param expr the composite value expression to unpack
/// @param ty the unpacked type
/// @returns an expression that holds the unpacked value
const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
auto helper = unpack_helpers.GetOrCreate(ty, [&] {
return MakePackUnpackHelper(
"tint_unpack_vec3_in_composite", ty,
[&](const Expression* element,
const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) {
// Unpack a `__packed_vec3` by casting it to a regular vec3.
// If it is an array element, extract the vector from the wrapper struct.
if (element->Is<IndexAccessorExpression>()) {
element = b.MemberAccessor(element, kStructMemberName);
}
return b.Call(CreateASTTypeFor(ctx, element_type), element);
} else {
return UnpackComposite(element, element_type);
}
},
[&] { return RewriteType(ty); }, //
[&] { return CreateASTTypeFor(ctx, ty); });
});
return b.Call(helper, expr);
}
/// Pack the composite value `expr` from the unpacked type `ty`. If `ty` is a matrix, this will
/// produce an array of packed column vectors.
/// @param expr the composite value expression to pack
/// @param ty the unpacked type
/// @returns an expression that holds the packed value
const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
auto helper = pack_helpers.GetOrCreate(ty, [&] {
return MakePackUnpackHelper(
"tint_pack_vec3_in_composite", ty,
[&](const Expression* element,
const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) {
// Pack a vector element by casting it to a packed_vec3.
// If it is an array element, construct a wrapper struct.
auto* packed = b.Call(MakePackedVec3(element_type), element);
if (element->Is<IndexAccessorExpression>()) {
packed = b.Call(RewriteType(element_type, true), packed);
}
return packed;
} else {
return PackComposite(element, element_type);
}
},
[&] { return CreateASTTypeFor(ctx, ty); }, //
[&] { return RewriteType(ty); });
});
return b.Call(helper, expr);
}
/// @returns true if there are host-shareable vec3's that need transforming
bool ShouldRun() {
// Check for vec3s in the types of all uniform and storage buffer variables to determine
// if the transform is necessary.
for (auto* decl : src->AST().GlobalVariables()) {
auto* var = sem.Get<sem::GlobalVariable>(decl);
if (var && core::IsHostShareable(var->AddressSpace()) &&
ContainsVec3(var->Type()->UnwrapRef())) {
return true;
}
}
return false;
}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
if (!ShouldRun()) {
return SkipTransform;
}
// Changing the types of certain structure members can trigger stricter layout validation
// rules for the uniform address space. In particular, replacing 16-bit matrices with arrays
// violates the requirement that the array element stride is a multiple of 16 bytes, and
// replacing vec3s with a structure violates the requirement that there must be at least 16
// bytes from the start of a structure to the start of the next member.
// Disable these validation rules using an internal extension, as MSL does not have these
// restrictions.
b.Enable(core::Extension::kChromiumInternalRelaxedUniformLayout);
// Track expressions that need to be packed or unpacked.
Hashset<const sem::ValueExpression*, 8> to_pack;
Hashset<const sem::ValueExpression*, 8> to_unpack;
// Replace vec3 types in host-shareable address spaces with `__packed_vec3` types, and
// collect expressions that need to be converted to or from values that use the
// `__packed_vec3` type.
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
sem.Get(node),
[&](const sem::TypeExpression* type) {
// Rewrite pointers to types that contain vec3s.
auto* ptr = type->Type()->As<type::Pointer>();
if (ptr && core::IsHostShareable(ptr->AddressSpace())) {
auto new_store_type = RewriteType(ptr->StoreType());
if (new_store_type) {
auto access = ptr->AddressSpace() == core::AddressSpace::kStorage
? ptr->Access()
: core::Access::kUndefined;
auto new_ptr_type =
b.ty.ptr(ptr->AddressSpace(), new_store_type, access);
ctx.Replace(node, new_ptr_type.expr);
}
}
},
[&](const sem::Variable* var) {
if (!core::IsHostShareable(var->AddressSpace())) {
return;
}
// Rewrite the var type, if it contains vec3s.
auto new_store_type = RewriteType(var->Type()->UnwrapRef());
if (new_store_type) {
ctx.Replace(var->Declaration()->type.expr, new_store_type.expr);
}
},
[&](const sem::Statement* stmt) {
// Pack the RHS of assignment statements that are writing to packed types.
if (auto* assign = stmt->Declaration()->As<AssignmentStatement>()) {
auto* lhs = sem.GetVal(assign->lhs);
auto* rhs = sem.GetVal(assign->rhs);
if (!ContainsVec3(rhs->Type()) ||
!core::IsHostShareable(
lhs->Type()->As<type::Reference>()->AddressSpace())) {
// Skip assignments to address spaces that are not host-shareable, or
// that do not contain vec3 types.
return;
}
// Pack the RHS expression.
if (to_unpack.Contains(rhs)) {
// The expression will already be packed, so skip the pending unpack.
to_unpack.Remove(rhs);
// If the expression produces a vec3 from an array element, extract
// the packed vector from the wrapper struct.
if (IsVec3(rhs->Type()) &&
rhs->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
ctx.Replace(rhs->Declaration(),
b.MemberAccessor(ctx.Clone(rhs->Declaration()),
kStructMemberName));
}
} else if (rhs) {
to_pack.Add(rhs);
}
}
},
[&](const sem::Load* load) {
// Unpack loads of types that contain vec3s in host-shareable address spaces.
if (ContainsVec3(load->Type()) &&
core::IsHostShareable(load->ReferenceType()->AddressSpace())) {
to_unpack.Add(load);
}
},
[&](const sem::IndexAccessorExpression* accessor) {
// If the expression produces a reference to a vec3 in a host-shareable address
// space from an array element, extract the packed vector from the wrapper
// struct.
if (auto* ref = accessor->Type()->As<type::Reference>()) {
if (IsVec3(ref->StoreType()) &&
core::IsHostShareable(ref->AddressSpace())) {
ctx.Replace(node, b.MemberAccessor(ctx.Clone(accessor->Declaration()),
kStructMemberName));
}
}
});
}
// Sort the pending pack/unpack operations by AST node ID to make the order deterministic.
auto to_unpack_sorted = to_unpack.Vector();
auto to_pack_sorted = to_pack.Vector();
auto pred = [&](auto* expr_a, auto* expr_b) {
return expr_a->Declaration()->node_id < expr_b->Declaration()->node_id;
};
to_unpack_sorted.Sort(pred);
to_pack_sorted.Sort(pred);
// Apply all of the pending unpack operations that we have collected.
for (auto* expr : to_unpack_sorted) {
TINT_ASSERT(ContainsVec3(expr->Type()));
auto* packed = ctx.Clone(expr->Declaration());
const Expression* unpacked = nullptr;
if (IsVec3(expr->Type())) {
if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
// If we are unpacking a vec3 from an array element, extract the vector from the
// wrapper struct.
packed = b.MemberAccessor(packed, kStructMemberName);
}
// Cast the packed vector to a regular vec3.
unpacked = b.Call(CreateASTTypeFor(ctx, expr->Type()), packed);
} else {
// Use a helper function to unpack an array or matrix.
unpacked = UnpackComposite(packed, expr->Type());
}
TINT_ASSERT(unpacked != nullptr);
ctx.Replace(expr->Declaration(), unpacked);
}
// Apply all of the pending pack operations that we have collected.
for (auto* expr : to_pack_sorted) {
TINT_ASSERT(ContainsVec3(expr->Type()));
auto* unpacked = ctx.Clone(expr->Declaration());
const Expression* packed = nullptr;
if (IsVec3(expr->Type())) {
// Cast the regular vec3 to a packed vector type.
packed = b.Call(MakePackedVec3(expr->Type()), unpacked);
} else {
// Use a helper function to pack an array or matrix.
packed = PackComposite(unpacked, expr->Type());
}
TINT_ASSERT(packed != nullptr);
ctx.Replace(expr->Declaration(), packed);
}
ctx.Clone();
return resolver::Resolve(b);
}
private:
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
program::CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Alias to the semantic info in ctx.src
const sem::Info& sem = ctx.src->Sem();
};
PackedVec3::PackedVec3() = default;
PackedVec3::~PackedVec3() = default;
Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const {
return State{src}.Run();
}
} // namespace tint::ast::transform