| // Copyright 2022 The Tint Authors. | 
 | // | 
 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
 | // you may not use this file except in compliance with the License. | 
 | // You may obtain a copy of the License at | 
 | // | 
 | //     http://www.apache.org/licenses/LICENSE-2.0 | 
 | // | 
 | // Unless required by applicable law or agreed to in writing, software | 
 | // distributed under the License is distributed on an "AS IS" BASIS, | 
 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | // See the License for the specific language governing permissions and | 
 | // limitations under the License. | 
 |  | 
 | #include "src/tint/transform/packed_vec3.h" | 
 |  | 
 | #include <algorithm> | 
 | #include <string> | 
 | #include <utility> | 
 |  | 
 | #include "src/tint/program_builder.h" | 
 | #include "src/tint/sem/index_accessor_expression.h" | 
 | #include "src/tint/sem/member_accessor_expression.h" | 
 | #include "src/tint/sem/statement.h" | 
 | #include "src/tint/sem/variable.h" | 
 | #include "src/tint/utils/hashmap.h" | 
 | #include "src/tint/utils/hashset.h" | 
 |  | 
 | TINT_INSTANTIATE_TYPEINFO(tint::transform::PackedVec3); | 
 | TINT_INSTANTIATE_TYPEINFO(tint::transform::PackedVec3::Attribute); | 
 |  | 
 | using namespace tint::number_suffixes;  // NOLINT | 
 |  | 
 | namespace tint::transform { | 
 |  | 
 | /// PIMPL state for the transform | 
 | struct PackedVec3::State { | 
 |     /// Constructor | 
 |     /// @param program the source program | 
 |     explicit State(const Program* program) : src(program) {} | 
 |  | 
 |     /// Runs the transform | 
 |     /// @returns the new program or SkipTransform if the transform is not required | 
 |     ApplyResult Run() { | 
 |         // Packed vec3<T> struct members | 
 |         utils::Hashset<const sem::StructMember*, 8> members; | 
 |  | 
 |         // Find all the packed vector struct members, and apply the @internal(packed_vector) | 
 |         // attribute. | 
 |         for (auto* decl : ctx.src->AST().GlobalDeclarations()) { | 
 |             if (auto* str = sem.Get<sem::Struct>(decl)) { | 
 |                 if (str->IsHostShareable()) { | 
 |                     for (auto* member : str->Members()) { | 
 |                         if (auto* vec = member->Type()->As<type::Vector>()) { | 
 |                             if (vec->Width() == 3) { | 
 |                                 members.Add(member); | 
 |  | 
 |                                 // Apply the PackedVec3::Attribute to the member | 
 |                                 ctx.InsertFront( | 
 |                                     member->Declaration()->attributes, | 
 |                                     b.ASTNodes().Create<Attribute>(b.ID(), b.AllocateNodeID())); | 
 |                             } | 
 |                         } | 
 |                     } | 
 |                 } | 
 |             } | 
 |         } | 
 |  | 
 |         if (members.IsEmpty()) { | 
 |             return SkipTransform; | 
 |         } | 
 |  | 
 |         // Walk the nodes, starting with the most deeply nested, finding all the AST expressions | 
 |         // that load a whole packed vector (not a scalar / swizzle of the vector). | 
 |         utils::Hashset<const sem::ValueExpression*, 16> refs; | 
 |         for (auto* node : ctx.src->ASTNodes().Objects()) { | 
 |             auto* sem_node = sem.Get(node); | 
 |             if (sem_node) { | 
 |                 if (auto* expr = sem_node->As<sem::ValueExpression>()) { | 
 |                     sem_node = expr->UnwrapLoad(); | 
 |                 } | 
 |             } | 
 |             Switch( | 
 |                 sem_node,  // | 
 |                 [&](const sem::StructMemberAccess* access) { | 
 |                     if (members.Contains(access->Member())) { | 
 |                         // Access to a packed vector member. Seed the expression tracking. | 
 |                         refs.Add(access); | 
 |                     } | 
 |                 }, | 
 |                 [&](const sem::IndexAccessorExpression* access) { | 
 |                     // Not loading a whole packed vector. Ignore. | 
 |                     refs.Remove(access->Object()->UnwrapLoad()); | 
 |                 }, | 
 |                 [&](const sem::Swizzle* access) { | 
 |                     // Not loading a whole packed vector. Ignore. | 
 |                     refs.Remove(access->Object()->UnwrapLoad()); | 
 |                 }, | 
 |                 [&](const sem::VariableUser* user) { | 
 |                     auto* v = user->Variable(); | 
 |                     if (v->Declaration()->Is<ast::Let>() &&  // if variable is let... | 
 |                         v->Type()->Is<type::Pointer>() &&    // and let is a pointer... | 
 |                         refs.Contains(v->Initializer())) {   // and pointer is to a packed vector... | 
 |                         refs.Add(user);  // then propagate tracking to pointer usage | 
 |                     } | 
 |                 }, | 
 |                 [&](const sem::ValueExpression* expr) { | 
 |                     if (auto* unary = expr->Declaration()->As<ast::UnaryOpExpression>()) { | 
 |                         if (unary->op == ast::UnaryOp::kAddressOf || | 
 |                             unary->op == ast::UnaryOp::kIndirection) { | 
 |                             // Memory access on the packed vector. Track these. | 
 |                             auto* inner = sem.GetVal(unary->expr); | 
 |                             if (refs.Remove(inner)) { | 
 |                                 refs.Add(expr); | 
 |                             } | 
 |                         } | 
 |                         // Note: non-memory ops (e.g. '-') are ignored, leaving any tracked | 
 |                         // reference at the inner expression, so we'd cast, then apply the unary op. | 
 |                     } | 
 |                 }, | 
 |                 [&](const sem::Statement* e) { | 
 |                     if (auto* assign = e->Declaration()->As<ast::AssignmentStatement>()) { | 
 |                         // We don't want to cast packed_vectors if they're being assigned to. | 
 |                         refs.Remove(sem.GetVal(assign->lhs)); | 
 |                     } | 
 |                 }); | 
 |         } | 
 |  | 
 |         // Wrap the load expressions with a cast to the unpacked type. | 
 |         utils::Hashmap<const type::Vector*, Symbol, 3> unpack_fns; | 
 |         for (auto* ref : refs) { | 
 |             // ref is either a packed vec3 that needs casting, or a pointer to a vec3 which we just | 
 |             // leave alone. | 
 |             if (auto* vec_ty = ref->Type()->UnwrapRef()->As<type::Vector>()) { | 
 |                 auto* expr = ref->Declaration(); | 
 |                 ctx.Replace(expr, [this, vec_ty, expr] {  // | 
 |                     auto* packed = ctx.CloneWithoutTransform(expr); | 
 |                     return b.Call(CreateASTTypeFor(ctx, vec_ty), packed); | 
 |                 }); | 
 |             } | 
 |         } | 
 |  | 
 |         ctx.Clone(); | 
 |         return Program(std::move(b)); | 
 |     } | 
 |  | 
 |   private: | 
 |     /// The source program | 
 |     const Program* const src; | 
 |     /// The target program builder | 
 |     ProgramBuilder b; | 
 |     /// The clone context | 
 |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | 
 |     /// Alias to the semantic info in ctx.src | 
 |     const sem::Info& sem = ctx.src->Sem(); | 
 |     /// Alias to the symbols in ctx.src | 
 |     const SymbolTable& sym = ctx.src->Symbols(); | 
 | }; | 
 |  | 
 | PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {} | 
 | PackedVec3::Attribute::~Attribute() = default; | 
 |  | 
 | const PackedVec3::Attribute* PackedVec3::Attribute::Clone(CloneContext* ctx) const { | 
 |     return ctx->dst->ASTNodes().Create<Attribute>(ctx->dst->ID(), ctx->dst->AllocateNodeID()); | 
 | } | 
 |  | 
 | std::string PackedVec3::Attribute::InternalName() const { | 
 |     return "packed_vector"; | 
 | } | 
 |  | 
 | PackedVec3::PackedVec3() = default; | 
 | PackedVec3::~PackedVec3() = default; | 
 |  | 
 | Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const { | 
 |     return State{src}.Run(); | 
 | } | 
 |  | 
 | }  // namespace tint::transform |