blob: 898383553c5c46a6cfb66fd3370c54bf451d8f17 [file] [log] [blame]
// Copyright 2025 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/lower/transpose_row_major.h"
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/spirv/type/explicit_layout_array.h"
namespace tint::spirv::reader::lower {
namespace {
using namespace tint::core::fluent_types; // NOLINT
constexpr std::string_view kTintLoadRowMajor = "tint_load_row_major_column";
constexpr std::string_view kTintTransposeRowMajorArray = "tint_transpose_row_major_array";
constexpr std::string_view kTintStoreRowMajor = "tint_store_row_major_column";
/// PIMPL state for the transform.
struct State {
/// The IR module.
core::ir::Module& ir;
/// The IR builder.
core::ir::Builder b{ir};
/// The type manager.
core::type::Manager& ty{ir.Types()};
/// The symbol manager.
SymbolTable& sym{ir.symbols};
// A map from a (type, is_row_major) pair to it's replacement (which maybe the same as the
// original).
struct TypeAndRowMajor {
const core::type::Type* type;
bool row_major;
bool operator==(const TypeAndRowMajor& other) const {
return type == other.type && row_major == other.row_major;
}
tint::HashCode HashCode() const { return Hash(type, row_major); }
};
Hashmap<TypeAndRowMajor, const core::type::Type*, 32> original_to_transposed{};
/// A map from rewritten structs to original structs.
Hashmap<const core::type::Struct*, const core::type::Struct*, 4> struct_to_original{};
/// List of instruction results where the usages need to be updated
Vector<core::ir::InstructionResult*, 32> results_to_update{};
/// A hash of access instructions to the vector index they were accessing
Hashmap<core::ir::Value*, core::ir::Value*, 8> access_to_vector_index{};
/// A list of instructions to remove
Vector<core::ir::Instruction*, 8> instructions_to_remove_if_unused{};
/// A map of type to the load helper
Hashmap<const core::type::Type*, core::ir::Function*, 4> load_functions{};
/// A map of type to the store helper
Hashmap<const core::type::Type*, core::ir::Function*, 4> store_functions{};
/// Instructions which have been processed
Hashset<core::ir::InstructionResult*, 32> processed_instructions{};
/// Process the module.
void Process() {
Vector<core::ir::Instruction*, 4> instructions_to_process;
for (auto* inst : ir.Instructions()) {
// Replace all constant operands where the type will be changed due to it containing a
// structure that uses a row-major attribute.
for (uint32_t i = 0; i < inst->Operands().Length(); ++i) {
if (auto* constant = As<core::ir::Constant>(inst->Operands()[i])) {
auto* new_constant = RewriteConstant(constant->Value(), false);
if (new_constant != constant->Value()) {
inst->SetOperand(i, b.Constant(new_constant));
}
}
}
for (auto* res : inst->Results()) {
auto* new_ty = RewriteType(res->Type(), false);
if (new_ty != res->Type()) {
res->SetType(new_ty);
instructions_to_process.Push(inst);
results_to_update.Push(res);
}
}
}
for (auto inst : instructions_to_process) {
ProcessInstruction(inst);
}
while (!results_to_update.IsEmpty()) {
auto* result = results_to_update.Pop();
if (!processed_instructions.Add(result)) {
continue;
}
// It's possible we've already processed this instruction because we're working through
// results, so if it isn't alive, then we've already replaced it and we can move on.
if (!result->Alive()) {
continue;
}
// Use sorted usages to force a copy, the replacements may create new usages of this
// result.
for (auto& usage : result->UsagesSorted()) {
ProcessInstruction(usage.instruction);
}
}
for (auto* inst : instructions_to_remove_if_unused) {
// If we've detached, just remove the instruction
if (inst->Results().IsEmpty()) {
inst->Destroy();
continue;
}
TINT_ASSERT(inst->Results().Length() == 1);
if (!inst->Result()->IsUsed()) {
inst->Destroy();
}
}
}
void ProcessInstruction(core::ir::Instruction* inst) {
tint::Switch(
inst, //
[&](core::ir::Var*) {
// Nothing to do as we already substituted the type.
},
[&](core::ir::Let*) {
// Nothing to do as we already substituted the type.
},
[&](core::ir::Load* ld) { ReplaceLoad(ld); },
[&](core::ir::Store* store) { ReplaceStore(store); },
[&](core::ir::Access* access) { ReplaceAccess(access); },
[&](core::ir::Construct* construct) { ReplaceConstruct(construct); },
[&](core::ir::LoadVectorElement* lve) { ReplaceLoadVectorElement(lve); },
[&](core::ir::StoreVectorElement* sve) { ReplaceStoreVectorElement(sve); },
TINT_ICE_ON_NO_MATCH);
}
void ReplaceConstruct(core::ir::Construct* construct) {
auto* struct_type = construct->Result()->Type()->As<core::type::Struct>();
if (!struct_type) {
return;
}
b.InsertBefore(construct, [&] {
Vector<core::ir::Value*, 8> new_operands;
for (uint32_t i = 0; i < construct->Operands().Length(); i++) {
auto* operand = construct->Operands()[i];
auto* member_type = struct_type->Members()[i]->Type();
if (member_type != operand->Type()) {
tint::Switch(
member_type,
[&](const core::type::Matrix*) {
new_operands.Push(
b.Call(member_type, core::BuiltinFn::kTranspose, operand)
->Result());
},
[&](const core::type::Array*) {
// TODO(437140112): Add support for arrays of matrices
TINT_UNIMPLEMENTED() << "handle construct of array of matrices";
},
TINT_ICE_ON_NO_MATCH);
} else {
new_operands.Push(operand);
}
}
construct->SetOperands(new_operands);
});
}
// This is a store vector element that is going to a matrix which we've transposed the size of.
// So, we need to swap the index on this store with the last index of the source access.
void ReplaceStoreVectorElement(core::ir::StoreVectorElement* sve) {
auto* src_to = sve->To()->As<core::ir::InstructionResult>();
TINT_ASSERT(src_to);
auto* src_access = src_to->Instruction()->As<core::ir::Access>();
TINT_ASSERT(src_access);
auto access_idx = access_to_vector_index.Get(sve->To());
TINT_ASSERT(access_idx);
core::ir::Access* new_access = nullptr;
b.InsertAfter(src_access, [&] {
auto* src_ty = src_to->Type()->As<core::type::Pointer>();
TINT_ASSERT(src_ty);
auto* src_mat = src_ty->StoreType()->As<core::type::Matrix>();
TINT_ASSERT(src_mat);
auto* new_ptr =
ty.ptr(src_ty->AddressSpace(), ty.vec(src_mat->Type(), src_mat->Rows()));
new_access = b.Access(new_ptr, src_access, Vector{sve->Index()});
b.InsertAfter(sve,
[&] { b.StoreVectorElement(new_access, *access_idx, sve->Value()); });
sve->Destroy();
});
instructions_to_remove_if_unused.Push(src_access);
}
// This is a load vector element that is coming from a matrix which we've transposed the size
// of. So, we need to swap the index on this load with the last index of the source access.
void ReplaceLoadVectorElement(core::ir::LoadVectorElement* lve) {
auto* src_result = lve->From()->As<core::ir::InstructionResult>();
TINT_ASSERT(src_result);
auto* src_access = src_result->Instruction()->As<core::ir::Access>();
TINT_ASSERT(src_access);
auto access_idx = access_to_vector_index.Get(lve->From());
TINT_ASSERT(access_idx);
core::ir::Access* new_access = nullptr;
b.InsertAfter(src_access, [&] {
auto* src_ty = src_result->Type()->As<core::type::Pointer>();
TINT_ASSERT(src_ty);
auto* src_mat = src_ty->StoreType()->As<core::type::Matrix>();
TINT_ASSERT(src_mat);
auto* new_ptr =
ty.ptr(src_ty->AddressSpace(), ty.vec(src_mat->Type(), src_mat->Rows()));
new_access = b.Access(new_ptr, src_access, Vector{lve->Index()});
b.InsertAfter(lve, [&] {
b.LoadVectorElementWithResult(lve->DetachResult(), new_access, *access_idx);
});
lve->Destroy();
});
instructions_to_remove_if_unused.Push(src_access);
}
void ReplaceAccess(core::ir::Access* access) {
bool indexed_through_row_major = false;
auto* cur_ty = access->Object()->Type()->UnwrapPtr();
const core::type::Type* parent_ty = nullptr;
const core::type::Matrix* mat_ty = nullptr;
auto indices = access->Indices();
int32_t matrix_index = -1;
for (uint32_t i = 0; i < indices.Length(); ++i) {
auto* idx = indices[i];
if (auto* struct_ty = cur_ty->As<core::type::Struct>()) {
auto const_idx = idx->As<core::ir::Constant>()->Value()->ValueAs<uint32_t>();
parent_ty = cur_ty;
cur_ty = cur_ty->Element(const_idx);
auto* orig_struct = struct_to_original.GetOr(struct_ty, nullptr);
if (!orig_struct) {
// Structure didn't change, so doesn't contain a row-major member
continue;
}
auto* mem = orig_struct->Members()[const_idx];
if (mem->RowMajor()) {
indexed_through_row_major = true;
}
} else {
parent_ty = cur_ty;
cur_ty = cur_ty->Elements().type;
}
// We do this at the end because we want the index that we load the matrix from, so the
// next thing we look at would be the matrix.
if (cur_ty->Is<core::type::Matrix>()) {
TINT_ASSERT(matrix_index == -1);
mat_ty = cur_ty->As<core::type::Matrix>();
matrix_index = int32_t(i);
}
}
// The matrix is in an array and we're dealing with the array. We'll need to handle the
// array when we load/store it, so add the access to the results to update.
if (cur_ty->Is<core::type::Array>()) {
ReplacePointerAccess(access, parent_ty, cur_ty);
results_to_update.Push(access->Result());
return;
}
// The thing we're accessing has changed, so we need to change.
if (!indexed_through_row_major && cur_ty != access->Result()->Type()->UnwrapPtr()) {
ReplacePointerAccess(access, parent_ty, cur_ty);
return;
}
// If we don't have a matrix index we need to update, then we've updated the `var` type
// we're accessing but we're accessing another member of the struct, so we're done.
if (matrix_index == -1) {
return;
}
// Not accessing through a matrix, nothing to do.
if (!indexed_through_row_major) {
return;
}
if (access->Object()->Type()->Is<core::type::Pointer>()) {
ReplacePointerAccess(access, parent_ty, cur_ty);
return;
}
// This isn't a pointer access, so we'll just split the access in half, transpose the
// matrix itself and then access that matrix with the rest of the expression.
Vector<core::ir::Value*, 4> mat_indices = indices.Truncate(size_t(matrix_index) + 1);
b.InsertBefore(access, [&] {
auto* m = b.Access(mat_ty, access->Object(), mat_indices);
auto* t = b.Call(RewriteType(mat_ty, true), core::BuiltinFn::kTranspose, m)->Result();
if (uint32_t(matrix_index) != indices.Length() - 1) {
Vector<core::ir::Value*, 4> access_indices =
indices.Offset(size_t(matrix_index) + 1);
b.AccessWithResult(access->DetachResult(), t, access_indices)->Result();
} else {
access->Result()->ReplaceAllUsesWith(t);
}
});
access->Destroy();
}
void ReplacePointerAccess(core::ir::Access* access,
const core::type::Type* parent_ty,
const core::type::Type* cur_ty) {
// We've accessed the matrix itself, or an array containing the matrix. We need to update
// the access result type, and if changed, update the uses of this access
if (cur_ty->Is<core::type::Matrix>() || cur_ty->Is<core::type::Array>()) {
auto* new_access_ty = RewriteType(access->Result()->Type(), true);
if (new_access_ty != access->Result()->Type()) {
access->Result()->SetType(new_access_ty);
results_to_update.Push(access->Result());
}
return;
}
// We're accessing a row of the vector. We need to replace this access with an access of
// the parent matrix and then store away which row we're accessing so we can rebuild any
// needed accesses later.
if (cur_ty->Is<core::type::Vector>()) {
TINT_ASSERT(parent_ty != nullptr);
auto* idx = access->PopLastIndex();
/// The access now returns the transposed matrix
auto* new_access_ty = access->Result()->Type();
if (auto* access_ptr = access->Result()->Type()->As<core::type::Pointer>()) {
new_access_ty = ty.ptr(access_ptr->AddressSpace(), parent_ty, access_ptr->Access());
}
access->Result()->SetType(new_access_ty);
access_to_vector_index.Add(access->Result(), idx);
results_to_update.Push(access->Result());
return;
}
TINT_UNREACHABLE() << "access of unknown type for row-major matrix";
}
void ReplaceLoad(core::ir::Load* ld) {
auto* ld_ty = ld->Result()->Type();
tint::Switch(
ld_ty, //
[&](const core::type::Matrix*) {
b.InsertAfter(ld, [&] {
// We're replacing the load, which means the source must have been a transposed
// matrix, so we need to get the load result as if it was row-major decorated.
auto* new_res = b.InstructionResult(RewriteType(ld->Result()->Type(), true));
b.CallWithResult(ld->DetachResult(), core::BuiltinFn::kTranspose, new_res);
ld->SetResult(new_res);
});
},
[&](const core::type::Vector*) {
auto idx = access_to_vector_index.Get(ld->From());
TINT_ASSERT(idx);
if (idx) {
// We're loading a vector from an access chain, we need to determine if this
// vector came from a row-major matrix
auto* load_fn = LoadColumnHelper(ld->From()->Type()->As<core::type::Pointer>());
b.InsertAfter(ld, [&] {
auto* v = *idx;
if (v->Type()->Is<core::type::I32>()) {
v = b.Convert(ty.u32(), v)->Result();
}
b.CallWithResult(ld->DetachResult(), load_fn, ld->From(), v);
});
// Do this after we're done so we don't end up modifying the usage list of the
// result we're iterating over.
instructions_to_remove_if_unused.Push(ld);
} else {
TINT_UNREACHABLE() << "attempting to load a row-major vector?";
}
},
[&](const core::type::Struct*) {
// Handled elsewhere
},
[&](const core::type::Array*) {
// We're replacing the load, which means the source must have been a transposed
// array of matrix.
auto* load_fn = TransposeArrayHelper(ld->From()->Type()->UnwrapPtr());
b.InsertAfter(ld, [&] {
auto* new_ld = b.Load(ld->From());
b.CallWithResult(ld->DetachResult(), load_fn, new_ld);
});
// Do this after we're done so we don't end up modifying the usage list of the
// result we're iterating over.
instructions_to_remove_if_unused.Push(ld);
},
TINT_ICE_ON_NO_MATCH);
}
void ReplaceStore(core::ir::Store* store) {
auto vec_idx = access_to_vector_index.Get(store->To());
auto* to_ty = store->To()->Type()->UnwrapPtr();
if (!vec_idx) {
tint::Switch(
to_ty, //
[&](const core::type::Matrix*) {
// Storing the full matrix
b.InsertBefore(store, [&] {
auto* from = b.Call(to_ty, core::BuiltinFn::kTranspose, store->From());
store->SetFrom(from->Result());
});
},
[&](const core::type::Array*) {
b.InsertBefore(store, [&] {
auto* from = store->From();
if (from->Type()->Is<core::type::Pointer>()) {
from = b.Load(from)->Result();
}
auto* fn = TransposeArrayHelper(from->Type());
store->SetFrom(b.Call(to_ty, fn, from)->Result());
});
},
[&](const core::type::Struct*) {
// Should already be fixed
},
TINT_ICE_ON_NO_MATCH);
return;
}
// Storing a vector
auto* store_fn = StoreColumnHelper(store->To()->Type()->As<core::type::Pointer>());
b.InsertAfter(store, [&] {
auto* v = *vec_idx;
if (v->Type()->Is<core::type::I32>()) {
v = b.Convert(ty.u32(), v)->Result();
}
b.Call(ty.void_(), store_fn, store->To(), v, store->From());
});
instructions_to_remove_if_unused.Push(store);
}
// Note, the type provided in the pointer has _already_ been transposed.
core::ir::Function* StoreColumnHelper(const core::type::Pointer* ptr) {
return store_functions.GetOrAdd(ptr, [&] {
TINT_ASSERT(ptr);
auto* row_major_ty = ptr->UnwrapPtr()->As<core::type::Matrix>();
TINT_ASSERT(row_major_ty);
auto* vec_ty = ty.vec(row_major_ty->Type(), row_major_ty->Columns());
auto* col_ptr_ty =
ty.ptr(ptr->AddressSpace(), ty.vec(row_major_ty->Type(), row_major_ty->Rows()),
ptr->Access());
auto* fn = b.Function(kTintStoreRowMajor, ty.void_());
auto* mat = b.FunctionParam(ptr);
auto* row = b.FunctionParam(ty.u32());
auto* col = b.FunctionParam(vec_ty);
fn->SetParams({mat, row, col});
b.Append(fn->Block(), [&] {
for (uint32_t i = 0; i < row_major_ty->Columns(); ++i) {
auto* col_access = b.Access(vec_ty->DeepestElement(), col, u32(i));
b.StoreVectorElement(b.Access(col_ptr_ty, mat, u32(i)), row, col_access);
}
b.Return(fn);
});
return fn;
});
}
// Note, the type provided in the pointer has _already_ been transposed.
core::ir::Function* LoadColumnHelper(const core::type::Pointer* ptr) {
return load_functions.GetOrAdd(ptr, [&] {
TINT_ASSERT(ptr);
auto* row_major_ty = ptr->UnwrapPtr()->As<core::type::Matrix>();
TINT_ASSERT(row_major_ty);
auto* vec_ty = ty.vec(row_major_ty->Type(), row_major_ty->Columns());
auto* col_ptr_ty =
ty.ptr(ptr->AddressSpace(), ty.vec(row_major_ty->Type(), row_major_ty->Rows()),
ptr->Access());
auto* fn = b.Function(kTintLoadRowMajor, vec_ty);
auto* mat = b.FunctionParam(ptr);
auto* row = b.FunctionParam(ty.u32());
fn->SetParams({mat, row});
b.Append(fn->Block(), [&] {
Vector<core::ir::Value*, 4> values;
for (uint32_t i = 0; i < row_major_ty->Columns(); ++i) {
values.Push(
b.LoadVectorElement(b.Access(col_ptr_ty, mat, u32(i)), row)->Result());
}
b.Return(fn, b.Construct(vec_ty, values));
});
return fn;
});
}
core::ir::Function* TransposeArrayHelper(const core::type::Type* in_type) {
// The helper function will look like this:
// fn tint_transpose_array(from: array<mat3x2<f32>, 4>) -> array<mat2x3<f32>, 4> {
// var result : array<mat2x3<f32>, 4>;
// for (var i = 0; i < 4; i++) {
// result[i] = transpose(from[i]);
// }
// return result;
// }
TINT_ASSERT(in_type);
return load_functions.GetOrAdd(in_type, [&] {
auto* outer_ty = in_type->As<core::type::Array>();
TINT_ASSERT(outer_ty);
auto* from_ty = outer_ty;
auto* to_ty = RewriteType(in_type, true)->As<core::type::Array>();
TINT_ASSERT(from_ty);
auto* fn = b.Function(kTintTransposeRowMajorArray, to_ty);
auto* in = b.FunctionParam(in_type);
fn->SetParams({in});
auto count = from_ty->ConstantCount();
b.Append(fn->Block(), [&] {
auto* res = b.Var(ty.ptr(function, to_ty));
b.LoopRange(u32(0), u32(*count), u32(1), [&](core::ir::Value* idx) {
core::ir::Value* transposed = nullptr;
auto* cur = b.Access(from_ty->ElemType(), in, idx);
if (auto* nested = outer_ty->ElemType()->As<core::type::Array>()) {
auto* inner_fn = TransposeArrayHelper(nested);
transposed = b.Call(to_ty->ElemType(), inner_fn, cur)->Result();
} else {
transposed =
b.Call(to_ty->ElemType(), core::BuiltinFn::kTranspose, cur)->Result();
}
auto* slot = b.Access(ty.ptr(function, to_ty->ElemType()), res, idx);
b.Store(slot, transposed);
});
auto* ld = b.Load(res);
b.Return(fn, ld);
});
return fn;
});
}
const core::type::Type* RewriteType(const core::type::Type* type, bool decorated_row_major) {
return original_to_transposed.GetOrAdd(TypeAndRowMajor{type, decorated_row_major}, [&] {
return tint::Switch(
type,
[&](const core::type::Array* arr) {
return RewriteArray(arr, decorated_row_major);
},
[&](const core::type::Struct* str) { return RewriteStruct(str); },
[&](const core::type::Matrix* mat) {
if (!decorated_row_major) {
return mat;
}
return ty.mat(mat->Type(), mat->Rows(), mat->Columns());
},
[&](const core::type::Pointer* ptr) {
return ty.ptr(ptr->AddressSpace(),
RewriteType(ptr->StoreType(), decorated_row_major),
ptr->Access());
},
[&](Default) { return type; });
});
}
const core::type::Type* RewriteArray(const core::type::Array* arr, bool decorated_row_major) {
auto* elem_ty = RewriteType(arr->ElemType(), decorated_row_major);
if (elem_ty == arr->ElemType()) {
return arr;
}
// The element type is the only thing that will change. That does not affect the stride of
// the array itself, which may either be the natural stride or an larger stride in the case
// of an explicitly laid out array.
if (auto* ex = arr->As<spirv::type::ExplicitLayoutArray>()) {
return ty.Get<spirv::type::ExplicitLayoutArray>(elem_ty, arr->Count(), arr->Size(),
ex->Stride());
}
return ty.Get<core::type::Array>(elem_ty, arr->Count(), arr->Size());
}
const core::type::Type* RewriteStruct(const core::type::Struct* old_struct) {
bool made_changes = false;
Vector<const core::type::StructMember*, 8> new_members;
new_members.Reserve(old_struct->Members().Length());
for (auto* member : old_struct->Members()) {
auto* new_member_type = RewriteType(member->Type(), member->RowMajor());
if (member->RowMajor() || new_member_type != member->Type()) {
// Recreate the struct member without the row major attribute, using the new type.
auto* new_member = ty.Get<core::type::StructMember>(
member->Name(), new_member_type, member->Index(), member->Offset(),
member->Align(), member->Size(), member->Attributes());
if (member->HasMatrixStride()) {
new_member->SetMatrixStride(member->MatrixStride());
}
new_members.Push(new_member);
made_changes = true;
} else {
new_members.Push(member);
}
}
if (!made_changes) {
return old_struct;
}
// Create the new struct and record the mapping to the old struct.
auto* new_struct = ty.Struct(sym.New(old_struct->Name().Name()), std::move(new_members));
struct_to_original.Add(new_struct, old_struct);
return new_struct;
}
const core::constant::Value* RewriteConstant(const core::constant::Value* constant,
bool is_row_major) {
auto* orig_type = constant->Type();
auto* new_type = RewriteType(orig_type, is_row_major);
if (new_type == orig_type) {
return constant;
}
return tint::Switch(
new_type, //
[&](const core::type::Matrix* mat) {
if (!is_row_major) {
return constant;
}
auto* orig_mat = orig_type->As<core::type::Matrix>();
TINT_ASSERT(orig_mat);
TINT_ASSERT(constant->NumElements() == mat->Rows());
Vector<const core::constant::Value*, 4> columns;
for (size_t i = 0; i < mat->Columns(); ++i) {
auto* vec_ty = ty.vec(mat->Type(), mat->Rows());
Vector<const core::constant::Value*, 4> vec_elements;
for (uint32_t j = 0; j < mat->Rows(); ++j) {
auto* value = constant->Index(j);
TINT_ASSERT(value->NumElements() == mat->Columns());
vec_elements.Push(value->Index(i));
}
columns.Push(ir.constant_values.Composite(vec_ty, std::move(vec_elements)));
}
return ir.constant_values.Composite(new_type, std::move(columns));
},
[&](const core::type::Array*) {
if (!is_row_major) {
return constant;
}
Vector<const core::constant::Value*, 16> elements;
for (uint32_t i = 0; i < constant->NumElements(); i++) {
auto* value = constant->Index(i);
elements.Push(RewriteConstant(value, is_row_major));
}
return ir.constant_values.Composite(new_type, std::move(elements));
},
[&](const core::type::Struct* str) {
TINT_ASSERT(constant->NumElements() == str->Members().Length());
Vector<const core::constant::Value*, 16> elements;
elements.Reserve(str->Members().Length());
auto* orig_str = orig_type->As<core::type::Struct>();
TINT_ASSERT(orig_str);
for (size_t i = 0; i < orig_str->Members().Length(); ++i) {
auto& orig_mem = orig_str->Members()[i];
auto& new_mem = str->Members()[i];
auto* value = constant->Index(i);
auto* new_member_type = new_mem->Type();
if (new_member_type != value->Type()) {
elements.Push(RewriteConstant(value, orig_mem->RowMajor()));
} else {
elements.Push(value);
}
}
return ir.constant_values.Composite(new_type, std::move(elements));
},
[&](Default) { return constant; });
}
};
} // namespace
Result<SuccessType> TransposeRowMajor(core::ir::Module& ir) {
TINT_CHECK_RESULT(
ValidateAndDumpIfNeeded(ir, "spirv.TransposeRowMajor",
core::ir::Capabilities{
core::ir::Capability::kAllowMultipleEntryPoints,
core::ir::Capability::kAllowStructMatrixDecorations,
core::ir::Capability::kAllowNonCoreTypes,
core::ir::Capability::kAllowOverrides,
core::ir::Capability::kAllowPointerToHandle,
}));
State{ir}.Process();
return Success;
}
} // namespace tint::spirv::reader::lower