|  | // Copyright 2020 The Tint Authors. | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "src/transform/vertex_pulling.h" | 
|  |  | 
|  | #include <algorithm> | 
|  | #include <utility> | 
|  |  | 
|  | #include "src/ast/assignment_statement.h" | 
|  | #include "src/ast/bitcast_expression.h" | 
|  | #include "src/ast/struct_block_decoration.h" | 
|  | #include "src/ast/variable_decl_statement.h" | 
|  | #include "src/program_builder.h" | 
|  | #include "src/sem/variable.h" | 
|  | #include "src/utils/get_or_create.h" | 
|  | #include "src/utils/math.h" | 
|  |  | 
|  | TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling); | 
|  | TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config); | 
|  |  | 
|  | namespace tint { | 
|  | namespace transform { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | /// The base type of a component. | 
|  | /// The format type is either this type or a vector of this type. | 
|  | enum class BaseType { | 
|  | kInvalid, | 
|  | kU32, | 
|  | kI32, | 
|  | kF32, | 
|  | }; | 
|  |  | 
|  | /// Writes the BaseType to the std::ostream. | 
|  | /// @param out the std::ostream to write to | 
|  | /// @param format the BaseType to write | 
|  | /// @returns out so calls can be chained | 
|  | std::ostream& operator<<(std::ostream& out, BaseType format) { | 
|  | switch (format) { | 
|  | case BaseType::kInvalid: | 
|  | return out << "invalid"; | 
|  | case BaseType::kU32: | 
|  | return out << "u32"; | 
|  | case BaseType::kI32: | 
|  | return out << "i32"; | 
|  | case BaseType::kF32: | 
|  | return out << "f32"; | 
|  | } | 
|  | return out << "<unknown>"; | 
|  | } | 
|  |  | 
|  | /// Writes the VertexFormat to the std::ostream. | 
|  | /// @param out the std::ostream to write to | 
|  | /// @param format the VertexFormat to write | 
|  | /// @returns out so calls can be chained | 
|  | std::ostream& operator<<(std::ostream& out, VertexFormat format) { | 
|  | switch (format) { | 
|  | case VertexFormat::kUint8x2: | 
|  | return out << "uint8x2"; | 
|  | case VertexFormat::kUint8x4: | 
|  | return out << "uint8x4"; | 
|  | case VertexFormat::kSint8x2: | 
|  | return out << "sint8x2"; | 
|  | case VertexFormat::kSint8x4: | 
|  | return out << "sint8x4"; | 
|  | case VertexFormat::kUnorm8x2: | 
|  | return out << "unorm8x2"; | 
|  | case VertexFormat::kUnorm8x4: | 
|  | return out << "unorm8x4"; | 
|  | case VertexFormat::kSnorm8x2: | 
|  | return out << "snorm8x2"; | 
|  | case VertexFormat::kSnorm8x4: | 
|  | return out << "snorm8x4"; | 
|  | case VertexFormat::kUint16x2: | 
|  | return out << "uint16x2"; | 
|  | case VertexFormat::kUint16x4: | 
|  | return out << "uint16x4"; | 
|  | case VertexFormat::kSint16x2: | 
|  | return out << "sint16x2"; | 
|  | case VertexFormat::kSint16x4: | 
|  | return out << "sint16x4"; | 
|  | case VertexFormat::kUnorm16x2: | 
|  | return out << "unorm16x2"; | 
|  | case VertexFormat::kUnorm16x4: | 
|  | return out << "unorm16x4"; | 
|  | case VertexFormat::kSnorm16x2: | 
|  | return out << "snorm16x2"; | 
|  | case VertexFormat::kSnorm16x4: | 
|  | return out << "snorm16x4"; | 
|  | case VertexFormat::kFloat16x2: | 
|  | return out << "float16x2"; | 
|  | case VertexFormat::kFloat16x4: | 
|  | return out << "float16x4"; | 
|  | case VertexFormat::kFloat32: | 
|  | return out << "float32"; | 
|  | case VertexFormat::kFloat32x2: | 
|  | return out << "float32x2"; | 
|  | case VertexFormat::kFloat32x3: | 
|  | return out << "float32x3"; | 
|  | case VertexFormat::kFloat32x4: | 
|  | return out << "float32x4"; | 
|  | case VertexFormat::kUint32: | 
|  | return out << "uint32"; | 
|  | case VertexFormat::kUint32x2: | 
|  | return out << "uint32x2"; | 
|  | case VertexFormat::kUint32x3: | 
|  | return out << "uint32x3"; | 
|  | case VertexFormat::kUint32x4: | 
|  | return out << "uint32x4"; | 
|  | case VertexFormat::kSint32: | 
|  | return out << "sint32"; | 
|  | case VertexFormat::kSint32x2: | 
|  | return out << "sint32x2"; | 
|  | case VertexFormat::kSint32x3: | 
|  | return out << "sint32x3"; | 
|  | case VertexFormat::kSint32x4: | 
|  | return out << "sint32x4"; | 
|  | } | 
|  | return out << "<unknown>"; | 
|  | } | 
|  |  | 
|  | /// A vertex attribute data format. | 
|  | struct DataType { | 
|  | BaseType base_type; | 
|  | uint32_t width;  // 1 for scalar, 2+ for a vector | 
|  | }; | 
|  |  | 
|  | DataType DataTypeOf(sem::Type* ty) { | 
|  | if (ty->Is<sem::I32>()) { | 
|  | return {BaseType::kI32, 1}; | 
|  | } | 
|  | if (ty->Is<sem::U32>()) { | 
|  | return {BaseType::kU32, 1}; | 
|  | } | 
|  | if (ty->Is<sem::F32>()) { | 
|  | return {BaseType::kF32, 1}; | 
|  | } | 
|  | if (auto* vec = ty->As<sem::Vector>()) { | 
|  | return {DataTypeOf(vec->type()).base_type, vec->Width()}; | 
|  | } | 
|  | return {BaseType::kInvalid, 0}; | 
|  | } | 
|  |  | 
|  | DataType DataTypeOf(VertexFormat format) { | 
|  | switch (format) { | 
|  | case VertexFormat::kUint32: | 
|  | return {BaseType::kU32, 1}; | 
|  | case VertexFormat::kUint8x2: | 
|  | case VertexFormat::kUint16x2: | 
|  | case VertexFormat::kUint32x2: | 
|  | return {BaseType::kU32, 2}; | 
|  | case VertexFormat::kUint32x3: | 
|  | return {BaseType::kU32, 3}; | 
|  | case VertexFormat::kUint8x4: | 
|  | case VertexFormat::kUint16x4: | 
|  | case VertexFormat::kUint32x4: | 
|  | return {BaseType::kU32, 4}; | 
|  | case VertexFormat::kSint32: | 
|  | return {BaseType::kI32, 1}; | 
|  | case VertexFormat::kSint8x2: | 
|  | case VertexFormat::kSint16x2: | 
|  | case VertexFormat::kSint32x2: | 
|  | return {BaseType::kI32, 2}; | 
|  | case VertexFormat::kSint32x3: | 
|  | return {BaseType::kI32, 3}; | 
|  | case VertexFormat::kSint8x4: | 
|  | case VertexFormat::kSint16x4: | 
|  | case VertexFormat::kSint32x4: | 
|  | return {BaseType::kI32, 4}; | 
|  | case VertexFormat::kFloat32: | 
|  | return {BaseType::kF32, 1}; | 
|  | case VertexFormat::kUnorm8x2: | 
|  | case VertexFormat::kSnorm8x2: | 
|  | case VertexFormat::kUnorm16x2: | 
|  | case VertexFormat::kSnorm16x2: | 
|  | case VertexFormat::kFloat16x2: | 
|  | case VertexFormat::kFloat32x2: | 
|  | return {BaseType::kF32, 2}; | 
|  | case VertexFormat::kFloat32x3: | 
|  | return {BaseType::kF32, 3}; | 
|  | case VertexFormat::kUnorm8x4: | 
|  | case VertexFormat::kSnorm8x4: | 
|  | case VertexFormat::kUnorm16x4: | 
|  | case VertexFormat::kSnorm16x4: | 
|  | case VertexFormat::kFloat16x4: | 
|  | case VertexFormat::kFloat32x4: | 
|  | return {BaseType::kF32, 4}; | 
|  | } | 
|  | return {BaseType::kInvalid, 0}; | 
|  | } | 
|  |  | 
|  | struct State { | 
|  | State(CloneContext& context, const VertexPulling::Config& c) | 
|  | : ctx(context), cfg(c) {} | 
|  | State(const State&) = default; | 
|  | ~State() = default; | 
|  |  | 
|  | /// LocationReplacement describes an ast::Variable replacement for a | 
|  | /// location input. | 
|  | struct LocationReplacement { | 
|  | /// The variable to replace in the source Program | 
|  | ast::Variable* from; | 
|  | /// The replacement to use in the target ProgramBuilder | 
|  | ast::Variable* to; | 
|  | }; | 
|  |  | 
|  | struct LocationInfo { | 
|  | std::function<ast::Expression*()> expr; | 
|  | sem::Type* type; | 
|  | }; | 
|  |  | 
|  | CloneContext& ctx; | 
|  | VertexPulling::Config const cfg; | 
|  | std::unordered_map<uint32_t, LocationInfo> location_info; | 
|  | std::function<ast::Expression*()> vertex_index_expr = nullptr; | 
|  | std::function<ast::Expression*()> instance_index_expr = nullptr; | 
|  | Symbol pulling_position_name; | 
|  | Symbol struct_buffer_name; | 
|  | std::unordered_map<uint32_t, Symbol> vertex_buffer_names; | 
|  | ast::VariableList new_function_parameters; | 
|  |  | 
|  | /// Generate the vertex buffer binding name | 
|  | /// @param index index to append to buffer name | 
|  | Symbol GetVertexBufferName(uint32_t index) { | 
|  | return utils::GetOrCreate(vertex_buffer_names, index, [&] { | 
|  | static const char kVertexBufferNamePrefix[] = | 
|  | "tint_pulling_vertex_buffer_"; | 
|  | return ctx.dst->Symbols().New(kVertexBufferNamePrefix + | 
|  | std::to_string(index)); | 
|  | }); | 
|  | } | 
|  |  | 
|  | /// Lazily generates the structure buffer symbol | 
|  | Symbol GetStructBufferName() { | 
|  | if (!struct_buffer_name.IsValid()) { | 
|  | static const char kStructBufferName[] = "tint_vertex_data"; | 
|  | struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName); | 
|  | } | 
|  | return struct_buffer_name; | 
|  | } | 
|  |  | 
|  | /// Adds storage buffer decorated variables for the vertex buffers | 
|  | void AddVertexStorageBuffers() { | 
|  | // Creating the struct type | 
|  | static const char kStructName[] = "TintVertexData"; | 
|  | auto* struct_type = ctx.dst->Structure( | 
|  | ctx.dst->Symbols().New(kStructName), | 
|  | { | 
|  | ctx.dst->Member(GetStructBufferName(), | 
|  | ctx.dst->ty.array<ProgramBuilder::u32, 0>(4)), | 
|  | }, | 
|  | { | 
|  | ctx.dst->create<ast::StructBlockDecoration>(), | 
|  | }); | 
|  | for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { | 
|  | // The decorated variable with struct type | 
|  | ctx.dst->Global( | 
|  | GetVertexBufferName(i), ctx.dst->ty.Of(struct_type), | 
|  | ast::StorageClass::kStorage, ast::Access::kRead, | 
|  | ast::DecorationList{ | 
|  | ctx.dst->create<ast::BindingDecoration>(i), | 
|  | ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group), | 
|  | }); | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Creates and returns the assignment to the variables from the buffers | 
|  | ast::BlockStatement* CreateVertexPullingPreamble() { | 
|  | // Assign by looking at the vertex descriptor to find attributes with | 
|  | // matching location. | 
|  |  | 
|  | ast::StatementList stmts; | 
|  |  | 
|  | for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); | 
|  | ++buffer_idx) { | 
|  | const VertexBufferLayoutDescriptor& buffer_layout = | 
|  | cfg.vertex_state[buffer_idx]; | 
|  |  | 
|  | if ((buffer_layout.array_stride & 3) != 0) { | 
|  | ctx.dst->Diagnostics().add_error( | 
|  | diag::System::Transform, | 
|  | "WebGPU requires that vertex stride must be a multiple of 4 bytes, " | 
|  | "but VertexPulling array stride for buffer " + | 
|  | std::to_string(buffer_idx) + " was " + | 
|  | std::to_string(buffer_layout.array_stride) + " bytes"); | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex | 
|  | ? vertex_index_expr() | 
|  | : instance_index_expr(); | 
|  |  | 
|  | // buffer_array_base is the base array offset for all the vertex | 
|  | // attributes. These are units of uint (4 bytes). | 
|  | auto buffer_array_base = ctx.dst->Symbols().New( | 
|  | "buffer_array_base_" + std::to_string(buffer_idx)); | 
|  |  | 
|  | auto* attribute_offset = index_expr; | 
|  | if (buffer_layout.array_stride != 4) { | 
|  | attribute_offset = | 
|  | ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u); | 
|  | } | 
|  |  | 
|  | // let pulling_offset_n = <attribute_offset> | 
|  | stmts.emplace_back(ctx.dst->Decl( | 
|  | ctx.dst->Const(buffer_array_base, nullptr, attribute_offset))); | 
|  |  | 
|  | for (const VertexAttributeDescriptor& attribute_desc : | 
|  | buffer_layout.attributes) { | 
|  | auto it = location_info.find(attribute_desc.shader_location); | 
|  | if (it == location_info.end()) { | 
|  | continue; | 
|  | } | 
|  | auto& var = it->second; | 
|  |  | 
|  | // Data type of the target WGSL variable | 
|  | auto var_dt = DataTypeOf(var.type); | 
|  | // Data type of the vertex stream attribute | 
|  | auto fmt_dt = DataTypeOf(attribute_desc.format); | 
|  |  | 
|  | // Base types must match between the vertex stream and the WGSL variable | 
|  | if (var_dt.base_type != fmt_dt.base_type) { | 
|  | std::stringstream err; | 
|  | err << "VertexAttributeDescriptor for location " | 
|  | << std::to_string(attribute_desc.shader_location) | 
|  | << " has format " << attribute_desc.format | 
|  | << " but shader expects " | 
|  | << var.type->FriendlyName(ctx.src->Symbols()); | 
|  | ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str()); | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | // Load the attribute value | 
|  | auto* fetch = Fetch(buffer_array_base, attribute_desc.offset, | 
|  | buffer_idx, attribute_desc.format); | 
|  |  | 
|  | // The attribute value may not be of the desired vector width. If it is | 
|  | // not, we'll need to either reduce the width with a swizzle, or append | 
|  | // 0's and / or a 1. | 
|  | auto* value = fetch; | 
|  | if (var_dt.width < fmt_dt.width) { | 
|  | // WGSL variable vector width is smaller than the loaded vector width | 
|  | switch (var_dt.width) { | 
|  | case 1: | 
|  | value = ctx.dst->MemberAccessor(fetch, "x"); | 
|  | break; | 
|  | case 2: | 
|  | value = ctx.dst->MemberAccessor(fetch, "xy"); | 
|  | break; | 
|  | case 3: | 
|  | value = ctx.dst->MemberAccessor(fetch, "xyz"); | 
|  | break; | 
|  | default: | 
|  | TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | 
|  | << var_dt.width; | 
|  | return nullptr; | 
|  | } | 
|  | } else if (var_dt.width > fmt_dt.width) { | 
|  | // WGSL variable vector width is wider than the loaded vector width | 
|  | ast::Type* ty = nullptr; | 
|  | ast::ExpressionList values{fetch}; | 
|  | switch (var_dt.base_type) { | 
|  | case BaseType::kI32: | 
|  | ty = ctx.dst->ty.i32(); | 
|  | for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | 
|  | values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0)); | 
|  | } | 
|  | break; | 
|  | case BaseType::kU32: | 
|  | ty = ctx.dst->ty.u32(); | 
|  | for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | 
|  | values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u)); | 
|  | } | 
|  | break; | 
|  | case BaseType::kF32: | 
|  | ty = ctx.dst->ty.f32(); | 
|  | for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | 
|  | values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f)); | 
|  | } | 
|  | break; | 
|  | default: | 
|  | TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | 
|  | << var_dt.base_type; | 
|  | return nullptr; | 
|  | } | 
|  | value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values); | 
|  | } | 
|  |  | 
|  | // Assign the value to the WGSL variable | 
|  | stmts.emplace_back(ctx.dst->Assign(var.expr(), value)); | 
|  | } | 
|  | } | 
|  |  | 
|  | if (stmts.empty()) { | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | return ctx.dst->create<ast::BlockStatement>(stmts); | 
|  | } | 
|  |  | 
|  | /// Generates an expression reading from a buffer a specific format. | 
|  | /// @param array_base the symbol of the variable holding the base array offset | 
|  | /// of the vertex array (each index is 4-bytes). | 
|  | /// @param offset the byte offset of the data from `buffer_base` | 
|  | /// @param buffer the index of the vertex buffer | 
|  | /// @param format the format to read | 
|  | ast::Expression* Fetch(Symbol array_base, | 
|  | uint32_t offset, | 
|  | uint32_t buffer, | 
|  | VertexFormat format) { | 
|  | using u32 = ProgramBuilder::u32; | 
|  | using i32 = ProgramBuilder::i32; | 
|  | using f32 = ProgramBuilder::f32; | 
|  |  | 
|  | // Returns a u32 loaded from buffer_base + offset. | 
|  | auto load_u32 = [&] { | 
|  | return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32); | 
|  | }; | 
|  |  | 
|  | // Returns a i32 loaded from buffer_base + offset. | 
|  | auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); }; | 
|  |  | 
|  | // Returns a u32 loaded from buffer_base + offset + 4. | 
|  | auto load_next_u32 = [&] { | 
|  | return LoadPrimitive(array_base, offset + 4, buffer, | 
|  | VertexFormat::kUint32); | 
|  | }; | 
|  |  | 
|  | // Returns a i32 loaded from buffer_base + offset + 4. | 
|  | auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); }; | 
|  |  | 
|  | // Returns a u16 loaded from offset, packed in the high 16 bits of a u32. | 
|  | // The low 16 bits are 0. | 
|  | // `min_alignment` must be a power of two. | 
|  | // `offset` must be `min_alignment` bytes aligned. | 
|  | auto load_u16_h = [&] { | 
|  | auto low_u32_offset = offset & ~3u; | 
|  | auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer, | 
|  | VertexFormat::kUint32); | 
|  | switch (offset & 3) { | 
|  | case 0: | 
|  | return ctx.dst->Shl(low_u32, 16u); | 
|  | case 1: | 
|  | return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u); | 
|  | case 2: | 
|  | return ctx.dst->And(low_u32, 0xffff0000u); | 
|  | default: {  // 3: | 
|  | auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | 
|  | VertexFormat::kUint32); | 
|  | auto* shr = ctx.dst->Shr(low_u32, 8u); | 
|  | auto* shl = ctx.dst->Shl(high_u32, 24u); | 
|  | return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u); | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | // Returns a u16 loaded from offset, packed in the low 16 bits of a u32. | 
|  | // The high 16 bits are 0. | 
|  | auto load_u16_l = [&] { | 
|  | auto low_u32_offset = offset & ~3u; | 
|  | auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer, | 
|  | VertexFormat::kUint32); | 
|  | switch (offset & 3) { | 
|  | case 0: | 
|  | return ctx.dst->And(low_u32, 0xffffu); | 
|  | case 1: | 
|  | return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu); | 
|  | case 2: | 
|  | return ctx.dst->Shr(low_u32, 16u); | 
|  | default: {  // 3: | 
|  | auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | 
|  | VertexFormat::kUint32); | 
|  | auto* shr = ctx.dst->Shr(low_u32, 24u); | 
|  | auto* shl = ctx.dst->Shl(high_u32, 8u); | 
|  | return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu); | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | // Returns a i16 loaded from offset, packed in the high 16 bits of a u32. | 
|  | // The low 16 bits are 0. | 
|  | auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); }; | 
|  |  | 
|  | // Assumptions are made that alignment must be at least as large as the size | 
|  | // of a single component. | 
|  | switch (format) { | 
|  | // Basic primitives | 
|  | case VertexFormat::kUint32: | 
|  | case VertexFormat::kSint32: | 
|  | case VertexFormat::kFloat32: | 
|  | return LoadPrimitive(array_base, offset, buffer, format); | 
|  |  | 
|  | // Vectors of basic primitives | 
|  | case VertexFormat::kUint32x2: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | 
|  | VertexFormat::kUint32, 2); | 
|  | case VertexFormat::kUint32x3: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | 
|  | VertexFormat::kUint32, 3); | 
|  | case VertexFormat::kUint32x4: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | 
|  | VertexFormat::kUint32, 4); | 
|  | case VertexFormat::kSint32x2: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | 
|  | VertexFormat::kSint32, 2); | 
|  | case VertexFormat::kSint32x3: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | 
|  | VertexFormat::kSint32, 3); | 
|  | case VertexFormat::kSint32x4: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | 
|  | VertexFormat::kSint32, 4); | 
|  | case VertexFormat::kFloat32x2: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | 
|  | VertexFormat::kFloat32, 2); | 
|  | case VertexFormat::kFloat32x3: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | 
|  | VertexFormat::kFloat32, 3); | 
|  | case VertexFormat::kFloat32x4: | 
|  | return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | 
|  | VertexFormat::kFloat32, 4); | 
|  |  | 
|  | case VertexFormat::kUint8x2: { | 
|  | // yyxx0000, yyxx0000 | 
|  | auto* u16s = ctx.dst->vec2<u32>(load_u16_h()); | 
|  | // xx000000, yyxx0000 | 
|  | auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u)); | 
|  | // 000000xx, 000000yy | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u)); | 
|  | } | 
|  | case VertexFormat::kUint8x4: { | 
|  | // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx | 
|  | auto* u32s = ctx.dst->vec4<u32>(load_u32()); | 
|  | // xx000000, yyxx0000, zzyyxx00, wwzzyyxx | 
|  | auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u)); | 
|  | // 000000xx, 000000yy, 000000zz, 000000ww | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u)); | 
|  | } | 
|  | case VertexFormat::kUint16x2: { | 
|  | // yyyyxxxx, yyyyxxxx | 
|  | auto* u32s = ctx.dst->vec2<u32>(load_u32()); | 
|  | // xxxx0000, yyyyxxxx | 
|  | auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u)); | 
|  | // 0000xxxx, 0000yyyy | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u)); | 
|  | } | 
|  | case VertexFormat::kUint16x4: { | 
|  | // yyyyxxxx, wwwwzzzz | 
|  | auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32()); | 
|  | // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz | 
|  | auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy"); | 
|  | // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz | 
|  | auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u)); | 
|  | // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u)); | 
|  | } | 
|  | case VertexFormat::kSint8x2: { | 
|  | // yyxx0000, yyxx0000 | 
|  | auto* i16s = ctx.dst->vec2<i32>(load_i16_h()); | 
|  | // xx000000, yyxx0000 | 
|  | auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u)); | 
|  | // ssssssxx, ssssssyy | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u)); | 
|  | } | 
|  | case VertexFormat::kSint8x4: { | 
|  | // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx | 
|  | auto* i32s = ctx.dst->vec4<i32>(load_i32()); | 
|  | // xx000000, yyxx0000, zzyyxx00, wwzzyyxx | 
|  | auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u)); | 
|  | // ssssssxx, ssssssyy, sssssszz, ssssssww | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u)); | 
|  | } | 
|  | case VertexFormat::kSint16x2: { | 
|  | // yyyyxxxx, yyyyxxxx | 
|  | auto* i32s = ctx.dst->vec2<i32>(load_i32()); | 
|  | // xxxx0000, yyyyxxxx | 
|  | auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u)); | 
|  | // ssssxxxx, ssssyyyy | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u)); | 
|  | } | 
|  | case VertexFormat::kSint16x4: { | 
|  | // yyyyxxxx, wwwwzzzz | 
|  | auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32()); | 
|  | // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz | 
|  | auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy"); | 
|  | // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz | 
|  | auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u)); | 
|  | // ssssxxxx, ssssyyyy, sssszzzz, sssswwww | 
|  | return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u)); | 
|  | } | 
|  | case VertexFormat::kUnorm8x2: | 
|  | return ctx.dst->MemberAccessor( | 
|  | ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy"); | 
|  | case VertexFormat::kSnorm8x2: | 
|  | return ctx.dst->MemberAccessor( | 
|  | ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy"); | 
|  | case VertexFormat::kUnorm8x4: | 
|  | return ctx.dst->Call("unpack4x8unorm", load_u32()); | 
|  | case VertexFormat::kSnorm8x4: | 
|  | return ctx.dst->Call("unpack4x8snorm", load_u32()); | 
|  | case VertexFormat::kUnorm16x2: | 
|  | return ctx.dst->Call("unpack2x16unorm", load_u32()); | 
|  | case VertexFormat::kSnorm16x2: | 
|  | return ctx.dst->Call("unpack2x16snorm", load_u32()); | 
|  | case VertexFormat::kFloat16x2: | 
|  | return ctx.dst->Call("unpack2x16float", load_u32()); | 
|  | case VertexFormat::kUnorm16x4: | 
|  | return ctx.dst->vec4<f32>( | 
|  | ctx.dst->Call("unpack2x16unorm", load_u32()), | 
|  | ctx.dst->Call("unpack2x16unorm", load_next_u32())); | 
|  | case VertexFormat::kSnorm16x4: | 
|  | return ctx.dst->vec4<f32>( | 
|  | ctx.dst->Call("unpack2x16snorm", load_u32()), | 
|  | ctx.dst->Call("unpack2x16snorm", load_next_u32())); | 
|  | case VertexFormat::kFloat16x4: | 
|  | return ctx.dst->vec4<f32>( | 
|  | ctx.dst->Call("unpack2x16float", load_u32()), | 
|  | ctx.dst->Call("unpack2x16float", load_next_u32())); | 
|  | } | 
|  |  | 
|  | TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | 
|  | << "format " << static_cast<int>(format); | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | /// Generates an expression reading an aligned basic type (u32, i32, f32) from | 
|  | /// a vertex buffer. | 
|  | /// @param array_base the symbol of the variable holding the base array offset | 
|  | /// of the vertex array (each index is 4-bytes). | 
|  | /// @param offset the byte offset of the data from `buffer_base` | 
|  | /// @param buffer the index of the vertex buffer | 
|  | /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or | 
|  | /// VertexFormat::kFloat32 | 
|  | ast::Expression* LoadPrimitive(Symbol array_base, | 
|  | uint32_t offset, | 
|  | uint32_t buffer, | 
|  | VertexFormat format) { | 
|  | ast::Expression* u32 = nullptr; | 
|  | if ((offset & 3) == 0) { | 
|  | // Aligned load. | 
|  |  | 
|  | ast ::Expression* index = nullptr; | 
|  | if (offset > 0) { | 
|  | index = ctx.dst->Add(array_base, offset / 4); | 
|  | } else { | 
|  | index = ctx.dst->Expr(array_base); | 
|  | } | 
|  | u32 = ctx.dst->IndexAccessor( | 
|  | ctx.dst->MemberAccessor(GetVertexBufferName(buffer), | 
|  | GetStructBufferName()), | 
|  | index); | 
|  |  | 
|  | } else { | 
|  | // Unaligned load | 
|  | uint32_t offset_aligned = offset & ~3u; | 
|  | auto* low = LoadPrimitive(array_base, offset_aligned, buffer, | 
|  | VertexFormat::kUint32); | 
|  | auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer, | 
|  | VertexFormat::kUint32); | 
|  |  | 
|  | uint32_t shift = 8u * (offset & 3u); | 
|  |  | 
|  | auto* low_shr = ctx.dst->Shr(low, shift); | 
|  | auto* high_shl = ctx.dst->Shl(high, 32u - shift); | 
|  | u32 = ctx.dst->Or(low_shr, high_shl); | 
|  | } | 
|  |  | 
|  | switch (format) { | 
|  | case VertexFormat::kUint32: | 
|  | return u32; | 
|  | case VertexFormat::kSint32: | 
|  | return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32); | 
|  | case VertexFormat::kFloat32: | 
|  | return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32); | 
|  | default: | 
|  | break; | 
|  | } | 
|  | TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | 
|  | << "invalid format for LoadPrimitive" << static_cast<int>(format); | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | /// Generates an expression reading a vec2/3/4 from a vertex buffer. | 
|  | /// @param array_base the symbol of the variable holding the base array offset | 
|  | /// of the vertex array (each index is 4-bytes). | 
|  | /// @param offset the byte offset of the data from `buffer_base` | 
|  | /// @param buffer the index of the vertex buffer | 
|  | /// @param element_stride stride between elements, in bytes | 
|  | /// @param base_type underlying AST type | 
|  | /// @param base_format underlying vertex format | 
|  | /// @param count how many elements the vector has | 
|  | ast::Expression* LoadVec(Symbol array_base, | 
|  | uint32_t offset, | 
|  | uint32_t buffer, | 
|  | uint32_t element_stride, | 
|  | ast::Type* base_type, | 
|  | VertexFormat base_format, | 
|  | uint32_t count) { | 
|  | ast::ExpressionList expr_list; | 
|  | for (uint32_t i = 0; i < count; ++i) { | 
|  | // Offset read position by element_stride for each component | 
|  | uint32_t primitive_offset = offset + element_stride * i; | 
|  | expr_list.push_back( | 
|  | LoadPrimitive(array_base, primitive_offset, buffer, base_format)); | 
|  | } | 
|  |  | 
|  | return ctx.dst->create<ast::TypeConstructorExpression>( | 
|  | ctx.dst->create<ast::Vector>(base_type, count), std::move(expr_list)); | 
|  | } | 
|  |  | 
|  | /// Process a non-struct entry point parameter. | 
|  | /// Generate function-scope variables for location parameters, and record | 
|  | /// vertex_index and instance_index builtins if present. | 
|  | /// @param func the entry point function | 
|  | /// @param param the parameter to process | 
|  | void ProcessNonStructParameter(ast::Function* func, ast::Variable* param) { | 
|  | if (auto* location = | 
|  | ast::GetDecoration<ast::LocationDecoration>(param->decorations())) { | 
|  | // Create a function-scope variable to replace the parameter. | 
|  | auto func_var_sym = ctx.Clone(param->symbol()); | 
|  | auto* func_var_type = ctx.Clone(param->type()); | 
|  | auto* func_var = ctx.dst->Var(func_var_sym, func_var_type); | 
|  | ctx.InsertFront(func->body()->statements(), ctx.dst->Decl(func_var)); | 
|  | // Capture mapping from location to the new variable. | 
|  | LocationInfo info; | 
|  | info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; | 
|  | info.type = ctx.src->Sem().Get(param)->Type(); | 
|  | location_info[location->value()] = info; | 
|  | } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>( | 
|  | param->decorations())) { | 
|  | // Check for existing vertex_index and instance_index builtins. | 
|  | if (builtin->value() == ast::Builtin::kVertexIndex) { | 
|  | vertex_index_expr = [this, param]() { | 
|  | return ctx.dst->Expr(ctx.Clone(param->symbol())); | 
|  | }; | 
|  | } else if (builtin->value() == ast::Builtin::kInstanceIndex) { | 
|  | instance_index_expr = [this, param]() { | 
|  | return ctx.dst->Expr(ctx.Clone(param->symbol())); | 
|  | }; | 
|  | } | 
|  | new_function_parameters.push_back(ctx.Clone(param)); | 
|  | } else { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "Invalid entry point parameter"; | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Process a struct entry point parameter. | 
|  | /// If the struct has members with location attributes, push the parameter to | 
|  | /// a function-scope variable and create a new struct parameter without those | 
|  | /// attributes. Record expressions for members that are vertex_index and | 
|  | /// instance_index builtins. | 
|  | /// @param func the entry point function | 
|  | /// @param param the parameter to process | 
|  | /// @param struct_ty the structure type | 
|  | void ProcessStructParameter(ast::Function* func, | 
|  | ast::Variable* param, | 
|  | const ast::Struct* struct_ty) { | 
|  | auto param_sym = ctx.Clone(param->symbol()); | 
|  |  | 
|  | // Process the struct members. | 
|  | bool has_locations = false; | 
|  | ast::StructMemberList members_to_clone; | 
|  | for (auto* member : struct_ty->members()) { | 
|  | auto member_sym = ctx.Clone(member->symbol()); | 
|  | std::function<ast::Expression*()> member_expr = [this, param_sym, | 
|  | member_sym]() { | 
|  | return ctx.dst->MemberAccessor(param_sym, member_sym); | 
|  | }; | 
|  |  | 
|  | if (auto* location = ast::GetDecoration<ast::LocationDecoration>( | 
|  | member->decorations())) { | 
|  | // Capture mapping from location to struct member. | 
|  | LocationInfo info; | 
|  | info.expr = member_expr; | 
|  | info.type = ctx.src->Sem().Get(member)->Type(); | 
|  | location_info[location->value()] = info; | 
|  | has_locations = true; | 
|  | } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>( | 
|  | member->decorations())) { | 
|  | // Check for existing vertex_index and instance_index builtins. | 
|  | if (builtin->value() == ast::Builtin::kVertexIndex) { | 
|  | vertex_index_expr = member_expr; | 
|  | } else if (builtin->value() == ast::Builtin::kInstanceIndex) { | 
|  | instance_index_expr = member_expr; | 
|  | } | 
|  | members_to_clone.push_back(member); | 
|  | } else { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "Invalid entry point parameter"; | 
|  | } | 
|  | } | 
|  |  | 
|  | if (!has_locations) { | 
|  | // Nothing to do. | 
|  | new_function_parameters.push_back(ctx.Clone(param)); | 
|  | return; | 
|  | } | 
|  |  | 
|  | // Create a function-scope variable to replace the parameter. | 
|  | auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type())); | 
|  | ctx.InsertFront(func->body()->statements(), ctx.dst->Decl(func_var)); | 
|  |  | 
|  | if (!members_to_clone.empty()) { | 
|  | // Create a new struct without the location attributes. | 
|  | ast::StructMemberList new_members; | 
|  | for (auto* member : members_to_clone) { | 
|  | auto member_sym = ctx.Clone(member->symbol()); | 
|  | auto* member_type = ctx.Clone(member->type()); | 
|  | auto member_decos = ctx.Clone(member->decorations()); | 
|  | new_members.push_back( | 
|  | ctx.dst->Member(member_sym, member_type, std::move(member_decos))); | 
|  | } | 
|  | auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members); | 
|  |  | 
|  | // Create a new function parameter with this struct. | 
|  | auto* new_param = | 
|  | ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct)); | 
|  | new_function_parameters.push_back(new_param); | 
|  |  | 
|  | // Copy values from the new parameter to the function-scope variable. | 
|  | for (auto* member : members_to_clone) { | 
|  | auto member_name = ctx.Clone(member->symbol()); | 
|  | ctx.InsertFront( | 
|  | func->body()->statements(), | 
|  | ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name), | 
|  | ctx.dst->MemberAccessor(new_param, member_name))); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Process an entry point function. | 
|  | /// @param func the entry point function | 
|  | void Process(ast::Function* func) { | 
|  | if (func->body()->empty()) { | 
|  | return; | 
|  | } | 
|  |  | 
|  | // Process entry point parameters. | 
|  | for (auto* param : func->params()) { | 
|  | auto* sem = ctx.src->Sem().Get(param); | 
|  | if (auto* str = sem->Type()->As<sem::Struct>()) { | 
|  | ProcessStructParameter(func, param, str->Declaration()); | 
|  | } else { | 
|  | ProcessNonStructParameter(func, param); | 
|  | } | 
|  | } | 
|  |  | 
|  | // Insert new parameters for vertex_index and instance_index if needed. | 
|  | if (!vertex_index_expr) { | 
|  | for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | 
|  | if (layout.step_mode == VertexStepMode::kVertex) { | 
|  | auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); | 
|  | new_function_parameters.push_back( | 
|  | ctx.dst->Param(name, ctx.dst->ty.u32(), | 
|  | {ctx.dst->Builtin(ast::Builtin::kVertexIndex)})); | 
|  | vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; | 
|  | break; | 
|  | } | 
|  | } | 
|  | } | 
|  | if (!instance_index_expr) { | 
|  | for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | 
|  | if (layout.step_mode == VertexStepMode::kInstance) { | 
|  | auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); | 
|  | new_function_parameters.push_back( | 
|  | ctx.dst->Param(name, ctx.dst->ty.u32(), | 
|  | {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)})); | 
|  | instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; | 
|  | break; | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // Generate vertex pulling preamble. | 
|  | if (auto* block = CreateVertexPullingPreamble()) { | 
|  | ctx.InsertFront(func->body()->statements(), block); | 
|  | } | 
|  |  | 
|  | // Rewrite the function header with the new parameters. | 
|  | auto func_sym = ctx.Clone(func->symbol()); | 
|  | auto* ret_type = ctx.Clone(func->return_type()); | 
|  | auto* body = ctx.Clone(func->body()); | 
|  | auto decos = ctx.Clone(func->decorations()); | 
|  | auto ret_decos = ctx.Clone(func->return_type_decorations()); | 
|  | auto* new_func = ctx.dst->create<ast::Function>( | 
|  | func->source(), func_sym, new_function_parameters, ret_type, body, | 
|  | std::move(decos), std::move(ret_decos)); | 
|  | ctx.Replace(func, new_func); | 
|  | } | 
|  | }; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | VertexPulling::VertexPulling() = default; | 
|  | VertexPulling::~VertexPulling() = default; | 
|  |  | 
|  | void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { | 
|  | auto cfg = cfg_; | 
|  | if (auto* cfg_data = inputs.Get<Config>()) { | 
|  | cfg = *cfg_data; | 
|  | } | 
|  |  | 
|  | // Find entry point | 
|  | auto* func = ctx.src->AST().Functions().Find( | 
|  | ctx.src->Symbols().Get(cfg.entry_point_name), | 
|  | ast::PipelineStage::kVertex); | 
|  | if (func == nullptr) { | 
|  | ctx.dst->Diagnostics().add_error(diag::System::Transform, | 
|  | "Vertex stage entry point not found"); | 
|  | return; | 
|  | } | 
|  |  | 
|  | // TODO(idanr): Need to check shader locations in descriptor cover all | 
|  | // attributes | 
|  |  | 
|  | // TODO(idanr): Make sure we covered all error cases, to guarantee the | 
|  | // following stages will pass | 
|  |  | 
|  | State state{ctx, cfg}; | 
|  | state.AddVertexStorageBuffers(); | 
|  | state.Process(func); | 
|  |  | 
|  | ctx.Clone(); | 
|  | } | 
|  |  | 
|  | VertexPulling::Config::Config() = default; | 
|  | VertexPulling::Config::Config(const Config&) = default; | 
|  | VertexPulling::Config::~Config() = default; | 
|  | VertexPulling::Config& VertexPulling::Config::operator=(const Config&) = | 
|  | default; | 
|  |  | 
|  | VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; | 
|  |  | 
|  | VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor( | 
|  | uint32_t in_array_stride, | 
|  | VertexStepMode in_step_mode, | 
|  | std::vector<VertexAttributeDescriptor> in_attributes) | 
|  | : array_stride(in_array_stride), | 
|  | step_mode(in_step_mode), | 
|  | attributes(std::move(in_attributes)) {} | 
|  |  | 
|  | VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor( | 
|  | const VertexBufferLayoutDescriptor& other) = default; | 
|  |  | 
|  | VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=( | 
|  | const VertexBufferLayoutDescriptor& other) = default; | 
|  |  | 
|  | VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default; | 
|  |  | 
|  | }  // namespace transform | 
|  | }  // namespace tint |