tint: Replace all remaining AST types with ast::Type
This CL removes the following AST nodes:
* ast::Array
* ast::Atomic
* ast::Matrix
* ast::MultisampledTexture
* ast::Pointer
* ast::SampledTexture
* ast::Texture
* ast::TypeName
* ast::Vector
ast::Type, which used to be the base class for all AST types, is now a
thin wrapper around ast::IdentifierExpression. All types are now
referred to using their type name.
The resolver now handles type resolution and validation of the types
listed above based on the TemplateIdentifier arguments.
Other changes:
* ProgramBuilder has undergone substantial refactoring.
* ProgramBuilder helpers for type inferencing is now more explicit.
Instead of passing 'nullptr', a new 'Infer' template argument is
passed.
* ast::CheckIdentifier() is used for more tests that check identifiers,
including types.
Bug: tint:1810
Change-Id: I8e739ef49435dc1c20a462f3ec5ba265661a7edb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118723
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/transform/add_block_attribute.cc b/src/tint/transform/add_block_attribute.cc
index e2a0c00..2b1ec31 100644
--- a/src/tint/transform/add_block_attribute.cc
+++ b/src/tint/transform/add_block_attribute.cc
@@ -78,7 +78,7 @@
ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret);
return ret;
});
- ctx.Replace(global->type, b.ty.Of(wrapper));
+ ctx.Replace(global->type.expr, b.Expr(wrapper->name->symbol));
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc
index 470826d..b657980 100644
--- a/src/tint/transform/binding_remapper.cc
+++ b/src/tint/transform/binding_remapper.cc
@@ -138,7 +138,7 @@
return Program(std::move(b));
}
auto* ty = sem->Type()->UnwrapRef();
- const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
+ auto inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->name->symbol),
inner_ty, var->declared_address_space, ac,
ctx.Clone(var->initializer), ctx.Clone(var->attributes));
diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc
index 3d647a7..5baef8a 100644
--- a/src/tint/transform/builtin_polyfill.cc
+++ b/src/tint/transform/builtin_polyfill.cc
@@ -176,7 +176,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
+ auto U = [&]() {
if (width == 1) {
return b.ty.u32();
}
@@ -234,7 +234,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
+ auto U = [&]() {
if (width == 1) {
return b.ty.u32();
}
@@ -351,7 +351,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
+ auto U = [&]() {
if (width == 1) {
return b.ty.u32();
}
@@ -423,7 +423,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() -> const ast::Type* {
+ auto U = [&]() {
if (width == 1) {
return b.ty.u32();
}
@@ -825,7 +825,7 @@
bool has_full_ptr_params;
/// @returns the AST type for the given sem type
- const ast::Type* T(const type::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
+ ast::Type T(const type::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
/// @returns 1 if `ty` is not a vector, otherwise the vector width
uint32_t WidthOf(const type::Type* ty) const {
@@ -1067,15 +1067,17 @@
break;
}
},
- [&](const ast::TypeName* type_name) {
+ [&](const ast::Expression* expr) {
if (polyfill.bgra8unorm) {
- if (auto* tex = src->Sem().Get<type::StorageTexture>(type_name)) {
- if (tex->texel_format() == type::TexelFormat::kBgra8Unorm) {
- ctx.Replace(type_name, [&ctx, tex] {
- return ctx.dst->ty.storage_texture(
- tex->dim(), type::TexelFormat::kRgba8Unorm, tex->access());
- });
- made_changes = true;
+ if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
+ if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
+ if (tex->texel_format() == type::TexelFormat::kBgra8Unorm) {
+ ctx.Replace(expr, [&ctx, tex] {
+ return ctx.dst->Expr(ctx.dst->ty.storage_texture(
+ tex->dim(), type::TexelFormat::kRgba8Unorm, tex->access()));
+ });
+ made_changes = true;
+ }
}
}
}
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index 028f8c5..5ed6547 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -104,10 +104,10 @@
auto get_buffer_size_intrinsic = [&](const type::Reference* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = b.Sym();
- auto* type = CreateASTTypeFor(ctx, buffer_type);
+ auto type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
- b.AST().AddFunction(b.create<ast::Function>(
- b.Ident(name),
+ b.Func(
+ name,
utils::Vector{
b.Param("buffer",
b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
@@ -117,8 +117,7 @@
b.ty.void_(), nullptr,
utils::Vector{
b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
- },
- utils::Empty));
+ });
return name;
});
diff --git a/src/tint/transform/canonicalize_entry_point_io.cc b/src/tint/transform/canonicalize_entry_point_io.cc
index f51d4cc..81c2821 100644
--- a/src/tint/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/transform/canonicalize_entry_point_io.cc
@@ -130,7 +130,7 @@
/// The name of the output value.
std::string name;
/// The type of the output value.
- const ast::Type* type;
+ ast::Type type;
/// The shader IO attributes.
utils::Vector<const ast::Attribute*, 2> attributes;
/// The value itself.
@@ -210,7 +210,7 @@
const type::Type* type,
std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attributes) {
- auto* ast_type = CreateASTTypeFor(ctx, type);
+ auto ast_type = CreateASTTypeFor(ctx, type);
if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
// Vulkan requires that integer user-defined fragment inputs are always decorated with
// `Flat`. See:
@@ -524,7 +524,7 @@
// Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name);
- auto* type = outval.type;
+ ast::Type type = outval.type;
const ast::Expression* lhs = ctx.dst->Expr(name);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
@@ -606,7 +606,7 @@
auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body.
- std::function<const ast::Type*()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
+ std::function<ast::Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
if (func_sem->ReturnType()->Is<type::Void>()) {
// The function call is just a statement with no result.
wrapper_body.Push(ctx.dst->CallStmt(call_inner));
@@ -665,7 +665,7 @@
}
auto* wrapper_func = ctx.dst->create<ast::Function>(
- ctx.dst->Ident(name), wrapper_ep_parameters, wrapper_ret_type(),
+ ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
}
@@ -727,7 +727,7 @@
/// WGSL expects
const ast::Expression* FromGLSLBuiltin(ast::BuiltinValue builtin,
const ast::Expression* value,
- const ast::Type*& ast_type) {
+ ast::Type& ast_type) {
switch (builtin) {
case ast::BuiltinValue::kVertexIndex:
case ast::BuiltinValue::kInstanceIndex:
diff --git a/src/tint/transform/clamp_frag_depth.cc b/src/tint/transform/clamp_frag_depth.cc
index 2e6ef2d..0b8d8f7 100644
--- a/src/tint/transform/clamp_frag_depth.cc
+++ b/src/tint/transform/clamp_frag_depth.cc
@@ -22,7 +22,6 @@
#include "src/tint/ast/function.h"
#include "src/tint/ast/module.h"
#include "src/tint/ast/struct.h"
-#include "src/tint/ast/type.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
@@ -163,10 +162,9 @@
// }
auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()->Declaration();
auto helper = io_structs_clamp_helpers.GetOrCreate(struct_ty, [&] {
- auto* return_ty = fn->return_type;
+ auto return_ty = fn->return_type;
auto fn_sym =
- b.Symbols().New("clamp_frag_depth_" +
- sym.NameFor(return_ty->As<ast::TypeName>()->name->symbol));
+ b.Symbols().New("clamp_frag_depth_" + sym.NameFor(struct_ty->name->symbol));
utils::Vector<const ast::Expression*, 8u> initializer_args;
for (auto* member : struct_ty->members) {
diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc
index c3593ef..7b9f9fd 100644
--- a/src/tint/transform/combine_samplers.cc
+++ b/src/tint/transform/combine_samplers.cc
@@ -115,7 +115,7 @@
if (it != binding_info->binding_map.end()) {
name = it->second;
}
- const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
@@ -124,7 +124,7 @@
/// @param kind the sampler kind to create for
/// @returns the newly-created global variable
const ast::Variable* CreatePlaceholder(type::SamplerKind kind) {
- const ast::Type* type = ctx.dst->ty.sampler(kind);
+ ast::Type type = ctx.dst->ty.sampler(kind);
const char* name = kind == type::SamplerKind::kComparisonSampler
? "placeholder_comparison_sampler"
: "placeholder_sampler";
@@ -132,18 +132,17 @@
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
- /// Creates ast::Type for a given texture and sampler variable pair.
+ /// Creates ast::Identifier for a given texture and sampler variable pair.
/// Depth textures with no samplers are turned into the corresponding
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
/// @param texture the texture variable of interest
/// @param sampler the texture variable of interest
/// @returns the newly-created type
- const ast::Type* CreateCombinedASTTypeFor(const sem::Variable* texture,
- const sem::Variable* sampler) {
+ ast::Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
const type::Type* texture_type = texture->Type()->UnwrapRef();
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
if (depth && !sampler) {
- return ctx.dst->create<ast::SampledTexture>(depth->dim(), ctx.dst->ty.f32());
+ return ctx.dst->ty.sampled_texture(depth->dim(), ctx.dst->ty.f32());
} else {
return CreateASTTypeFor(ctx, texture_type);
}
@@ -158,7 +157,7 @@
// by combined samplers.
for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* global_sem = sem.Get(global)->As<sem::GlobalVariable>();
- auto* type = sem.Get(global->type);
+ auto* type = ctx.src->TypeOf(global->type);
if (tint::IsAnyOf<type::Texture, type::Sampler>(type) &&
!type->Is<type::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
@@ -199,7 +198,7 @@
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
- auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
+ ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var);
function_combined_texture_samplers_[fn][pair] = var;
@@ -215,7 +214,7 @@
// Create a new function signature that differs only in the parameter
// list.
auto name = ctx.Clone(ast_fn->name);
- auto* return_type = ctx.Clone(ast_fn->return_type);
+ auto return_type = ctx.Clone(ast_fn->return_type);
auto* body = ctx.Clone(ast_fn->body);
auto attributes = ctx.Clone(ast_fn->attributes);
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
@@ -276,8 +275,7 @@
args.Push(ctx.Clone(arg));
}
}
- const ast::Expression* value =
- ctx.dst->Call(ctx.Clone(expr->target.name), args);
+ const ast::Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
@@ -329,7 +327,7 @@
args.Push(ctx.Clone(arg));
}
}
- return ctx.dst->Call(ctx.Clone(expr->target.name), args);
+ return ctx.dst->Call(ctx.Clone(expr->target), args);
}
}
return nullptr;
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index fc2003c..072c9c6 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -23,7 +23,6 @@
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/call_statement.h"
#include "src/tint/ast/disable_validation_attribute.h"
-#include "src/tint/ast/type_name.h"
#include "src/tint/ast/unary_op.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
@@ -481,15 +480,12 @@
auto name = b.Sym();
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, address_space, el_ty)) {
- auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
- auto* func = b.create<ast::Function>(
- b.Ident(name), params, el_ast_ty, nullptr,
- utils::Vector{
- intrinsic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- utils::Empty);
- b.AST().AddFunction(func);
+ auto el_ast_ty = CreateASTTypeFor(ctx, el_ty);
+ b.Func(name, params, el_ast_ty, nullptr,
+ utils::Vector{
+ intrinsic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ });
} else if (auto* arr_ty = el_ty->As<type::Array>()) {
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
// var arr : array<T, N>;
@@ -581,14 +577,11 @@
auto name = b.Sym();
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, address_space, el_ty)) {
- auto* func = b.create<ast::Function>(
- b.Ident(name), params, b.ty.void_(), nullptr,
- utils::Vector{
- intrinsic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- utils::Empty);
- b.AST().AddFunction(func);
+ b.Func(name, params, b.ty.void_(), nullptr,
+ utils::Vector{
+ intrinsic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ });
} else {
auto body = Switch<utils::Vector<const ast::Statement*, 8>>(
el_ty, //
@@ -695,7 +688,7 @@
// Other parameters are copied as-is:
for (size_t i = 1; i < intrinsic->Parameters().Length(); i++) {
auto* param = intrinsic->Parameters()[i];
- auto* ty = CreateASTTypeFor(ctx, param->Type());
+ auto ty = CreateASTTypeFor(ctx, param->Type());
params.Push(b.Param("param_" + std::to_string(i), ty));
}
@@ -706,7 +699,7 @@
<< el_ty->TypeInfo().name;
}
- const ast::Type* ret_ty = nullptr;
+ ast::Type ret_ty;
// For intrinsics that return a struct, there is no AST node for it, so create one now.
if (intrinsic->Type() == sem::BuiltinType::kAtomicCompareExchangeWeak) {
@@ -727,17 +720,13 @@
ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
}
- auto* func = b.create<ast::Function>(
- b.Ident(b.Symbols().New(std::string{"tint_"} + intrinsic->str())), params, ret_ty,
- nullptr,
- utils::Vector{
- atomic,
- b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
- },
- utils::Empty);
-
- b.AST().AddFunction(func);
- return func->name->symbol;
+ auto name = b.Symbols().New(std::string{"tint_"} + intrinsic->str());
+ b.Func(name, std::move(params), ret_ty, nullptr,
+ utils::Vector{
+ atomic,
+ b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
+ });
+ return name;
});
}
};
diff --git a/src/tint/transform/decompose_memory_access_test.cc b/src/tint/transform/decompose_memory_access_test.cc
index b58d340..f0f441f 100644
--- a/src/tint/transform/decompose_memory_access_test.cc
+++ b/src/tint/transform/decompose_memory_access_test.cc
@@ -1584,16 +1584,16 @@
}
fn tint_symbol_34(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<vec3<f32>, 2u>) {
- var array = value;
+ var array_1 = value;
for(var i = 0u; (i < 2u); i = (i + 1u)) {
- tint_symbol_8(buffer, (offset + (i * 16u)), array[i]);
+ tint_symbol_8(buffer, (offset + (i * 16u)), array_1[i]);
}
}
fn tint_symbol_35(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<mat4x2<f16>, 2u>) {
- var array_1 = value;
+ var array_2 = value;
for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
- tint_symbol_31(buffer, (offset + (i_1 * 16u)), array_1[i_1]);
+ tint_symbol_31(buffer, (offset + (i_1 * 16u)), array_2[i_1]);
}
}
@@ -1889,16 +1889,16 @@
}
fn tint_symbol_34(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<vec3<f32>, 2u>) {
- var array = value;
+ var array_1 = value;
for(var i = 0u; (i < 2u); i = (i + 1u)) {
- tint_symbol_8(buffer, (offset + (i * 16u)), array[i]);
+ tint_symbol_8(buffer, (offset + (i * 16u)), array_1[i]);
}
}
fn tint_symbol_35(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<mat4x2<f16>, 2u>) {
- var array_1 = value;
+ var array_2 = value;
for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
- tint_symbol_31(buffer, (offset + (i_1 * 16u)), array_1[i_1]);
+ tint_symbol_31(buffer, (offset + (i_1 * 16u)), array_2[i_1]);
}
}
@@ -2733,16 +2733,16 @@
}
fn tint_symbol_35(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<vec3<f32>, 2u>) {
- var array = value;
+ var array_1 = value;
for(var i = 0u; (i < 2u); i = (i + 1u)) {
- tint_symbol_9(buffer, (offset + (i * 16u)), array[i]);
+ tint_symbol_9(buffer, (offset + (i * 16u)), array_1[i]);
}
}
fn tint_symbol_36(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<mat4x2<f16>, 2u>) {
- var array_1 = value;
+ var array_2 = value;
for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
- tint_symbol_32(buffer, (offset + (i_1 * 16u)), array_1[i_1]);
+ tint_symbol_32(buffer, (offset + (i_1 * 16u)), array_2[i_1]);
}
}
@@ -3007,16 +3007,16 @@
}
fn tint_symbol_35(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<vec3<f32>, 2u>) {
- var array = value;
+ var array_1 = value;
for(var i = 0u; (i < 2u); i = (i + 1u)) {
- tint_symbol_9(buffer, (offset + (i * 16u)), array[i]);
+ tint_symbol_9(buffer, (offset + (i * 16u)), array_1[i]);
}
}
fn tint_symbol_36(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read_write>, offset : u32, value : array<mat4x2<f16>, 2u>) {
- var array_1 = value;
+ var array_2 = value;
for(var i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
- tint_symbol_32(buffer, (offset + (i_1 * 16u)), array_1[i_1]);
+ tint_symbol_32(buffer, (offset + (i_1 * 16u)), array_2[i_1]);
}
}
diff --git a/src/tint/transform/decompose_strided_array.cc b/src/tint/transform/decompose_strided_array.cc
index 8602f14..e92b85e 100644
--- a/src/tint/transform/decompose_strided_array.cc
+++ b/src/tint/transform/decompose_strided_array.cc
@@ -21,6 +21,7 @@
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/member_accessor_expression.h"
+#include "src/tint/sem/type_expression.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/sem/value_expression.h"
#include "src/tint/transform/simplify_pointers.h"
@@ -36,8 +37,8 @@
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
- if (auto* ast = node->As<ast::Array>()) {
- if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
+ if (auto* ident = node->As<ast::TemplatedIdentifier>()) {
+ if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
return true;
}
}
@@ -73,27 +74,45 @@
// stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the
// array stride.
- ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
- if (auto* arr = sem.Get(ast)) {
- if (!arr->IsStrideImplicit()) {
- auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
- auto name = b.Symbols().New("strided_arr");
- auto* member_ty = ctx.Clone(ast->type);
- auto* member = b.Member(kMemberName, member_ty,
- utils::Vector{
- b.MemberSize(AInt(arr->Stride())),
- });
- b.Structure(name, utils::Vector{member});
- return name;
- });
- auto* count = ctx.Clone(ast->count);
- return b.ty.array(b.ty(el_ty), count);
+ ctx.ReplaceAll([&](const ast::IdentifierExpression* expr) -> const ast::IdentifierExpression* {
+ auto* ident = expr->identifier->As<ast::TemplatedIdentifier>();
+ if (!ident) {
+ return nullptr;
+ }
+ auto* type_expr = sem.Get<sem::TypeExpression>(expr);
+ if (!type_expr) {
+ return nullptr;
+ }
+ auto* arr = type_expr->Type()->As<type::Array>();
+ if (!arr) {
+ return nullptr;
+ }
+ if (!arr->IsStrideImplicit()) {
+ auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
+ auto name = b.Symbols().New("strided_arr");
+ auto* member_ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
+ auto* member = b.Member(kMemberName, ast::Type{member_ty},
+ utils::Vector{
+ b.MemberSize(AInt(arr->Stride())),
+ });
+ b.Structure(name, utils::Vector{member});
+ return name;
+ });
+ if (ident->arguments.Length() > 1) {
+ auto* count = ctx.Clone(ident->arguments[1]);
+ return b.Expr(b.ty.array(b.ty(el_ty), count));
+ } else {
+ return b.Expr(b.ty.array(b.ty(el_ty)));
}
- if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
- // Strip the @stride attribute
- auto* ty = ctx.Clone(ast->type);
- auto* count = ctx.Clone(ast->count);
- return b.ty.array(ty, count);
+ }
+ if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
+ // Strip the @stride attribute
+ auto* ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
+ if (ident->arguments.Length() > 1) {
+ auto* count = ctx.Clone(ident->arguments[1]);
+ return b.Expr(b.ty.array(ast::Type{ty}, count));
+ } else {
+ return b.Expr(b.ty.array(ast::Type{ty}));
}
}
return nullptr;
@@ -133,12 +152,8 @@
// decomposed.
// If this is an aliased array, decomposed should already be
// populated with any strided aliases.
- ast::CallExpression::Target target;
- if (expr->target.type) {
- target.type = ctx.Clone(expr->target.type);
- } else {
- target.name = ctx.Clone(expr->target.name);
- }
+
+ auto* target = ctx.Clone(expr->target);
utils::Vector<const ast::Expression*, 8> args;
if (auto it = decomposed.find(arr); it != decomposed.end()) {
@@ -150,8 +165,7 @@
args = ctx.Clone(expr->args);
}
- return target.type ? b.Call(target.type, std::move(args))
- : b.Call(target.name, std::move(args));
+ return b.Call(target, std::move(args));
}
}
}
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 8b72ac2..4642c57 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -37,9 +37,8 @@
/// The type of the matrix
const type::Matrix* matrix = nullptr;
- /// @returns a new ast::Array that holds an vector column for each row of the
- /// matrix.
- const ast::Array* array(ProgramBuilder* b) const {
+ /// @returns the identifier of an array that holds an vector column for each row of the matrix.
+ ast::Type array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
utils::Vector{
b->Stride(stride),
diff --git a/src/tint/transform/demote_to_helper.cc b/src/tint/transform/demote_to_helper.cc
index 128894d..dc5384b 100644
--- a/src/tint/transform/demote_to_helper.cc
+++ b/src/tint/transform/demote_to_helper.cc
@@ -177,7 +177,7 @@
// }
// let y = x + tmp;
auto result = b.Sym();
- const ast::Type* result_ty = nullptr;
+ ast::Type result_ty;
const ast::Statement* masked_call = nullptr;
if (builtin->Type() == sem::BuiltinType::kAtomicCompareExchangeWeak) {
// Special case for atomicCompareExchangeWeak as we cannot name its
diff --git a/src/tint/transform/direct_variable_access.cc b/src/tint/transform/direct_variable_access.cc
index 3e89ca3..a8a1272 100644
--- a/src/tint/transform/direct_variable_access.cc
+++ b/src/tint/transform/direct_variable_access.cc
@@ -831,14 +831,14 @@
if (auto incoming_shape = variant_sig.Find(param)) {
auto& symbols = *variant.ptr_param_symbols.Find(param);
if (symbols.base_ptr.IsValid()) {
- auto* base_ptr_ty =
+ auto base_ptr_ty =
b.ty.pointer(CreateASTTypeFor(ctx, incoming_shape->root.type),
incoming_shape->root.address_space);
params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
}
if (symbols.indices.IsValid()) {
// Variant has dynamic indices for this variant, replace it.
- auto* dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape);
+ auto dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape);
params.Push(b.Param(symbols.indices, dyn_idx_arr_type));
}
} else {
@@ -850,7 +850,7 @@
// Build the variant by cloning the source function. The other clone callbacks will
// use clone_state->current_variant and clone_state->current_variant_sig to produce
// the variant.
- auto* ret_ty = ctx.Clone(fn->Declaration()->return_type);
+ auto ret_ty = ctx.Clone(fn->Declaration()->return_type);
auto body = ctx.Clone(fn->Declaration()->body);
auto attrs = ctx.Clone(fn->Declaration()->attributes);
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
@@ -912,7 +912,7 @@
}
// Get or create the dynamic indices array.
- if (auto* dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
+ if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
// Build an array of dynamic indices to pass as the replacement for the pointer.
utils::Vector<const ast::Expression*, 8> dyn_idx_args;
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
@@ -1064,7 +1064,7 @@
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
/// if this is the first call for the given shape.
- const ast::TypeName* DynamicIndexArrayType(const AccessShape& shape) {
+ ast::Type DynamicIndexArrayType(const AccessShape& shape) {
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
// Count the number of dynamic indices
uint32_t num_dyn_indices = shape.NumDynamicIndices();
@@ -1075,7 +1075,7 @@
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
return symbol;
});
- return name.IsValid() ? b.ty(name) : nullptr;
+ return name.IsValid() ? b.ty(name) : ast::Type{};
}
/// @returns a name describing the given shape
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc
index 285b95c..364087b 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc
@@ -145,7 +145,7 @@
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
- auto* param_type = store_type();
+ auto param_type = store_type();
if (auto* arr = ty->As<type::Array>();
arr && arr->Count()->Is<type::RuntimeArrayCount>()) {
// Wrap runtime-sized arrays in structures, so that we can declare pointers to
@@ -230,7 +230,7 @@
bool& is_pointer) {
auto* var_ast = var->Declaration()->As<ast::Var>();
auto* ty = var->Type()->UnwrapRef();
- auto* param_type = CreateASTTypeFor(ctx, ty);
+ auto param_type = CreateASTTypeFor(ctx, ty);
auto sc = var->AddressSpace();
switch (sc) {
case type::AddressSpace::kPrivate:
@@ -450,7 +450,7 @@
// The parameter is a struct that contains members for each workgroup variable.
auto* str =
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
- auto* param_type =
+ auto param_type =
ctx.dst->ty.pointer(ctx.dst->ty.Of(str), type::AddressSpace::kWorkgroup);
auto* param = ctx.dst->Param(
workgroup_param(), param_type,
@@ -463,8 +463,8 @@
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
- auto* target = ctx.src->AST().Functions().Find(call->target.name->symbol);
- auto* target_sem = ctx.src->Sem().Get(target);
+ auto* call_sem = ctx.src->Sem().Get(call)->Unwrap()->As<sem::Call>();
+ auto* target_sem = call_sem->Target()->As<sem::Function>();
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc
index 1bbbe22..06dfba4 100644
--- a/src/tint/transform/multiplanar_external_texture.cc
+++ b/src/tint/transform/multiplanar_external_texture.cc
@@ -34,8 +34,8 @@
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
- if (auto* ty = node->As<ast::Type>()) {
- if (program->Sem().Get<type::ExternalTexture>(ty)) {
+ if (auto* expr = node->As<ast::Expression>()) {
+ if (Is<type::ExternalTexture>(program->TypeOf(expr))) {
return true;
}
}
@@ -263,7 +263,8 @@
b.Member("gammaDecodeParams", b.ty("GammaTransferParams")),
b.Member("gammaEncodeParams", b.ty("GammaTransferParams")),
b.Member("gamutConversionMatrix", b.ty.mat3x3<f32>()),
- b.Member("coordTransformationMatrix", b.ty.mat3x2<f32>())};
+ b.Member("coordTransformationMatrix", b.ty.mat3x2<f32>()),
+ };
params_struct_sym = b.Symbols().New("ExternalTextureParams");
diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc
index 7d9e7cd..954069c 100644
--- a/src/tint/transform/pad_structs.cc
+++ b/src/tint/transform/pad_structs.cc
@@ -75,7 +75,7 @@
}
auto* ty = mem->Type();
- const ast::Type* type = CreateASTTypeFor(ctx, ty);
+ auto type = CreateASTTypeFor(ctx, ty);
new_members.Push(b.Member(name, type));
diff --git a/src/tint/transform/remove_phonies.cc b/src/tint/transform/remove_phonies.cc
index 3b0dabd..6131892 100644
--- a/src/tint/transform/remove_phonies.cc
+++ b/src/tint/transform/remove_phonies.cc
@@ -108,7 +108,7 @@
auto name = b.Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params;
for (auto* ty : sig) {
- auto* ast_ty = CreateASTTypeFor(ctx, ty);
+ auto ast_ty = CreateASTTypeFor(ctx, ty);
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
}
b.Func(name, params, b.ty.void_(), {});
diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc
index ec8dcee..f2aea29 100644
--- a/src/tint/transform/renamer.cc
+++ b/src/tint/transform/renamer.cc
@@ -22,6 +22,7 @@
#include "src/tint/sem/call.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_conversion.h"
+#include "src/tint/sem/type_expression.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/text/unicode.h"
@@ -1292,20 +1293,28 @@
[&](const ast::DiagnosticDirective* diagnostic) {
preserved_identifiers.Add(diagnostic->control.rule_name);
},
- [&](const ast::TypeName* ty) { preserve_if_builtin_type(ty->name); },
[&](const ast::IdentifierExpression* expr) {
- if (src->Sem().Get<sem::BuiltinEnumExpressionBase>(expr)) {
- preserved_identifiers.Add(expr->identifier);
- }
+ Switch(
+ src->Sem().Get(expr), //
+ [&](const sem::BuiltinEnumExpressionBase*) {
+ preserved_identifiers.Add(expr->identifier);
+ },
+ [&](const sem::TypeExpression*) {
+ preserve_if_builtin_type(expr->identifier);
+ });
},
[&](const ast::CallExpression* call) {
- if (auto* ident = call->target.name) {
- Switch(
- src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
- [&](const sem::Builtin*) { preserved_identifiers.Add(ident); },
- [&](const sem::TypeConversion*) { preserve_if_builtin_type(ident); },
- [&](const sem::TypeInitializer*) { preserve_if_builtin_type(ident); });
- }
+ Switch(
+ src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
+ [&](const sem::Builtin*) {
+ preserved_identifiers.Add(call->target->identifier);
+ },
+ [&](const sem::TypeConversion*) {
+ preserve_if_builtin_type(call->target->identifier);
+ },
+ [&](const sem::TypeInitializer*) {
+ preserve_if_builtin_type(call->target->identifier);
+ });
});
}
@@ -1369,7 +1378,7 @@
if (auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>()) {
auto args = ctx.Clone(tmpl_ident->arguments);
return ctx.dst->create<ast::TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
- std::move(args));
+ std::move(args), utils::Empty);
}
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), replacement);
});
diff --git a/src/tint/transform/renamer_test.cc b/src/tint/transform/renamer_test.cc
index 982a9a2..71cf2a2 100644
--- a/src/tint/transform/renamer_test.cc
+++ b/src/tint/transform/renamer_test.cc
@@ -1667,23 +1667,45 @@
/// @return all the identifiers parsed as keywords
std::unordered_set<std::string> Keywords() {
return {
+ "array",
+ "atomic",
"bool",
"f16",
"f32",
"i32",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "ptr",
"sampler_comparison",
"sampler",
+ "texture_1d",
+ "texture_2d_array",
+ "texture_2d",
+ "texture_3d",
+ "texture_cube_array",
+ "texture_cube",
"texture_depth_2d_array",
"texture_depth_2d",
"texture_depth_cube_array",
"texture_depth_cube",
"texture_depth_multisampled_2d",
"texture_external",
+ "texture_multisampled_2d",
"texture_storage_1d",
- "texture_storage_2d",
"texture_storage_2d_array",
+ "texture_storage_2d",
"texture_storage_3d",
"u32",
+ "vec2",
+ "vec3",
+ "vec4",
};
}
@@ -1802,6 +1824,30 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(RenamerBuiltinTypeTest, PreserveTypeExpression) {
+ auto src = R"(
+enable f16;
+
+@fragment
+fn f() {
+ var v : array<f32, 2> = array<f32, 2>();
+}
+)";
+
+ auto expect = R"(
+enable f16;
+
+@fragment
+fn tint_symbol() {
+ var tint_symbol_1 : array<f32, 2> = array<f32, 2>();
+}
+)";
+
+ auto got = Run<Renamer>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_P(RenamerBuiltinTypeTest, RenameShadowedByAlias) {
auto expand = [&](const char* source) {
auto out = utils::ReplaceAll(source, "$name", GetParam());
diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc
index 87403e8..579d684 100644
--- a/src/tint/transform/robustness.cc
+++ b/src/tint/transform/robustness.cc
@@ -182,7 +182,7 @@
}
return 1u;
};
- auto scalar_or_vec_ty = [&](const ast::Type* scalar, uint32_t width) -> const ast::Type* {
+ auto scalar_or_vec_ty = [&](ast::Type scalar, uint32_t width) {
if (width > 1) {
return b.ty.vec(scalar, width);
}
@@ -191,7 +191,7 @@
auto scalar_or_vec = [&](const ast::Expression* scalar,
uint32_t width) -> const ast::Expression* {
if (width > 1) {
- return b.Call(b.ty.vec(nullptr, width), scalar);
+ return b.Call(b.ty.vec<Infer>(width), scalar);
}
return scalar;
};
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index afbc5cb..b3924cc 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -131,7 +131,7 @@
for (size_t i = 0; i < str->members.Length(); i++) {
auto* member = str->members[i];
if (forked.atomic_members.count(i)) {
- auto* type = AtomicTypeFor(ctx.src->Sem().Get(member)->Type());
+ auto type = AtomicTypeFor(ctx.src->Sem().Get(member)->Type());
auto name = ctx.src->Symbols().NameFor(member->name->symbol);
members.Push(b.Member(name, type, ctx.Clone(member->attributes)));
} else {
@@ -169,7 +169,7 @@
[&](const sem::VariableUser* user) {
auto* v = user->Variable()->Declaration();
if (v->type && atomic_variables.emplace(user->Variable()).second) {
- ctx.Replace(v->type, AtomicTypeFor(user->Variable()->Type()));
+ ctx.Replace(v->type.expr, b.Expr(AtomicTypeFor(user->Variable()->Type())));
}
if (auto* ctor = user->Variable()->Initializer()) {
atomic_expressions.Add(ctor);
@@ -193,13 +193,13 @@
}
}
- const ast::Type* AtomicTypeFor(const type::Type* ty) {
+ ast::Type AtomicTypeFor(const type::Type* ty) {
return Switch(
ty, //
[&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const type::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::Struct* str) { return b.ty(Fork(str->Declaration()).name); },
- [&](const type::Array* arr) -> const ast::Type* {
+ [&](const type::Array* arr) {
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(AtomicTypeFor(arr->ElemType()));
}
@@ -221,7 +221,7 @@
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type: " << ty->FriendlyName(ctx.src->Symbols());
- return nullptr;
+ return ast::Type{};
});
}
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 825685e..4c487d0 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -308,7 +308,7 @@
continue; // Next member
}
- } else if (auto* std140_ty = Std140Type(member->Type())) {
+ } else if (auto std140_ty = Std140Type(member->Type())) {
// Member is of a type that requires forking for std140-layout
fork_std140 = true;
auto attrs = ctx.Clone(member->Declaration()->attributes);
@@ -352,8 +352,8 @@
if (auto* var = global->As<ast::Var>()) {
if (var->declared_address_space == type::AddressSpace::kUniform) {
auto* v = sem.Get(var);
- if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) {
- ctx.Replace(global->type, std140_ty);
+ if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
+ ctx.Replace(global->type.expr, b.Expr(std140_ty));
std140_uniforms.Add(v);
}
}
@@ -400,16 +400,16 @@
/// If the semantic type is not split for std140-layout, then nullptr is returned.
/// @note will construct new std140 structures to hold decomposed matrices, populating
/// #std140_mats.
- const ast::Type* Std140Type(const type::Type* ty) {
+ ast::Type Std140Type(const type::Type* ty) {
return Switch(
ty, //
- [&](const sem::Struct* str) -> const ast::Type* {
+ [&](const sem::Struct* str) {
if (auto std140 = std140_structs.Find(str)) {
return b.ty(*std140);
}
- return nullptr;
+ return ast::Type{};
},
- [&](const type::Matrix* mat) -> const ast::Type* {
+ [&](const type::Matrix* mat) {
if (MatrixNeedsDecomposing(mat)) {
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
@@ -426,10 +426,10 @@
});
return b.ty(std140_mat.name);
}
- return nullptr;
+ return ast::Type{};
},
- [&](const type::Array* arr) -> const ast::Type* {
- if (auto* std140 = Std140Type(arr->ElemType())) {
+ [&](const type::Array* arr) {
+ if (auto std140 = Std140Type(arr->ElemType())) {
utils::Vector<const ast::Attribute*, 1> attrs;
if (!arr->IsStrideImplicit()) {
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
@@ -444,10 +444,9 @@
<< "unexpected non-constant array count";
count = 1;
}
- return b.create<ast::Array>(std140, b.Expr(u32(count.value())),
- std::move(attrs));
+ return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
}
- return nullptr;
+ return ast::Type{};
});
}
@@ -483,7 +482,7 @@
// Build the member
const auto col_name = name_prefix + std::to_string(i);
- const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
+ const auto col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
const auto* col_member = b.Member(col_name, col_ty, std::move(attributes));
// Record the member for std140_mat_members
out.Push(col_member);
@@ -702,7 +701,7 @@
for (auto* member : str->Members()) {
if (auto col_members = std140_mat_members.Find(member)) {
// std140 decomposed matrix. Reassemble.
- auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
+ auto mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args =
utils::Transform(*col_members, [&](const ast::StructMember* m) {
return b.MemberAccessor(param, m->name->symbol);
@@ -723,7 +722,7 @@
if (TINT_LIKELY(std140_mat)) {
utils::Vector<const ast::Expression*, 8> args;
// std140 decomposed matrix. Reassemble.
- auto* mat_ty = CreateASTTypeFor(ctx, mat);
+ auto mat_ty = CreateASTTypeFor(ctx, mat);
auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
return b.MemberAccessor(param, name);
});
@@ -764,7 +763,7 @@
});
// Generate the function
- auto* ret_ty = CreateASTTypeFor(ctx, ty);
+ auto ret_ty = CreateASTTypeFor(ctx, ty);
auto fn_sym = b.Symbols().New("conv_" + ConvertSuffix(ty));
b.Func(fn_sym, utils::Vector{param}, ret_ty, std::move(stmts));
return fn_sym;
@@ -1046,7 +1045,7 @@
stmts.Push(b.Return(expr));
// Build the function
- auto* ret_ty = CreateASTTypeFor(ctx, ty);
+ auto ret_ty = CreateASTTypeFor(ctx, ty);
auto fn_sym = b.Symbols().New(name);
b.Func(fn_sym, std::move(dynamic_index_params), ret_ty, std::move(stmts));
return fn_sym;
diff --git a/src/tint/transform/substitute_override.cc b/src/tint/transform/substitute_override.cc
index f5a859b..2d6b995 100644
--- a/src/tint/transform/substitute_override.cc
+++ b/src/tint/transform/substitute_override.cc
@@ -64,7 +64,7 @@
auto source = ctx.Clone(w->source);
auto sym = ctx.Clone(w->name->symbol);
- auto* ty = ctx.Clone(w->type);
+ ast::Type ty = w->type ? ctx.Clone(w->type) : ast::Type{};
// No replacement provided, just clone the override node as a const.
auto iter = data->map.find(sem->OverrideId());
diff --git a/src/tint/transform/texture_1d_to_2d.cc b/src/tint/transform/texture_1d_to_2d.cc
index 84d480a..d1caaab 100644
--- a/src/tint/transform/texture_1d_to_2d.cc
+++ b/src/tint/transform/texture_1d_to_2d.cc
@@ -19,6 +19,7 @@
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
+#include "src/tint/sem/type_expression.h"
#include "src/tint/type/texture_dimension.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Texture1DTo2D);
@@ -46,7 +47,7 @@
}
for (auto* var : program->AST().GlobalVariables()) {
if (Switch(
- program->Sem().Get(var->type),
+ program->Sem().Get(var)->Type()->UnwrapRef(),
[&](const type::SampledTexture* tex) {
return tex->dim() == type::TextureDimension::k1d;
},
@@ -83,8 +84,7 @@
return SkipTransform;
}
- auto create_var = [&](const ast::Variable* v,
- const ast::Type* type) -> const ast::Variable* {
+ auto create_var = [&](const ast::Variable* v, ast::Type type) -> const ast::Variable* {
if (v->As<ast::Parameter>()) {
return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
} else {
@@ -94,11 +94,11 @@
ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* {
const ast::Variable* r = Switch(
- sem.Get(v->type),
+ sem.Get(v)->Type()->UnwrapRef(),
[&](const type::SampledTexture* tex) -> const ast::Variable* {
if (tex->dim() == type::TextureDimension::k1d) {
- auto* type = ctx.dst->create<ast::SampledTexture>(
- type::TextureDimension::k2d, CreateASTTypeFor(ctx, tex->type()));
+ auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
+ CreateASTTypeFor(ctx, tex->type()));
return create_var(v, type);
} else {
return nullptr;
@@ -106,9 +106,9 @@
},
[&](const type::StorageTexture* storage_tex) -> const ast::Variable* {
if (storage_tex->dim() == type::TextureDimension::k1d) {
- auto* type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
- storage_tex->texel_format(),
- storage_tex->access());
+ auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
+ storage_tex->texel_format(),
+ storage_tex->access());
return create_var(v, type);
} else {
return nullptr;
@@ -172,7 +172,7 @@
}
index++;
}
- return ctx.dst->Call(ctx.Clone(c->target.name), args);
+ return ctx.dst->Call(ctx.Clone(c->target), args);
});
ctx.Clone();
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index a2ae237..cc09e6c 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -73,9 +73,9 @@
<< "unable to remove statement from parent of type " << sem->TypeInfo().name;
}
-const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
+ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
if (ty->Is<type::Void>()) {
- return nullptr;
+ return ast::Type{};
}
if (ty->Is<type::I32>()) {
return ctx.dst->ty.i32();
@@ -93,21 +93,21 @@
return ctx.dst->ty.bool_();
}
if (auto* m = ty->As<type::Matrix>()) {
- auto* el = CreateASTTypeFor(ctx, m->type());
- return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
+ auto el = CreateASTTypeFor(ctx, m->type());
+ return ctx.dst->ty.mat(el, m->columns(), m->rows());
}
if (auto* v = ty->As<type::Vector>()) {
- auto* el = CreateASTTypeFor(ctx, v->type());
- return ctx.dst->create<ast::Vector>(el, v->Width());
+ auto el = CreateASTTypeFor(ctx, v->type());
+ return ctx.dst->ty.vec(el, v->Width());
}
if (auto* a = ty->As<type::Array>()) {
- auto* el = CreateASTTypeFor(ctx, a->ElemType());
+ auto el = CreateASTTypeFor(ctx, a->ElemType());
utils::Vector<const ast::Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
- return ctx.dst->ty.array(el, nullptr, std::move(attrs));
+ return ctx.dst->ty.array(el, std::move(attrs));
}
if (auto* override = a->Count()->As<sem::NamedOverrideArrayCount>()) {
auto* count = ctx.Clone(override->variable->Declaration());
@@ -144,7 +144,7 @@
return CreateASTTypeFor(ctx, s->StoreType());
}
if (auto* a = ty->As<type::Atomic>()) {
- return ctx.dst->create<ast::Atomic>(CreateASTTypeFor(ctx, a->Type()));
+ return ctx.dst->ty.atomic(CreateASTTypeFor(ctx, a->Type()));
}
if (auto* t = ty->As<type::DepthTexture>()) {
return ctx.dst->ty.depth_texture(t->dim());
@@ -156,11 +156,10 @@
return ctx.dst->ty.external_texture();
}
if (auto* t = ty->As<type::MultisampledTexture>()) {
- return ctx.dst->create<ast::MultisampledTexture>(t->dim(),
- CreateASTTypeFor(ctx, t->type()));
+ return ctx.dst->ty.multisampled_texture(t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<type::SampledTexture>()) {
- return ctx.dst->create<ast::SampledTexture>(t->dim(), CreateASTTypeFor(ctx, t->type()));
+ return ctx.dst->ty.sampled_texture(t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<type::StorageTexture>()) {
return ctx.dst->ty.storage_texture(t->dim(), t->texel_format(), t->access());
@@ -170,7 +169,7 @@
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name;
- return nullptr;
+ return ast::Type{};
}
} // namespace tint::transform
diff --git a/src/tint/transform/transform.h b/src/tint/transform/transform.h
index cc2f782..dfb9188 100644
--- a/src/tint/transform/transform.h
+++ b/src/tint/transform/transform.h
@@ -189,13 +189,11 @@
/// @param stmt the statement to remove when the program is cloned
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
- /// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
- /// semantic type `ty`.
+ /// CreateASTTypeFor constructs new ast::Type that reconstructs the semantic type `ty`.
/// @param ctx the clone context
/// @param ty the semantic type to reconstruct
- /// @returns a ast::Type that when resolved, will produce the semantic type
- /// `ty`.
- static const ast::Type* CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
+ /// @returns an ast::Type that when resolved, will produce the semantic type `ty`.
+ static ast::Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
};
} // namespace tint::transform
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index 6d97614..f27a990 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -14,6 +14,7 @@
#include <string>
+#include "src/tint/ast/test_helper.h"
#include "src/tint/clone_context.h"
#include "src/tint/program_builder.h"
#include "src/tint/transform/transform.h"
@@ -31,7 +32,7 @@
return SkipTransform;
}
- const ast::Type* create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
+ ast::Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder;
auto* sem_type = create_sem_type(sem_type_builder);
Program program(std::move(sem_type_builder));
@@ -39,71 +40,60 @@
return CreateASTTypeFor(ctx, sem_type);
}
- std::string TypeNameOf(const ast::Type* ty) const {
- if (auto* type_name = ty->As<ast::TypeName>()) {
- return ast_type_builder.Symbols().NameFor(type_name->name->symbol);
- }
- return "<not-a-typename>";
- }
-
ProgramBuilder ast_type_builder;
};
TEST_F(CreateASTTypeForTest, Basic) {
- EXPECT_EQ(TypeNameOf(create([](ProgramBuilder& b) { return b.create<type::I32>(); })), "i32");
- EXPECT_EQ(TypeNameOf(create([](ProgramBuilder& b) { return b.create<type::U32>(); })), "u32");
- EXPECT_EQ(TypeNameOf(create([](ProgramBuilder& b) { return b.create<type::F32>(); })), "f32");
- EXPECT_EQ(TypeNameOf(create([](ProgramBuilder& b) { return b.create<type::Bool>(); })), "bool");
+ auto check = [&](ast::Type ty, const char* expect) {
+ ast::CheckIdentifier(ast_type_builder.Symbols(), ty->identifier, expect);
+ };
+
+ check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
+ check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
+ check(create([](ProgramBuilder& b) { return b.create<type::F32>(); }), "f32");
+ check(create([](ProgramBuilder& b) { return b.create<type::Bool>(); }), "bool");
EXPECT_EQ(create([](ProgramBuilder& b) { return b.create<type::Void>(); }), nullptr);
}
TEST_F(CreateASTTypeForTest, Matrix) {
- auto* mat = create([](ProgramBuilder& b) {
+ auto mat = create([](ProgramBuilder& b) {
auto* column_type = b.create<type::Vector>(b.create<type::F32>(), 2u);
return b.create<type::Matrix>(column_type, 3u);
});
- ASSERT_TRUE(mat->Is<ast::Matrix>());
- EXPECT_EQ(TypeNameOf(mat->As<ast::Matrix>()->type), "f32");
- ASSERT_EQ(mat->As<ast::Matrix>()->columns, 3u);
- ASSERT_EQ(mat->As<ast::Matrix>()->rows, 2u);
+
+ ast::CheckIdentifier(ast_type_builder.Symbols(), mat, ast::Template("mat3x2", "f32"));
}
TEST_F(CreateASTTypeForTest, Vector) {
- auto* vec =
+ auto vec =
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
- ASSERT_TRUE(vec->Is<ast::Vector>());
- EXPECT_EQ(TypeNameOf(vec->As<ast::Vector>()->type), "f32");
- ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
+
+ ast::CheckIdentifier(ast_type_builder.Symbols(), vec, ast::Template("vec2", "f32"));
}
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
- auto* arr = create([](ProgramBuilder& b) {
+ auto arr = create([](ProgramBuilder& b) {
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
4u, 4u, 32u, 32u);
});
- ASSERT_TRUE(arr->Is<ast::Array>());
- EXPECT_EQ(TypeNameOf(arr->As<ast::Array>()->type), "f32");
- ASSERT_EQ(arr->As<ast::Array>()->attributes.Length(), 0u);
- auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
- ASSERT_NE(size, nullptr);
- EXPECT_EQ(size->value, 2);
+ ast::CheckIdentifier(ast_type_builder.Symbols(), arr, ast::Template("array", "f32", 2_u));
+ auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
+ ASSERT_NE(tmpl_attr, nullptr);
+ EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
}
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
- auto* arr = create([](ProgramBuilder& b) {
+ auto arr = create([](ProgramBuilder& b) {
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
4u, 4u, 64u, 32u);
});
- ASSERT_TRUE(arr->Is<ast::Array>());
- EXPECT_EQ(TypeNameOf(arr->As<ast::Array>()->type), "f32");
- ASSERT_EQ(arr->As<ast::Array>()->attributes.Length(), 1u);
- ASSERT_TRUE(arr->As<ast::Array>()->attributes[0]->Is<ast::StrideAttribute>());
- ASSERT_EQ(arr->As<ast::Array>()->attributes[0]->As<ast::StrideAttribute>()->stride, 64u);
-
- auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
- ASSERT_NE(size, nullptr);
- EXPECT_EQ(size->value, 2);
+ ast::CheckIdentifier(ast_type_builder.Symbols(), arr, ast::Template("array", "f32", 2_u));
+ auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
+ ASSERT_NE(tmpl_attr, nullptr);
+ ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
+ ASSERT_TRUE(tmpl_attr->attributes[0]->Is<ast::StrideAttribute>());
+ ASSERT_EQ(tmpl_attr->attributes[0]->As<ast::StrideAttribute>()->stride, 64u);
}
// crbug.com/tint/1764
@@ -123,19 +113,18 @@
auto* arr_ty = program.Sem().Get(alias);
CloneContext ctx(&ast_type_builder, &program, false);
- auto* ast_ty = tint::As<ast::TypeName>(CreateASTTypeFor(ctx, arr_ty));
- ASSERT_NE(ast_ty, nullptr);
- EXPECT_EQ(ast_type_builder.Symbols().NameFor(ast_ty->name->symbol), "A");
+ auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
+ ast::CheckIdentifier(ast_type_builder.Symbols(), ast_ty, "A");
}
TEST_F(CreateASTTypeForTest, Struct) {
- auto* str = create([](ProgramBuilder& b) {
+ auto str = create([](ProgramBuilder& b) {
auto* decl = b.Structure("S", {});
return b.create<sem::Struct>(decl, decl->source, decl->name->symbol, utils::Empty,
4u /* align */, 4u /* size */, 4u /* size_no_padding */);
});
- ASSERT_TRUE(str->Is<ast::TypeName>());
- EXPECT_EQ(ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name->symbol), "S");
+
+ ast::CheckIdentifier(ast_type_builder.Symbols(), str, "S");
}
} // namespace
diff --git a/src/tint/transform/truncate_interstage_variables.cc b/src/tint/transform/truncate_interstage_variables.cc
index cf15a38..b4ae24e 100644
--- a/src/tint/transform/truncate_interstage_variables.cc
+++ b/src/tint/transform/truncate_interstage_variables.cc
@@ -143,11 +143,12 @@
utils::Vector{b.Param("io", ctx.Clone(func_ast->return_type))},
b.ty(new_struct_sym),
utils::Vector{
- b.Return(b.Call(b.ty(new_struct_sym), std::move(initializer_exprs)))});
+ b.Return(b.Call(new_struct_sym, std::move(initializer_exprs))),
+ });
return TruncatedStructAndConverter{new_struct_sym, mapping_fn_sym};
});
- ctx.Replace(func_ast->return_type, b.ty(entry.truncated_struct));
+ ctx.Replace(func_ast->return_type.expr, b.Expr(entry.truncated_struct));
entry_point_functions_to_truncate_functions.Add(func_sem, entry.truncate_fn);
}
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index ac00a39..8f79544 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -56,7 +56,7 @@
renamed_to.Add(v, symbol);
auto source = ctx.Clone(decl->source);
- auto* type = ctx.Clone(decl->type);
+ auto type = decl->type ? ctx.Clone(decl->type) : ast::Type{};
auto* initializer = ctx.Clone(decl->initializer);
auto attributes = ctx.Clone(decl->attributes);
return Switch(
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index 438ede8..11984cd 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -395,7 +395,7 @@
// Convert the fetched scalar/vector if WGSL variable is of `f16` types
if (var_dt.base_type == BaseWGSLType::kF16) {
// The type of the same element number of base type of target WGSL variable
- const ast::Type* loaded_data_target_type;
+ ast::Type loaded_data_target_type;
if (fmt_dt.width == 1) {
loaded_data_target_type = b.ty.f16();
} else {
@@ -443,8 +443,7 @@
}
}
- const ast::Type* target_ty = CreateASTTypeFor(ctx, var.type);
- value = b.Call(target_ty, values);
+ value = b.Call(CreateASTTypeFor(ctx, var.type), values);
}
// Assign the value to the WGSL variable
@@ -735,7 +734,7 @@
uint32_t offset,
uint32_t buffer,
uint32_t element_stride,
- const ast::Type* base_type,
+ ast::Type base_type,
VertexFormat base_format,
uint32_t count) {
utils::Vector<const ast::Expression*, 8> expr_list;
@@ -745,7 +744,7 @@
expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
}
- return b.Call(b.create<ast::Vector>(base_type, count), std::move(expr_list));
+ return b.Call(b.ty.vec(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
@@ -757,7 +756,7 @@
if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->name->symbol);
- auto* func_var_type = ctx.Clone(param->type);
+ auto func_var_type = ctx.Clone(param->type);
auto* func_var = b.Var(func_var_sym, func_var_type);
ctx.InsertFront(func->body->statements, b.Decl(func_var));
// Capture mapping from location to the new variable.
@@ -856,7 +855,7 @@
utils::Vector<const ast::StructMember*, 8> new_members;
for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->name);
- auto* member_type = ctx.Clone(member->type);
+ auto member_type = ctx.Clone(member->type);
auto member_attrs = ctx.Clone(member->attributes);
new_members.Push(b.Member(member_name, member_type, std::move(member_attrs)));
}
@@ -926,7 +925,7 @@
// Rewrite the function header with the new parameters.
auto func_sym = ctx.Clone(func->name->symbol);
- auto* ret_type = ctx.Clone(func->return_type);
+ auto ret_type = ctx.Clone(func->return_type);
auto* body = ctx.Clone(func->body);
auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes);