Import Tint changes from Dawn
Changes:
- 1a8a19deddb55155846beee9d262fd8ebe00385d [tint] Rename pointer builders to 'ptr', match WGSL order by Ben Clayton <bclayton@google.com>
- 2c27e79d4b7e1cb6a04f3b55776d815bcb609bf0 [tint][resolver] Fix short-circuiting of const identifiers by Ben Clayton <bclayton@google.com>
- 4b8a3d34d4e5185c580ca613dbdd4e84ba78f3c8 [tint][transform][hlsl] Fix bad materialization by Ben Clayton <bclayton@google.com>
- cf2cf156c4f5fee47bd4e6c8e4cdf2332bac3ee8 [ir] Remove old `Load` constructor. by dan sinclair <dsinclair@chromium.org>
- 455e4b80f6b8d065660ced359bf836f72c656f33 D3D12: Always skip Robustness transform on non-DBO storag... by Jiawei Shao <jiawei.shao@intel.com>
- f3a03513bd3c8c9224d2a5992e8d768a4016018f [tint][type][ir] Add fluent type constructor helpers by Ben Clayton <bclayton@google.com>
- b768af431559c05ef71fd4de28cbbad5e88f9de2 [tint][ir] Use new fluent builder methods by Ben Clayton <bclayton@google.com>
- 039ffdc9c79825c27aca6f4631610b5f5ce4f8ae [tint][ir] Rework Builder to allow for fluent args by Ben Clayton <bclayton@google.com>
- c67a4fa78094b1b79027148f9250c27064bf9d95 [tint][ir] Strip Create prefix from methods by Ben Clayton <bclayton@google.com>
- 4abe1b0cd66a56b5f47b7bb7562672f69789fa2b [tint][ir] Simplify Convert by Ben Clayton <bclayton@google.com>
- 1d0ac04a729edefdc75a15ae89b362f65b3fe4e3 [tint][ir] Rename Builtin and rename Builder Call methods by Ben Clayton <bclayton@google.com>
- bc3111cc84d29603e0d390b80dfc928c4f9ee4f1 [tint][disassembler] Trim 'note' source to block start by Ben Clayton <bclayton@google.com>
- d3fe5f542f0b4402131470c62948b83db06e6ee7 [tint][ir] Validate Access instruction by Ben Clayton <bclayton@google.com>
- 7d7dce35bb0b567a83da1742d02514a7c6fddc6f [ir] Builder::Return takes single Value by James Price <jrprice@google.com>
- 1c9a507b2003560230b7f8df397a6c4e3177cabc [ir][spirv-writer] Use Type::Element() by James Price <jrprice@google.com>
- 8f7092ebe46deda3cbb61c80700a360bdaa7cd60 [tint][type] Add Element(uint32_t) by Ben Clayton <bclayton@google.com>
- 471a015c5b8b6afe318d57f9e354934f7b7219df [tint][type] Rework ElementOf() and DeepestElementOf() by Ben Clayton <bclayton@google.com>
- a5268dff5477b3b7b8bed8216020b0d0b46a5ef7 [ir] Unconst most of the IR. by dan sinclair <dsinclair@chromium.org>
- 5c9fd76edd8cbbcd178fca81aa9de3be865f4098 [tint][type] Add NumericScalar base class. by Ben Clayton <bclayton@google.com>
- 1416b18116faab4b1e1349f0de58d792edd24bab [tint][type] Add Scalar base class. by Ben Clayton <bclayton@google.com>
- bed3fe068c18e42a2b832145e0caac7c6b666928 [ir] Address review comments in VarForDynamicIndex by James Price <jrprice@google.com>
- be3125eb3c6fef4d0760201d4ba5cb8bd5ec33db [ir] Add Value::ReplaceAllUsesWith helper by James Price <jrprice@google.com>
- 55aff4e5ef8787d597646a897a0a9a5646eb3f4e [ir] Return instruction from Block::Append/Prepend by James Price <jrprice@google.com>
- e6b4a908d2e98aeaee2efdbb957c267d3ea2418f [tint][ir] Rename Builder.Declare() to Builder.Var() by Ben Clayton <bclayton@google.com>
- d2e30d0df26be3075de7319eb1ccc124672b63de [tint][ir] Change Var::Type() to return type::Pointer* by Ben Clayton <bclayton@google.com>
- 0ebf677dfe216d923a5e233accb419b9182c4bb0 [tint][ir] Add MultiInBlock that derives from Block by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 1a8a19deddb55155846beee9d262fd8ebe00385d
Change-Id: Id7732faab084bb70e8c323f329c3e812a3bb0b7f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/136720
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index bd06461..2b450fb 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -862,6 +862,8 @@
"type/multisampled_texture.h",
"type/node.cc",
"type/node.h",
+ "type/numeric_scalar.cc",
+ "type/numeric_scalar.h",
"type/pointer.cc",
"type/pointer.h",
"type/reference.cc",
@@ -872,6 +874,8 @@
"type/sampler.h",
"type/sampler_kind.cc",
"type/sampler_kind.h",
+ "type/scalar.cc",
+ "type/scalar.h",
"type/storage_texture.cc",
"type/storage_texture.h",
"type/struct.cc",
@@ -1230,8 +1234,8 @@
"ir/break_if.h",
"ir/builder.cc",
"ir/builder.h",
- "ir/builtin.cc",
- "ir/builtin.h",
+ "ir/builtin_call.cc",
+ "ir/builtin_call.h",
"ir/call.cc",
"ir/call.h",
"ir/constant.cc",
@@ -1269,6 +1273,8 @@
"ir/loop.h",
"ir/module.cc",
"ir/module.h",
+ "ir/multi_in_block.cc",
+ "ir/multi_in_block.h",
"ir/next_iteration.cc",
"ir/next_iteration.h",
"ir/operand_instruction.cc",
@@ -2302,7 +2308,7 @@
"ir/block_param_test.cc",
"ir/block_test.cc",
"ir/break_if_test.cc",
- "ir/builtin_test.cc",
+ "ir/builtin_call_test.cc",
"ir/constant_test.cc",
"ir/construct_test.cc",
"ir/continue_test.cc",
@@ -2330,6 +2336,7 @@
"ir/load_test.cc",
"ir/loop_test.cc",
"ir/module_test.cc",
+ "ir/multi_in_block_test.cc",
"ir/next_iteration_test.cc",
"ir/program_test_helper.h",
"ir/return_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 764be39..6e6d760 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -491,6 +491,8 @@
type/multisampled_texture.h
type/node.cc
type/node.h
+ type/numeric_scalar.cc
+ type/numeric_scalar.h
type/pointer.cc
type/pointer.h
type/reference.cc
@@ -501,6 +503,8 @@
type/sampler.h
type/sampler_kind.cc
type/sampler_kind.h
+ type/scalar.cc
+ type/scalar.h
type/storage_texture.cc
type/storage_texture.h
type/struct.cc
@@ -733,8 +737,8 @@
ir/break_if.h
ir/builder.cc
ir/builder.h
- ir/builtin.cc
- ir/builtin.h
+ ir/builtin_call.cc
+ ir/builtin_call.h
ir/call.cc
ir/call.h
ir/constant.cc
@@ -774,6 +778,8 @@
ir/loop.h
ir/module.cc
ir/module.h
+ ir/multi_in_block.cc
+ ir/multi_in_block.h
ir/next_iteration.cc
ir/next_iteration.h
ir/operand_instruction.cc
@@ -1503,7 +1509,7 @@
ir/block_param_test.cc
ir/block_test.cc
ir/break_if_test.cc
- ir/builtin_test.cc
+ ir/builtin_call_test.cc
ir/constant_test.cc
ir/construct_test.cc
ir/continue_test.cc
@@ -1531,6 +1537,7 @@
ir/load_test.cc
ir/loop_test.cc
ir/module_test.cc
+ ir/multi_in_block_test.cc
ir/next_iteration_test.cc
ir/program_test_helper.h
ir/return_test.cc
diff --git a/src/tint/ast/transform/builtin_polyfill.cc b/src/tint/ast/transform/builtin_polyfill.cc
index 3a4b01f..05f5da9 100644
--- a/src/tint/ast/transform/builtin_polyfill.cc
+++ b/src/tint/ast/transform/builtin_polyfill.cc
@@ -590,7 +590,7 @@
uint32_t width = WidthOf(ty);
// Currently in WGSL parameters of insertBits must be i32, u32, vecN<i32> or vecN<u32>
- if (TINT_UNLIKELY(((!type::Type::DeepestElementOf(ty)->IsAnyOf<type::I32, type::U32>())))) {
+ if (TINT_UNLIKELY(((!ty->DeepestElement()->IsAnyOf<type::I32, type::U32>())))) {
TINT_ICE(Transform, b.Diagnostics())
<< "insertBits polyfill only support i32, u32, and vector of i32 or u32, got "
<< ty->FriendlyName();
@@ -814,7 +814,7 @@
auto name = b.Symbols().New("tint_workgroupUniformLoad");
b.Func(name,
utils::Vector{
- b.Param("p", b.ty.pointer(T(type), builtin::AddressSpace::kWorkgroup)),
+ b.Param("p", b.ty.ptr(builtin::AddressSpace::kWorkgroup, T(type))),
},
T(type),
utils::Vector{
@@ -883,7 +883,7 @@
const Expression* BitshiftModulo(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
- auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
+ auto* lhs_el_ty = lhs_ty->DeepestElement();
const Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<type::Vector>()) {
mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
@@ -904,10 +904,8 @@
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
const bool is_div = bin_op->op == BinaryOp::kDivide;
- uint32_t lhs_width = 1;
- uint32_t rhs_width = 1;
- const auto* lhs_el_ty = type::Type::ElementOf(lhs_ty, &lhs_width);
- const auto* rhs_el_ty = type::Type::ElementOf(rhs_ty, &rhs_width);
+ const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
+ const auto [rhs_el_ty, rhs_width] = rhs_ty->Elements(rhs_ty, 1);
const uint32_t width = std::max(lhs_width, rhs_width);
@@ -997,10 +995,8 @@
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
- uint32_t lhs_width = 1;
- uint32_t rhs_width = 1;
- const auto* lhs_el_ty = type::Type::ElementOf(lhs_ty, &lhs_width);
- const auto* rhs_el_ty = type::Type::ElementOf(rhs_ty, &rhs_width);
+ const auto [lhs_el_ty, lhs_width] = lhs_ty->Elements(lhs_ty, 1);
+ const auto [rhs_el_ty, rhs_width] = rhs_ty->Elements(rhs_ty, 1);
const uint32_t width = std::max(lhs_width, rhs_width);
@@ -1249,10 +1245,10 @@
[&](const sem::ValueConversion* conv) {
if (cfg.builtins.conv_f32_to_iu32) {
auto* src_ty = conv->Source();
- if (tint::Is<type::F32>(type::Type::ElementOf(src_ty))) {
+ if (tint::Is<type::F32>(src_ty->Elements(src_ty).type)) {
auto* dst_ty = conv->Target();
if (tint::utils::IsAnyOf<type::I32, type::U32>(
- type::Type::ElementOf(dst_ty))) {
+ dst_ty->Elements(dst_ty).type)) {
return f32_conv_polyfills.GetOrCreate(dst_ty, [&] { //
return ConvF32ToIU32(src_ty, dst_ty);
});
diff --git a/src/tint/ast/transform/calculate_array_length.cc b/src/tint/ast/transform/calculate_array_length.cc
index 59f684a..0e11ad4 100644
--- a/src/tint/ast/transform/calculate_array_length.cc
+++ b/src/tint/ast/transform/calculate_array_length.cc
@@ -107,18 +107,17 @@
auto name = b.Sym();
auto type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = b.Disable(DisabledValidation::kFunctionParameter);
- b.Func(
- name,
- utils::Vector{
- b.Param("buffer",
- b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
- utils::Vector{disable_validation}),
- b.Param("result", b.ty.pointer(b.ty.u32(), builtin::AddressSpace::kFunction)),
- },
- b.ty.void_(), nullptr,
- utils::Vector{
- b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
- });
+ b.Func(name,
+ utils::Vector{
+ b.Param("buffer",
+ b.ty.ptr(buffer_type->AddressSpace(), type, buffer_type->Access()),
+ utils::Vector{disable_validation}),
+ b.Param("result", b.ty.ptr(builtin::AddressSpace::kFunction, b.ty.u32())),
+ },
+ b.ty.void_(), nullptr,
+ utils::Vector{
+ b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
+ });
return name;
});
diff --git a/src/tint/ast/transform/decompose_memory_access.cc b/src/tint/ast/transform/decompose_memory_access.cc
index bd48606..c1b606c 100644
--- a/src/tint/ast/transform/decompose_memory_access.cc
+++ b/src/tint/ast/transform/decompose_memory_access.cc
@@ -656,26 +656,7 @@
<< el_ty->TypeInfo().name;
}
- Type ret_ty;
-
- // For intrinsics that return a struct, there is no AST node for it, so create one now.
- if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
- auto* str = intrinsic->ReturnType()->As<type::Struct>();
- TINT_ASSERT(Transform, str);
-
- utils::Vector<const StructMember*, 8> ast_members;
- ast_members.Reserve(str->Members().Length());
- for (auto& m : str->Members()) {
- ast_members.Push(
- b.Member(ctx.Clone(m->Name()), CreateASTTypeFor(ctx, m->Type())));
- }
-
- auto name = b.Symbols().New("atomic_compare_exchange_weak_ret_type");
- auto* new_str = b.Structure(name, std::move(ast_members));
- ret_ty = b.ty.Of(new_str);
- } else {
- ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
- }
+ Type ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
auto name = b.Symbols().New(buffer.Name() + intrinsic->str());
b.Func(name, std::move(params), ret_ty, nullptr,
diff --git a/src/tint/ast/transform/decompose_memory_access_test.cc b/src/tint/ast/transform/decompose_memory_access_test.cc
index 52f2ff2..3bff4a8 100644
--- a/src/tint/ast/transform/decompose_memory_access_test.cc
+++ b/src/tint/ast/transform/decompose_memory_access_test.cc
@@ -3640,13 +3640,8 @@
@internal(intrinsic_atomic_exchange_storage_i32) @internal(disable_validation__function_has_no_body)
fn sbatomicExchange(offset : u32, param_1 : i32) -> i32
-struct atomic_compare_exchange_weak_ret_type {
- old_value : i32,
- exchanged : bool,
-}
-
@internal(intrinsic_atomic_compare_exchange_weak_storage_i32) @internal(disable_validation__function_has_no_body)
-fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> atomic_compare_exchange_weak_ret_type
+fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> __atomic_compare_exchange_result_i32
@internal(intrinsic_atomic_store_storage_u32) @internal(disable_validation__function_has_no_body)
fn sbatomicStore_1(offset : u32, param_1 : u32)
@@ -3678,13 +3673,8 @@
@internal(intrinsic_atomic_exchange_storage_u32) @internal(disable_validation__function_has_no_body)
fn sbatomicExchange_1(offset : u32, param_1 : u32) -> u32
-struct atomic_compare_exchange_weak_ret_type_1 {
- old_value : u32,
- exchanged : bool,
-}
-
@internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
-fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> atomic_compare_exchange_weak_ret_type_1
+fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> __atomic_compare_exchange_result_u32
@compute @workgroup_size(1)
fn main() {
@@ -3787,13 +3777,8 @@
@internal(intrinsic_atomic_exchange_storage_i32) @internal(disable_validation__function_has_no_body)
fn sbatomicExchange(offset : u32, param_1 : i32) -> i32
-struct atomic_compare_exchange_weak_ret_type {
- old_value : i32,
- exchanged : bool,
-}
-
@internal(intrinsic_atomic_compare_exchange_weak_storage_i32) @internal(disable_validation__function_has_no_body)
-fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> atomic_compare_exchange_weak_ret_type
+fn sbatomicCompareExchangeWeak(offset : u32, param_1 : i32, param_2 : i32) -> __atomic_compare_exchange_result_i32
@internal(intrinsic_atomic_store_storage_u32) @internal(disable_validation__function_has_no_body)
fn sbatomicStore_1(offset : u32, param_1 : u32)
@@ -3825,13 +3810,8 @@
@internal(intrinsic_atomic_exchange_storage_u32) @internal(disable_validation__function_has_no_body)
fn sbatomicExchange_1(offset : u32, param_1 : u32) -> u32
-struct atomic_compare_exchange_weak_ret_type_1 {
- old_value : u32,
- exchanged : bool,
-}
-
@internal(intrinsic_atomic_compare_exchange_weak_storage_u32) @internal(disable_validation__function_has_no_body)
-fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> atomic_compare_exchange_weak_ret_type_1
+fn sbatomicCompareExchangeWeak_1(offset : u32, param_1 : u32, param_2 : u32) -> __atomic_compare_exchange_result_u32
@compute @workgroup_size(1)
fn main() {
diff --git a/src/tint/ast/transform/direct_variable_access.cc b/src/tint/ast/transform/direct_variable_access.cc
index f4160d4..3267ae2 100644
--- a/src/tint/ast/transform/direct_variable_access.cc
+++ b/src/tint/ast/transform/direct_variable_access.cc
@@ -832,8 +832,8 @@
auto& symbols = *variant.ptr_param_symbols.Find(param);
if (symbols.base_ptr.IsValid()) {
auto base_ptr_ty =
- b.ty.pointer(CreateASTTypeFor(ctx, incoming_shape->root.type),
- incoming_shape->root.address_space);
+ b.ty.ptr(incoming_shape->root.address_space,
+ CreateASTTypeFor(ctx, incoming_shape->root.type));
params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
}
if (symbols.indices.IsValid()) {
diff --git a/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc b/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
index 16257ba..db06414 100644
--- a/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
@@ -160,8 +160,8 @@
}
param_type = sc == builtin::AddressSpace::kStorage
- ? ctx.dst->ty.pointer(param_type, sc, var->Access())
- : ctx.dst->ty.pointer(param_type, sc);
+ ? ctx.dst->ty.ptr(sc, param_type, var->Access())
+ : ctx.dst->ty.ptr(sc, param_type);
auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func->params, param);
is_pointer = true;
@@ -184,7 +184,7 @@
ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member));
auto* local_var = ctx.dst->Let(
new_var_symbol,
- ctx.dst->ty.pointer(store_type(), builtin::AddressSpace::kWorkgroup),
+ ctx.dst->ty.ptr(builtin::AddressSpace::kWorkgroup, store_type()),
member_ptr);
ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
is_pointer = true;
@@ -250,8 +250,8 @@
utils::Vector<const Attribute*, 2> attributes;
if (!ty->is_handle()) {
param_type = sc == builtin::AddressSpace::kStorage
- ? ctx.dst->ty.pointer(param_type, sc, var->Access())
- : ctx.dst->ty.pointer(param_type, sc);
+ ? ctx.dst->ty.ptr(sc, param_type, var->Access())
+ : ctx.dst->ty.ptr(sc, param_type);
is_pointer = true;
// Disable validation of the parameter's address space and of arguments passed to it.
@@ -427,8 +427,8 @@
}
} else {
// Create a parameter that is a pointer to the private variable struct.
- auto ptr = ctx.dst->ty.pointer(ctx.dst->ty(PrivateStructName()),
- builtin::AddressSpace::kPrivate);
+ auto ptr = ctx.dst->ty.ptr(builtin::AddressSpace::kPrivate,
+ ctx.dst->ty(PrivateStructName()));
auto* param = ctx.dst->Param(PrivateStructVariableName(), ptr);
ctx.InsertBack(func_ast->params, param);
}
@@ -490,7 +490,7 @@
auto* str =
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
auto param_type =
- ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup);
+ ctx.dst->ty.ptr(builtin::AddressSpace::kWorkgroup, ctx.dst->ty.Of(str));
auto* param =
ctx.dst->Param(workgroup_param(), param_type,
utils::Vector{
diff --git a/src/tint/ast/transform/packed_vec3.cc b/src/tint/ast/transform/packed_vec3.cc
index 30f3e1c..62f82d2 100644
--- a/src/tint/ast/transform/packed_vec3.cc
+++ b/src/tint/ast/transform/packed_vec3.cc
@@ -382,7 +382,7 @@
? ptr->Access()
: builtin::Access::kUndefined;
auto new_ptr_type =
- b.ty.pointer(new_store_type, ptr->AddressSpace(), access);
+ b.ty.ptr(ptr->AddressSpace(), new_store_type, access);
ctx.Replace(node, new_ptr_type.expr);
}
}
diff --git a/src/tint/ast/transform/preserve_padding.cc b/src/tint/ast/transform/preserve_padding.cc
index ee266a7..7d9143c 100644
--- a/src/tint/ast/transform/preserve_padding.cc
+++ b/src/tint/ast/transform/preserve_padding.cc
@@ -122,8 +122,8 @@
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
utils::Vector<const Parameter*, 2> params = {
b.Param(kDestParamName,
- b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage,
- builtin::Access::kReadWrite)),
+ b.ty.ptr(builtin::AddressSpace::kStorage, CreateASTTypeFor(ctx, ty),
+ builtin::Access::kReadWrite)),
b.Param(kValueParamName, CreateASTTypeFor(ctx, ty)),
};
b.Func(helper_name, params, b.ty.void_(), body());
diff --git a/src/tint/ast/transform/promote_side_effects_to_decl.cc b/src/tint/ast/transform/promote_side_effects_to_decl.cc
index 19136dd..fed8327 100644
--- a/src/tint/ast/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/ast/transform/promote_side_effects_to_decl.cc
@@ -406,7 +406,8 @@
auto clone_maybe_hoisted = [&](const Expression* e) -> const Expression* {
if (to_hoist.count(e)) {
auto name = b.Symbols().New();
- auto* v = b.Let(name, ctx.Clone(e));
+ auto* ty = sem.GetVal(e)->Type();
+ auto* v = b.Let(name, Transform::CreateASTTypeFor(ctx, ty), ctx.Clone(e));
auto* decl = b.Decl(v);
curr_stmts->Push(decl);
return b.Expr(name);
diff --git a/src/tint/ast/transform/promote_side_effects_to_decl_test.cc b/src/tint/ast/transform/promote_side_effects_to_decl_test.cc
index f8ce860..ec54d24 100644
--- a/src/tint/ast/transform/promote_side_effects_to_decl_test.cc
+++ b/src/tint/ast/transform/promote_side_effects_to_decl_test.cc
@@ -25,8 +25,7 @@
auto* src = "";
auto* expect = "";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -44,8 +43,7 @@
auto* expect = src;
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -75,14 +73,13 @@
}
fn f() {
- let tint_symbol = a();
- let tint_symbol_1 = b();
+ let tint_symbol : i32 = a();
+ let tint_symbol_1 : i32 = b();
let r = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -106,13 +103,12 @@
fn f() {
var b = 1;
- let tint_symbol = a();
+ let tint_symbol : i32 = a();
let r = (tint_symbol + b);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -136,14 +132,13 @@
fn f() {
var b = 1;
- let tint_symbol = b;
- let tint_symbol_1 = a();
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a();
let r = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -171,13 +166,12 @@
var b = 1;
var c = 1;
var d = 1;
- let tint_symbol = a();
+ let tint_symbol : i32 = a();
let r = (((tint_symbol + b) + c) + d);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -205,14 +199,13 @@
var b = 1;
var c = 1;
var d = 1;
- let tint_symbol = ((b + c) + d);
- let tint_symbol_1 = a();
+ let tint_symbol : i32 = ((b + c) + d);
+ let tint_symbol_1 : i32 = a();
let r = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -242,14 +235,13 @@
var c = 1;
var d = 1;
var e = 1;
- let tint_symbol = (b + c);
- let tint_symbol_1 = a();
+ let tint_symbol : i32 = (b + c);
+ let tint_symbol_1 : i32 = a();
let r = (((tint_symbol + tint_symbol_1) + d) + e);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -271,15 +263,14 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
let r = ((tint_symbol + tint_symbol_1) + tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -301,13 +292,12 @@
}
fn f() {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = ((((1 + tint_symbol) - 2) + 3) - 4);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -331,15 +321,14 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = b;
- let tint_symbol_2 = a(1);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
let r = ((((tint_symbol + 1) + tint_symbol_1) + tint_symbol_2) + 2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -365,13 +354,12 @@
fn main() {
var b = 1;
var c = 1;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = ((1 + tint_symbol) + b);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -395,14 +383,13 @@
fn main() {
var b = 1;
- let tint_symbol = b;
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
let r = ((tint_symbol + tint_symbol_1) + 1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -428,13 +415,12 @@
fn main() {
var b = 1;
var c = 1;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = (((tint_symbol + b) + 1) + c);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -462,13 +448,12 @@
fn f() {
var b = 0;
- let tint_symbol = atomicAdd(&(sb.a), 123);
+ let tint_symbol : i32 = atomicAdd(&(sb.a), 123);
let r = (tint_symbol + b);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -489,8 +474,7 @@
auto* expect = src;
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -526,14 +510,13 @@
fn f() {
var b = 0;
- let tint_symbol = atomicLoad(&(sb.a));
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = atomicLoad(&(sb.a));
+ let tint_symbol_1 : i32 = a(0);
let r = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -559,15 +542,14 @@
fn f() {
var b : vec3<i32>;
var c : i32;
- let tint_symbol = c;
- let tint_symbol_1 = b[tint_symbol];
- let tint_symbol_2 = a();
+ let tint_symbol : i32 = c;
+ let tint_symbol_1 : i32 = b[tint_symbol];
+ let tint_symbol_2 : i32 = a();
let r = (tint_symbol_1 + tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -599,17 +581,16 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = g(tint_symbol);
- let tint_symbol_2 = b;
- let tint_symbol_3 = a(1);
- let tint_symbol_4 = g((tint_symbol_2 + tint_symbol_3));
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = g(tint_symbol);
+ let tint_symbol_2 : i32 = b;
+ let tint_symbol_3 : i32 = a(1);
+ let tint_symbol_4 : i32 = g((tint_symbol_2 + tint_symbol_3));
let r = (tint_symbol_1 - tint_symbol_4);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -634,19 +615,18 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = i32(tint_symbol);
- let tint_symbol_2 = a(1);
- let tint_symbol_3 = i32((tint_symbol_2 + b));
- let tint_symbol_4 = a(2);
- let tint_symbol_5 = a(3);
- let tint_symbol_6 = i32((tint_symbol_4 - tint_symbol_5));
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = i32(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : i32 = i32((tint_symbol_2 + b));
+ let tint_symbol_4 : i32 = a(2);
+ let tint_symbol_5 : i32 = a(3);
+ let tint_symbol_6 : i32 = i32((tint_symbol_4 - tint_symbol_5));
let r = ((tint_symbol_1 + tint_symbol_3) - tint_symbol_6);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -671,16 +651,15 @@
fn f() {
var b = 1u;
- let tint_symbol = a(0);
- let tint_symbol_1 = u32(tint_symbol);
- let tint_symbol_2 = a(1);
- let tint_symbol_3 = u32(tint_symbol_2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : u32 = u32(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : u32 = u32(tint_symbol_2);
let r = ((tint_symbol_1 + tint_symbol_3) - b);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -705,19 +684,18 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = abs(tint_symbol);
- let tint_symbol_2 = a(1);
- let tint_symbol_3 = abs((tint_symbol_2 + b));
- let tint_symbol_4 = a(2);
- let tint_symbol_5 = a(3);
- let tint_symbol_6 = abs((tint_symbol_4 + tint_symbol_5));
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = abs(tint_symbol);
+ let tint_symbol_2 : i32 = a(1);
+ let tint_symbol_3 : i32 = abs((tint_symbol_2 + b));
+ let tint_symbol_4 : i32 = a(2);
+ let tint_symbol_5 : i32 = a(3);
+ let tint_symbol_6 : i32 = abs((tint_symbol_4 + tint_symbol_5));
let r = ((tint_symbol_1 + tint_symbol_3) - tint_symbol_6);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -750,15 +728,14 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = b;
- let tint_symbol_2 = a(1);
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : S = a(1);
let r = ((tint_symbol.v + tint_symbol_1) + tint_symbol_2.v);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -782,15 +759,14 @@
fn f() {
var b = 1;
- let tint_symbol = -(a(0));
- let tint_symbol_1 = b;
- let tint_symbol_2 = a(1);
+ let tint_symbol : i32 = -(a(0));
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
let r = (tint_symbol + -((tint_symbol_1 + tint_symbol_2)));
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -814,14 +790,13 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
let r = bitcast<u32>((tint_symbol + tint_symbol_1));
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -849,7 +824,7 @@
fn f() {
var b = 1;
{
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var r = (tint_symbol + b);
loop {
{
@@ -861,8 +836,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -889,7 +863,7 @@
fn f() {
var b = 1;
loop {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
if (!(((tint_symbol + b) > 0))) {
break;
}
@@ -900,8 +874,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -937,15 +910,14 @@
}
continuing {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
r = (tint_symbol + b);
}
}
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -978,10 +950,10 @@
var d = 3;
var r = 0;
{
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var r = (tint_symbol + b);
loop {
- let tint_symbol_1 = a(1);
+ let tint_symbol_1 : i32 = a(1);
if (!(((tint_symbol_1 + c) > 0))) {
break;
}
@@ -990,7 +962,7 @@
}
continuing {
- let tint_symbol_2 = a(2);
+ let tint_symbol_2 : i32 = a(2);
r = (tint_symbol_2 + d);
}
}
@@ -998,8 +970,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1026,7 +997,7 @@
fn f() {
var b = 1;
loop {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
if (!(((tint_symbol + b) > 0))) {
break;
}
@@ -1037,8 +1008,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1069,7 +1039,7 @@
if (true) {
var marker = 0;
} else {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
if (((tint_symbol + b) > 0)) {
var marker = 1;
}
@@ -1077,8 +1047,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1119,12 +1088,12 @@
} else if (true) {
var marker = 1;
} else {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
if (((tint_symbol + b) > 0)) {
var marker = 2;
} else {
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
if (((tint_symbol_1 + tint_symbol_2) > 0)) {
var marker = 3;
} else if (true) {
@@ -1137,8 +1106,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1162,14 +1130,13 @@
fn f() -> i32 {
var b = 1;
- let tint_symbol = b;
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
return (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1196,8 +1163,8 @@
fn f() {
var b = 1;
- let tint_symbol = b;
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
switch((tint_symbol + tint_symbol_1)) {
default: {
}
@@ -1205,8 +1172,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1238,8 +1204,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1271,8 +1236,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1302,8 +1266,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1347,8 +1310,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1384,8 +1346,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1429,8 +1390,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1468,8 +1428,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1509,8 +1468,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1548,8 +1506,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1585,8 +1542,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1628,8 +1584,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1685,8 +1640,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1732,8 +1686,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1779,8 +1732,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1818,8 +1770,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1859,8 +1810,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1885,7 +1835,7 @@
fn f() {
var b = true;
- let tint_symbol_2 = a(0);
+ let tint_symbol_2 : bool = a(0);
var tint_symbol_1 = bool(tint_symbol_2);
if (tint_symbol_1) {
var tint_symbol_3 = a(1);
@@ -1906,8 +1856,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1932,10 +1881,10 @@
fn f() {
var b = true;
- let tint_symbol_2 = a(0);
+ let tint_symbol_2 : i32 = a(0);
var tint_symbol_1 = bool(tint_symbol_2);
if (tint_symbol_1) {
- let tint_symbol_3 = a(1);
+ let tint_symbol_3 : i32 = a(1);
tint_symbol_1 = bool(tint_symbol_3);
}
var tint_symbol = tint_symbol_1;
@@ -1946,8 +1895,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -1974,18 +1922,17 @@
fn f() {
var b = 1;
- let tint_symbol_1 = a(0);
+ let tint_symbol_1 : i32 = a(0);
var tint_symbol = bool((tint_symbol_1 == b));
if (tint_symbol) {
- let tint_symbol_2 = a(1);
+ let tint_symbol_2 : i32 = a(1);
tint_symbol = bool((tint_symbol_2 == b));
}
let r = tint_symbol;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2010,7 +1957,7 @@
fn f() {
var b = true;
- let tint_symbol_2 = a(0);
+ let tint_symbol_2 : bool = a(0);
var tint_symbol_1 = all(tint_symbol_2);
if (tint_symbol_1) {
var tint_symbol_3 = a(1);
@@ -2031,8 +1978,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2065,22 +2011,21 @@
fn f() {
var b = true;
- let tint_symbol_2 = a(0);
+ let tint_symbol_2 : S = a(0);
var tint_symbol_1 = tint_symbol_2.v;
if (tint_symbol_1) {
tint_symbol_1 = b;
}
var tint_symbol = tint_symbol_1;
if (tint_symbol) {
- let tint_symbol_3 = a(1);
+ let tint_symbol_3 : S = a(1);
tint_symbol = tint_symbol_3.v;
}
let r = tint_symbol;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2113,8 +2058,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2146,8 +2090,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2190,8 +2133,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2232,8 +2174,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2279,8 +2220,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2342,8 +2282,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2384,8 +2323,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2427,8 +2365,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2492,8 +2429,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2517,8 +2453,7 @@
auto* expect = src;
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2550,13 +2485,12 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = g(tint_symbol, b, 3);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2586,15 +2520,14 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
let r = g(tint_symbol, tint_symbol_1, tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2627,15 +2560,14 @@
fn f() {
var b = 1;
- let tint_symbol = a(0);
- let tint_symbol_1 = b;
- let tint_symbol_2 = a(1);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = b;
+ let tint_symbol_2 : i32 = a(1);
let r = g(tint_symbol, tint_symbol_1, tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2671,17 +2603,16 @@
var b = 0;
var c = 0;
var d = 0;
- let tint_symbol = b;
- let tint_symbol_1 = c;
- let tint_symbol_2 = a(0);
- let tint_symbol_3 = g(tint_symbol_1, tint_symbol_2, d);
- let tint_symbol_4 = a(1);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = c;
+ let tint_symbol_2 : i32 = a(0);
+ let tint_symbol_3 : i32 = g(tint_symbol_1, tint_symbol_2, d);
+ let tint_symbol_4 : i32 = a(1);
let r = ((tint_symbol + tint_symbol_3) + tint_symbol_4);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2707,13 +2638,12 @@
fn f() {
var b = array<array<i32, 10>, 10>();
var c = 1;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var r = b[tint_symbol][c];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2747,8 +2677,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2772,14 +2701,13 @@
fn f() {
var b = array<array<i32, 10>, 10>();
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
var r = b[tint_symbol][tint_symbol_1];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2805,8 +2733,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2830,13 +2757,12 @@
fn f() {
var b = array<i32, 10>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
b[tint_symbol] = a(1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2860,14 +2786,13 @@
fn f() {
var b = array<array<i32, 10>, 10>();
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
b[tint_symbol][tint_symbol_1] = a(2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2891,15 +2816,14 @@
fn f() {
var b = array<array<array<i32, 10>, 10>, 10>();
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
b[tint_symbol][tint_symbol_1][tint_symbol_2] = a(3);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2927,14 +2851,13 @@
var b = array<i32, 3>();
var d = array<array<i32, 3>, 3>();
var a_1 = 0;
- let tint_symbol = a(2);
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = a(2);
+ let tint_symbol_1 : i32 = a(0);
b[tint_symbol] = d[tint_symbol_1][a_1];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2958,13 +2881,12 @@
fn f() {
var b = vec3<i32>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
b[tint_symbol] = a(1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -2990,13 +2912,12 @@
fn f() {
var b = vec3<i32>();
var c = 0;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
b[tint_symbol] = c;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3022,13 +2943,12 @@
fn f() {
var b = vec3<i32>();
var c = 0;
- let tint_symbol = c;
+ let tint_symbol : i32 = c;
b[tint_symbol] = a(0);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3062,15 +2982,14 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
var r = S(tint_symbol, tint_symbol_1, tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3092,15 +3011,14 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = a(2);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : i32 = a(2);
var r = array<i32, 3>(tint_symbol, tint_symbol_1, tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3122,18 +3040,17 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
- let tint_symbol_2 = array<i32, 2>(tint_symbol, tint_symbol_1);
- let tint_symbol_3 = a(2);
- let tint_symbol_4 = a(3);
- let tint_symbol_5 = array<i32, 2>(tint_symbol_3, tint_symbol_4);
+ let tint_symbol : i32 = a(0);
+ let tint_symbol_1 : i32 = a(1);
+ let tint_symbol_2 : array<i32, 2u> = array<i32, 2>(tint_symbol, tint_symbol_1);
+ let tint_symbol_3 : i32 = a(2);
+ let tint_symbol_4 : i32 = a(3);
+ let tint_symbol_5 : array<i32, 2u> = array<i32, 2>(tint_symbol_3, tint_symbol_4);
var r = array<array<i32, 2>, 2>(tint_symbol_2, tint_symbol_5);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3155,14 +3072,13 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
+ let tint_symbol : vec3<i32> = a(0);
+ let tint_symbol_1 : vec3<i32> = a(1);
var r = (tint_symbol.x + tint_symbol_1.y);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3194,14 +3110,13 @@
}
fn f() {
- let tint_symbol = a(0);
- let tint_symbol_1 = a(1);
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : S = a(1);
var r = (tint_symbol.x + tint_symbol_1.y);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3253,19 +3168,18 @@
var k = 0;
var l = 0;
var m = 0;
- let tint_symbol = a(0);
- let tint_symbol_1 = i;
- let tint_symbol_2 = a(1);
- let tint_symbol_3 = j;
- let tint_symbol_4 = a(2);
- let tint_symbol_5 = k;
- let tint_symbol_6 = b(3);
+ let tint_symbol : S = a(0);
+ let tint_symbol_1 : i32 = i;
+ let tint_symbol_2 : S = a(1);
+ let tint_symbol_3 : i32 = j;
+ let tint_symbol_4 : S = a(2);
+ let tint_symbol_5 : i32 = k;
+ let tint_symbol_6 : i32 = b(3);
var r = (((((tint_symbol.x + tint_symbol_1) + tint_symbol_2.y) + tint_symbol_3) + tint_symbol_4.arr[((tint_symbol_5 + tint_symbol_6) + l)]) + m);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3289,14 +3203,13 @@
fn f() {
var v = array<i32, 10>();
- let tint_symbol = v[0];
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = v[0];
+ let tint_symbol_1 : i32 = a(0);
let r = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3320,13 +3233,12 @@
fn f() {
var v = array<i32, 10>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = v[tint_symbol];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3350,13 +3262,12 @@
fn f() {
var v = array<array<i32, 10>, 10>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = v[tint_symbol][0];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3380,13 +3291,12 @@
fn f() {
var v = array<array<i32, 10>, 10>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = v[0][tint_symbol];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3412,13 +3322,12 @@
fn f() {
var v = array<array<i32, 10>, 10>();
var b : i32;
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = v[tint_symbol][b];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3444,14 +3353,13 @@
fn f() {
var v = array<array<i32, 10>, 10>();
var b : i32;
- let tint_symbol = b;
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = a(0);
let r = v[tint_symbol][tint_symbol_1];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3477,15 +3385,14 @@
fn f() {
var v = array<i32, 10>();
var b = 0;
- let tint_symbol = b;
- let tint_symbol_1 = v[tint_symbol];
- let tint_symbol_2 = a(0);
+ let tint_symbol : i32 = b;
+ let tint_symbol_1 : i32 = v[tint_symbol];
+ let tint_symbol_2 : i32 = a(0);
let r = (tint_symbol_1 + tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3509,14 +3416,13 @@
fn f() {
var v = array<i32, 10>();
- let tint_symbol = v[0];
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = v[0];
+ let tint_symbol_1 : i32 = a(0);
let r = (tint_symbol + v[tint_symbol_1]);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3542,13 +3448,12 @@
fn f() {
var v = array<i32, 10>();
var w = array<i32, 10>();
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
v[w[tint_symbol]] = 1;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3574,14 +3479,13 @@
fn f() {
var v = array<i32, 10>();
var w = array<i32, 10>();
- let tint_symbol = w[0];
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = w[0];
+ let tint_symbol_1 : i32 = a(0);
v[(tint_symbol + tint_symbol_1)] = 1;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3608,14 +3512,13 @@
fn f() {
var v = array<i32, 10>();
var w = array<i32, 10>();
- let tint_symbol = w[0];
- let tint_symbol_1 = a(0);
+ let tint_symbol : i32 = w[0];
+ let tint_symbol_1 : i32 = a(0);
v[(tint_symbol + w[tint_symbol_1])] = 1;
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3645,14 +3548,13 @@
}
fn f() {
- let tint_symbol = b();
- let tint_symbol_1 = a(0);
+ let tint_symbol : array<i32, 10u> = b();
+ let tint_symbol_1 : i32 = a(0);
let r = tint_symbol[tint_symbol_1];
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3682,14 +3584,13 @@
}
fn f() {
- let tint_symbol = b();
- let tint_symbol_1 = a(0);
+ let tint_symbol : array<i32, 10u> = b();
+ let tint_symbol_1 : i32 = a(0);
let r = (tint_symbol[0] + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3718,14 +3619,13 @@
fn f() {
var v = vec4<i32>();
- let tint_symbol = v.x;
- let tint_symbol_1 = modify_vec(&(v));
+ let tint_symbol : i32 = v.x;
+ let tint_symbol_1 : i32 = modify_vec(&(v));
let l = (tint_symbol + tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3755,13 +3655,12 @@
}
fn f() {
- let tint_symbol = get_uv();
+ let tint_symbol : vec2<f32> = get_uv();
let r = textureGather(1, tex, samp, tint_symbol, 1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3801,13 +3700,12 @@
}
fn f() {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
let r = g(&(b), tint_symbol);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3824,8 +3722,7 @@
auto* expect = src;
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3863,18 +3760,17 @@
}
fn f() {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var tint_symbol_1 = b(1);
if (tint_symbol_1) {
tint_symbol_1 = b(2);
}
- let tint_symbol_2 = g(tint_symbol_1);
+ let tint_symbol_2 : i32 = g(tint_symbol_1);
let r = (tint_symbol + tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3916,14 +3812,13 @@
if (tint_symbol) {
tint_symbol = b(1);
}
- let tint_symbol_1 = g(tint_symbol);
- let tint_symbol_2 = a(2);
+ let tint_symbol_1 : i32 = g(tint_symbol);
+ let tint_symbol_2 : i32 = a(2);
let r = (tint_symbol_1 + tint_symbol_2);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -3961,7 +3856,7 @@
}
fn f() {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var tint_symbol_1 = b(1);
if (tint_symbol_1) {
tint_symbol_1 = b(2);
@@ -3970,8 +3865,7 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -4013,13 +3907,12 @@
if (tint_symbol) {
tint_symbol = b(1);
}
- let tint_symbol_1 = a(2);
+ let tint_symbol_1 : i32 = a(2);
let r = g(tint_symbol, tint_symbol_1);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -4065,22 +3958,21 @@
}
fn f() {
- let tint_symbol = a(0);
+ let tint_symbol : i32 = a(0);
var tint_symbol_1 = b(1);
if (tint_symbol_1) {
- let tint_symbol_2 = a(2);
- let tint_symbol_3 = a(3);
+ let tint_symbol_2 : i32 = a(2);
+ let tint_symbol_3 : i32 = a(3);
tint_symbol_1 = b((tint_symbol_2 + tint_symbol_3));
}
- let tint_symbol_4 = a(4);
- let tint_symbol_5 = g(tint_symbol_1, tint_symbol_4);
- let tint_symbol_6 = a(5);
+ let tint_symbol_4 : i32 = a(4);
+ let tint_symbol_5 : i32 = g(tint_symbol_1, tint_symbol_4);
+ let tint_symbol_6 : i32 = a(5);
let r = ((tint_symbol + tint_symbol_5) + tint_symbol_6);
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
@@ -4118,7 +4010,7 @@
}
fn f(t : texture_2d<f32>, s : sampler) -> vec4<f32> {
- let tint_symbol = side_effects();
+ let tint_symbol : vec2<f32> = side_effects();
return textureSample(t, s, tint_symbol);
}
@@ -4127,8 +4019,79 @@
}
)";
- Transform::DataMap data;
- auto got = Run<PromoteSideEffectsToDecl>(src, data);
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, BuiltinReturnType) {
+ auto* src = R"(
+fn X(a:vec2f, b:vec2f) {
+}
+
+fn Y() -> vec2f { return vec2f(); }
+
+fn f() {
+ var v: vec2f;
+ X(vec2(), v); // okay
+ X(vec2(), Y()); // errors
+}
+)";
+
+ auto* expect = R"(
+fn X(a : vec2f, b : vec2f) {
+}
+
+fn Y() -> vec2f {
+ return vec2f();
+}
+
+fn f() {
+ var v : vec2f;
+ X(vec2(), v);
+ let tint_symbol : vec2<f32> = vec2();
+ let tint_symbol_1 : vec2<f32> = Y();
+ X(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PromoteSideEffectsToDeclTest, Bug1963) {
+ auto* src = R"(
+fn X(a:vec2f, b:vec2f) {
+}
+
+fn Y() -> vec2f { return vec2f(); }
+
+fn f() {
+ var v: vec2f;
+ X(vec2(), v); // okay
+ X(vec2(), Y()); // errors
+}
+)";
+
+ auto* expect = R"(
+fn X(a : vec2f, b : vec2f) {
+}
+
+fn Y() -> vec2f {
+ return vec2f();
+}
+
+fn f() {
+ var v : vec2f;
+ X(vec2(), v);
+ let tint_symbol : vec2<f32> = vec2();
+ let tint_symbol_1 : vec2<f32> = Y();
+ X(tint_symbol, tint_symbol_1);
+}
+)";
+
+ auto got = Run<PromoteSideEffectsToDecl>(src);
EXPECT_EQ(expect, str(got));
}
diff --git a/src/tint/ast/transform/robustness.cc b/src/tint/ast/transform/robustness.cc
index 120d256..aa94748 100644
--- a/src/tint/ast/transform/robustness.cc
+++ b/src/tint/ast/transform/robustness.cc
@@ -62,6 +62,9 @@
// obj[idx]
// Array, matrix and vector indexing may require robustness transformation.
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
+ if (IsIgnoredResourceBinding(expr->Object()->RootIdentifier())) {
+ return;
+ }
switch (ActionFor(expr)) {
case Action::kPredicate:
PredicateIndexAccessor(expr);
@@ -673,6 +676,22 @@
const CallExpression* CastToUnsigned(const Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
}
+
+ /// @returns true if the variable represents a resource binding that should be ignored in the
+ /// robustness check.
+ /// TODO(tint:1890): make this function work with unrestricted pointer paramters. Note that this
+ /// depends on transform::DirectVariableAccess to have been run first.
+ bool IsIgnoredResourceBinding(const sem::Variable* variable) const {
+ auto* globalVariable = utils::As<sem::GlobalVariable>(variable);
+ if (globalVariable == nullptr) {
+ return false;
+ }
+ if (!globalVariable->BindingPoint().has_value()) {
+ return false;
+ }
+ sem::BindingPoint bindingPoint = *globalVariable->BindingPoint();
+ return cfg.bindings_ignored.find(bindingPoint) != cfg.bindings_ignored.cend();
+ }
};
Robustness::Config::Config() = default;
diff --git a/src/tint/ast/transform/robustness.h b/src/tint/ast/transform/robustness.h
index 35804b4..7278a41 100644
--- a/src/tint/ast/transform/robustness.h
+++ b/src/tint/ast/transform/robustness.h
@@ -15,7 +15,10 @@
#ifndef SRC_TINT_AST_TRANSFORM_ROBUSTNESS_H_
#define SRC_TINT_AST_TRANSFORM_ROBUSTNESS_H_
+#include <unordered_set>
+
#include "src/tint/ast/transform/transform.h"
+#include "src/tint/sem/binding_point.h"
namespace tint::ast::transform {
@@ -77,6 +80,9 @@
Action uniform_action = Action::kDefault;
/// Robustness action for variables in the 'workgroup' address space
Action workgroup_action = Action::kDefault;
+
+ /// Bindings that should always be applied Actions::kIgnore on.
+ std::unordered_set<tint::sem::BindingPoint> bindings_ignored;
};
/// Constructor
diff --git a/src/tint/ast/transform/spirv_atomic.cc b/src/tint/ast/transform/spirv_atomic.cc
index aff4a7c..3ee89f2 100644
--- a/src/tint/ast/transform/spirv_atomic.cc
+++ b/src/tint/ast/transform/spirv_atomic.cc
@@ -215,8 +215,8 @@
return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
},
[&](const type::Pointer* ptr) {
- return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->AddressSpace(),
- ptr->Access());
+ return b.ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()),
+ ptr->Access());
},
[&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
[&](Default) {
diff --git a/src/tint/ast/transform/transform.cc b/src/tint/ast/transform/transform.cc
index bee1f39..6a9eebc 100644
--- a/src/tint/ast/transform/transform.cc
+++ b/src/tint/ast/transform/transform.cc
@@ -169,7 +169,7 @@
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
- return ctx.dst->ty.pointer(CreateASTTypeFor(ctx, p->StoreType()), address_space, access);
+ return ctx.dst->ty.ptr(address_space, CreateASTTypeFor(ctx, p->StoreType()), access);
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name;
diff --git a/src/tint/ast/transform/transform_test.cc b/src/tint/ast/transform/transform_test.cc
index c61a321..8f86596 100644
--- a/src/tint/ast/transform/transform_test.cc
+++ b/src/tint/ast/transform/transform_test.cc
@@ -127,7 +127,7 @@
TEST_F(CreateASTTypeForTest, PrivatePointer) {
auto ptr = create([](ProgramBuilder& b) {
- return b.create<type::Pointer>(b.create<type::I32>(), builtin::AddressSpace::kPrivate,
+ return b.create<type::Pointer>(builtin::AddressSpace::kPrivate, b.create<type::I32>(),
builtin::Access::kReadWrite);
});
@@ -136,7 +136,7 @@
TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
auto ptr = create([](ProgramBuilder& b) {
- return b.create<type::Pointer>(b.create<type::I32>(), builtin::AddressSpace::kStorage,
+ return b.create<type::Pointer>(builtin::AddressSpace::kStorage, b.create<type::I32>(),
builtin::Access::kReadWrite);
});
diff --git a/src/tint/ast/transform/vectorize_scalar_matrix_initializers.cc b/src/tint/ast/transform/vectorize_scalar_matrix_initializers.cc
index 840f368..44ec2f9 100644
--- a/src/tint/ast/transform/vectorize_scalar_matrix_initializers.cc
+++ b/src/tint/ast/transform/vectorize_scalar_matrix_initializers.cc
@@ -34,7 +34,7 @@
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::ValueConstructor>() && call->Type()->Is<type::Matrix>()) {
auto& args = call->Arguments();
- if (!args.IsEmpty() && args[0]->Type()->UnwrapRef()->is_scalar()) {
+ if (!args.IsEmpty() && args[0]->Type()->UnwrapRef()->Is<type::Scalar>()) {
return true;
}
}
diff --git a/src/tint/debug.h b/src/tint/debug.h
index 8b8c4b3..dacf76b 100644
--- a/src/tint/debug.h
+++ b/src/tint/debug.h
@@ -136,4 +136,21 @@
} \
} while (false)
+/// TINT_ASSERT_OR_RETURN_VALUE() is a macro for checking the expression is true, triggering a
+/// TINT_ICE if it is not and returning a value from the calling function.
+/// The ICE message contains the callsite's file and line.
+/// @warning: Unlike TINT_ICE() and TINT_UNREACHABLE(), TINT_ASSERT_OR_RETURN_VALUE() does not
+/// append a message to an existing tint::diag::List. As such, TINT_ASSERT_OR_RETURN_VALUE()
+/// may silently fail in builds where SetInternalCompilerErrorReporter() is not
+/// called. Only use in places where there's no sensible place to put proper
+/// error handling.
+#define TINT_ASSERT_OR_RETURN_VALUE(system, condition, value) \
+ do { \
+ if (TINT_UNLIKELY(!(condition))) { \
+ tint::diag::List diagnostics; \
+ TINT_ICE(system, diagnostics) << "TINT_ASSERT(" #system ", " #condition ")"; \
+ return value; \
+ } \
+ } while (false)
+
#endif // SRC_TINT_DEBUG_H_
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index adafe3a..9a6d1e1 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -80,7 +80,7 @@
TINT_ASSERT(Inspector, type->is_numeric_scalar_or_vector());
ComponentType componentType = Switch(
- type::Type::DeepestElementOf(type), //
+ type->DeepestElement(), //
[&](const type::F32*) { return ComponentType::kF32; },
[&](const type::F16*) { return ComponentType::kF16; },
[&](const type::I32*) { return ComponentType::kI32; },
@@ -200,7 +200,7 @@
override.name = name;
override.id = global->OverrideId();
auto* type = var->Type();
- TINT_ASSERT(Inspector, type->is_scalar());
+ TINT_ASSERT(Inspector, type->Is<type::Scalar>());
if (type->is_bool_scalar_or_vector()) {
override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) {
diff --git a/src/tint/ir/access.h b/src/tint/ir/access.h
index 55f7306..04467e5 100644
--- a/src/tint/ir/access.h
+++ b/src/tint/ir/access.h
@@ -23,6 +23,12 @@
/// An access instruction in the IR.
class Access : public utils::Castable<Access, OperandInstruction<3>> {
public:
+ /// The base offset in Operands() for the object being accessed
+ static constexpr size_t kObjectOperandOffset = 0;
+
+ /// The base offset in Operands() for the access indices
+ static constexpr size_t kIndicesOperandOffset = 1;
+
/// Constructor
/// @param result_type the result type
/// @param object the accessor object
@@ -31,18 +37,13 @@
~Access() override;
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the object used for the access
- Value* Object() const { return operands_[0]; }
+ Value* Object() { return operands_[kObjectOperandOffset]; }
/// @returns the accessor indices
- utils::Slice<Value const* const> Indices() const {
- return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
- }
-
- /// @returns the accessor indices
- utils::Slice<Value*> Indices() { return operands_.Slice().Offset(1); }
+ utils::Slice<Value*> Indices() { return operands_.Slice().Offset(kIndicesOperandOffset); }
private:
const type::Type* result_type_ = nullptr;
diff --git a/src/tint/ir/access_test.cc b/src/tint/ir/access_test.cc
index 9e09930..1f0e947 100644
--- a/src/tint/ir/access_test.cc
+++ b/src/tint/ir/access_test.cc
@@ -24,11 +24,10 @@
using IR_AccessTest = IRTestHelper;
TEST_F(IR_AccessTest, SetsUsage) {
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
+ auto* type = ty.ptr<function, i32>();
+ auto* var = b.Var(type);
auto* idx = b.Constant(u32(1));
- auto* a = b.Access(mod.Types().i32(), var, utils::Vector{idx});
+ auto* a = b.Access(ty.i32(), var, idx);
EXPECT_THAT(var->Usages(), testing::UnorderedElementsAre(Usage{a, 0u}));
EXPECT_THAT(idx->Usages(), testing::UnorderedElementsAre(Usage{a, 1u}));
@@ -39,10 +38,9 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- b.Access(nullptr, var, utils::Vector{b.Constant(u32(1))});
+ auto* ty = (mod.Types().ptr<function, i32>());
+ auto* var = b.Var(ty);
+ b.Access(nullptr, var, u32(1));
},
"");
}
@@ -52,7 +50,7 @@
{
Module mod;
Builder b{mod};
- b.Access(mod.Types().i32(), nullptr, utils::Vector{b.Constant(u32(1))});
+ b.Access(mod.Types().i32(), nullptr, u32(1));
},
"");
}
@@ -62,9 +60,8 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
+ auto* ty = (mod.Types().ptr<function, i32>());
+ auto* var = b.Var(ty);
b.Access(mod.Types().i32(), var, utils::Empty);
},
"");
@@ -75,10 +72,9 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- b.Access(mod.Types().i32(), var, utils::Vector<Value*, 1>{nullptr});
+ auto* ty = (mod.Types().ptr<function, i32>());
+ auto* var = b.Var(ty);
+ b.Access(mod.Types().i32(), var, nullptr);
},
"");
}
diff --git a/src/tint/ir/binary.h b/src/tint/ir/binary.h
index 0998c85..6c8e6d4 100644
--- a/src/tint/ir/binary.h
+++ b/src/tint/ir/binary.h
@@ -55,16 +55,16 @@
~Binary() override;
/// @returns the kind of the binary instruction
- enum Kind Kind() const { return kind_; }
+ enum Kind Kind() { return kind_; }
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the left-hand-side value for the instruction
- const Value* LHS() const { return operands_[0]; }
+ Value* LHS() { return operands_[0]; }
/// @returns the right-hand-side value for the instruction
- const Value* RHS() const { return operands_[1]; }
+ Value* RHS() { return operands_[1]; }
private:
enum Kind kind_;
diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc
index 9bdaba0..8b3a90d 100644
--- a/src/tint/ir/binary_test.cc
+++ b/src/tint/ir/binary_test.cc
@@ -30,7 +30,7 @@
{
Module mod;
Builder b{mod};
- b.Add(nullptr, b.Constant(u32(1)), b.Constant(u32(2)));
+ b.Add(nullptr, u32(1), u32(2));
},
"");
}
@@ -40,7 +40,7 @@
{
Module mod;
Builder b{mod};
- b.Add(mod.Types().u32(), nullptr, b.Constant(u32(2)));
+ b.Add(mod.Types().u32(), nullptr, u32(2));
},
"");
}
@@ -50,13 +50,13 @@
{
Module mod;
Builder b{mod};
- b.Add(mod.Types().u32(), b.Constant(u32(1)), nullptr);
+ b.Add(mod.Types().u32(), u32(1), nullptr);
},
"");
}
TEST_F(IR_BinaryTest, CreateAnd) {
- const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.And(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -74,7 +74,7 @@
}
TEST_F(IR_BinaryTest, CreateOr) {
- const auto* inst = b.Or(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Or(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kOr);
@@ -91,7 +91,7 @@
}
TEST_F(IR_BinaryTest, CreateXor) {
- const auto* inst = b.Xor(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Xor(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kXor);
@@ -108,7 +108,7 @@
}
TEST_F(IR_BinaryTest, CreateEqual) {
- const auto* inst = b.Equal(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Equal(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual);
@@ -125,7 +125,7 @@
}
TEST_F(IR_BinaryTest, CreateNotEqual) {
- const auto* inst = b.NotEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.NotEqual(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kNotEqual);
@@ -142,7 +142,7 @@
}
TEST_F(IR_BinaryTest, CreateLessThan) {
- const auto* inst = b.LessThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.LessThan(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThan);
@@ -159,7 +159,7 @@
}
TEST_F(IR_BinaryTest, CreateGreaterThan) {
- const auto* inst = b.GreaterThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.GreaterThan(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThan);
@@ -176,7 +176,7 @@
}
TEST_F(IR_BinaryTest, CreateLessThanEqual) {
- const auto* inst = b.LessThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.LessThanEqual(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThanEqual);
@@ -193,7 +193,7 @@
}
TEST_F(IR_BinaryTest, CreateGreaterThanEqual) {
- const auto* inst = b.GreaterThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.GreaterThanEqual(mod.Types().bool_(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThanEqual);
@@ -210,7 +210,7 @@
}
TEST_F(IR_BinaryTest, CreateNot) {
- const auto* inst = b.Not(mod.Types().bool_(), b.Constant(true));
+ auto* inst = b.Not(mod.Types().bool_(), true);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual);
@@ -227,7 +227,7 @@
}
TEST_F(IR_BinaryTest, CreateShiftLeft) {
- const auto* inst = b.ShiftLeft(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.ShiftLeft(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftLeft);
@@ -244,7 +244,7 @@
}
TEST_F(IR_BinaryTest, CreateShiftRight) {
- const auto* inst = b.ShiftRight(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.ShiftRight(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftRight);
@@ -261,7 +261,7 @@
}
TEST_F(IR_BinaryTest, CreateAdd) {
- const auto* inst = b.Add(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Add(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kAdd);
@@ -278,7 +278,7 @@
}
TEST_F(IR_BinaryTest, CreateSubtract) {
- const auto* inst = b.Subtract(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Subtract(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kSubtract);
@@ -295,7 +295,7 @@
}
TEST_F(IR_BinaryTest, CreateMultiply) {
- const auto* inst = b.Multiply(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Multiply(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kMultiply);
@@ -312,7 +312,7 @@
}
TEST_F(IR_BinaryTest, CreateDivide) {
- const auto* inst = b.Divide(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Divide(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kDivide);
@@ -329,7 +329,7 @@
}
TEST_F(IR_BinaryTest, CreateModulo) {
- const auto* inst = b.Modulo(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.Modulo(mod.Types().i32(), 4_i, 2_i);
ASSERT_TRUE(inst->Is<Binary>());
EXPECT_EQ(inst->Kind(), Binary::Kind::kModulo);
@@ -346,7 +346,7 @@
}
TEST_F(IR_BinaryTest, Binary_Usage) {
- auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i));
+ auto* inst = b.And(mod.Types().i32(), 4_i, 2_i);
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -358,7 +358,7 @@
}
TEST_F(IR_BinaryTest, Binary_Usage_DuplicateValue) {
- auto val = b.Constant(4_i);
+ auto val = 4_i;
auto* inst = b.And(mod.Types().i32(), val, val);
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
@@ -372,7 +372,7 @@
TEST_F(IR_BinaryTest, Binary_Usage_SetOperand) {
auto* rhs_a = b.Constant(2_i);
auto* rhs_b = b.Constant(3_i);
- auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), rhs_a);
+ auto* inst = b.And(mod.Types().i32(), 4_i, rhs_a);
EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd);
diff --git a/src/tint/ir/bitcast_test.cc b/src/tint/ir/bitcast_test.cc
index e8ed473..e925349 100644
--- a/src/tint/ir/bitcast_test.cc
+++ b/src/tint/ir/bitcast_test.cc
@@ -27,12 +27,12 @@
using IR_BitcastTest = IRTestHelper;
TEST_F(IR_BitcastTest, Bitcast) {
- const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Bitcast(mod.Types().i32(), 4_i);
ASSERT_TRUE(inst->Is<ir::Bitcast>());
ASSERT_NE(inst->Type(), nullptr);
- const auto args = inst->Args();
+ auto args = inst->Args();
ASSERT_EQ(args.Length(), 1u);
ASSERT_TRUE(args[0]->Is<Constant>());
auto val = args[0]->As<Constant>()->Value();
@@ -41,9 +41,9 @@
}
TEST_F(IR_BitcastTest, Bitcast_Usage) {
- auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Bitcast(mod.Types().i32(), 4_i);
- const auto args = inst->Args();
+ auto args = inst->Args();
ASSERT_EQ(args.Length(), 1u);
ASSERT_NE(args[0], nullptr);
EXPECT_THAT(args[0]->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
@@ -64,7 +64,7 @@
{
Module mod;
Builder b{mod};
- b.Bitcast(nullptr, b.Constant(u32(1)));
+ b.Bitcast(nullptr, 1_i);
},
"");
}
diff --git a/src/tint/ir/block.cc b/src/tint/ir/block.cc
index b545f34..4dad6c2 100644
--- a/src/tint/ir/block.cc
+++ b/src/tint/ir/block.cc
@@ -22,25 +22,9 @@
Block::~Block() = default;
-void Block::SetParams(utils::VectorRef<const BlockParam*> params) {
- params_ = std::move(params);
-
- for (auto* param : params_) {
- TINT_ASSERT(IR, param != nullptr);
- }
-}
-
-void Block::AddInboundBranch(ir::Branch* node) {
- TINT_ASSERT(IR, node != nullptr);
-
- if (node) {
- inbound_branches_.Push(node);
- }
-}
-
-void Block::Prepend(Instruction* inst) {
- TINT_ASSERT_OR_RETURN(IR, inst);
- TINT_ASSERT_OR_RETURN(IR, inst->Block() == nullptr);
+Instruction* Block::Prepend(Instruction* inst) {
+ TINT_ASSERT_OR_RETURN_VALUE(IR, inst, inst);
+ TINT_ASSERT_OR_RETURN_VALUE(IR, inst->Block() == nullptr, inst);
inst->SetBlock(this);
instructions_.count += 1;
@@ -53,11 +37,13 @@
instructions_.first->prev = inst;
instructions_.first = inst;
}
+
+ return inst;
}
-void Block::Append(Instruction* inst) {
- TINT_ASSERT_OR_RETURN(IR, inst);
- TINT_ASSERT_OR_RETURN(IR, inst->Block() == nullptr);
+Instruction* Block::Append(Instruction* inst) {
+ TINT_ASSERT_OR_RETURN_VALUE(IR, inst, inst);
+ TINT_ASSERT_OR_RETURN_VALUE(IR, inst->Block() == nullptr, inst);
inst->SetBlock(this);
instructions_.count += 1;
@@ -70,6 +56,8 @@
instructions_.last->next = inst;
instructions_.last = inst;
}
+
+ return inst;
}
void Block::InsertBefore(Instruction* before, Instruction* inst) {
@@ -175,4 +163,10 @@
}
}
+void Block::SetInstructions(std::initializer_list<Instruction*> instructions) {
+ for (auto* i : instructions) {
+ Append(i);
+ }
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/block.h b/src/tint/ir/block.h
index 1da6fb9..9b9a4ca 100644
--- a/src/tint/ir/block.h
+++ b/src/tint/ir/block.h
@@ -17,7 +17,6 @@
#include <utility>
-#include "src/tint/ir/block_param.h"
#include "src/tint/ir/branch.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/utils/vector.h"
@@ -39,12 +38,12 @@
~Block() override;
/// @returns true if this is block has a branch target set
- bool HasBranchTarget() const {
+ bool HasBranchTarget() {
return instructions_.last != nullptr && instructions_.last->Is<ir::Branch>();
}
/// @return the node this block branches to or nullptr if the block doesn't branch
- const ir::Branch* Branch() const {
+ ir::Branch* Branch() {
if (!HasBranchTarget()) {
return nullptr;
}
@@ -55,8 +54,12 @@
/// @param instructions the instructions to set
void SetInstructions(utils::VectorRef<Instruction*> instructions);
+ /// Sets the instructions in the block
+ /// @param instructions the instructions to set
+ void SetInstructions(std::initializer_list<Instruction*> instructions);
+
/// @returns the instructions in the block
- Instruction* Instructions() const { return instructions_.first; }
+ Instruction* Instructions() { return instructions_.first; }
/// Iterator for the instructions inside a block
class Iterator {
@@ -92,23 +95,25 @@
};
/// @returns the iterator pointing to the start of the instruction list
- Iterator begin() const { return Iterator{instructions_.first}; }
+ Iterator begin() { return Iterator{instructions_.first}; }
/// @returns the ending iterator
- Iterator end() const { return Iterator{nullptr}; }
+ Iterator end() { return Iterator{nullptr}; }
/// @returns the first instruction in the instruction list
- Instruction* Front() const { return instructions_.first; }
+ Instruction* Front() { return instructions_.first; }
/// @returns the last instruction in the instruction list
- Instruction* Back() const { return instructions_.last; }
+ Instruction* Back() { return instructions_.last; }
/// Adds the instruction to the beginning of the block
/// @param inst the instruction to add
- void Prepend(Instruction* inst);
+ /// @returns the instruction to allow calls to be chained
+ Instruction* Prepend(Instruction* inst);
/// Adds the instruction to the end of the block
/// @param inst the instruction to add
- void Append(Instruction* inst);
+ /// @returns the instruction to allow calls to be chained
+ Instruction* Append(Instruction* inst);
/// Adds the new instruction before the given instruction
/// @param before the instruction to insert before
/// @param inst the instruction to insert
@@ -126,28 +131,13 @@
void Remove(Instruction* inst);
/// @returns true if the block contains no instructions
- bool IsEmpty() const { return Length() == 0; }
+ bool IsEmpty() { return Length() == 0; }
/// @returns the number of instructions in the block
- size_t Length() const { return instructions_.count; }
-
- /// Sets the params to the block
- /// @param params the params for the block
- void SetParams(utils::VectorRef<const BlockParam*> params);
- /// @return the parameters passed into the block
- utils::VectorRef<const BlockParam*> Params() const { return params_; }
- /// @returns the params to the block
- utils::Vector<const BlockParam*, 0>& Params() { return params_; }
-
- /// @returns the inbound branch list for the block
- utils::VectorRef<ir::Branch*> InboundBranches() const { return inbound_branches_; }
-
- /// Adds the given node to the inbound branches
- /// @param node the node to add
- void AddInboundBranch(ir::Branch* node);
+ size_t Length() { return instructions_.count; }
/// @return the parent instruction that owns this block
- ControlInstruction* Parent() const { return parent_; }
+ ControlInstruction* Parent() { return parent_; }
/// @param parent the parent instruction that owns this block
void SetParent(ControlInstruction* parent) { parent_ = parent; }
@@ -159,15 +149,6 @@
size_t count = 0;
} instructions_;
- utils::Vector<const BlockParam*, 0> params_;
-
- /// The list of branches into this node. This list maybe empty for several
- /// reasons:
- /// - Node is a start node
- /// - Node is a merge target outside control flow (e.g. an if that returns in both branches)
- /// - Node is a continue target outside control flow (e.g. a loop that returns)
- utils::Vector<ir::Branch*, 2> inbound_branches_;
-
ControlInstruction* parent_ = nullptr;
};
diff --git a/src/tint/ir/block_param.h b/src/tint/ir/block_param.h
index 0a90ba8..51a257c 100644
--- a/src/tint/ir/block_param.h
+++ b/src/tint/ir/block_param.h
@@ -29,7 +29,7 @@
~BlockParam() override;
/// @returns the type of the var
- const type::Type* Type() const override { return type_; }
+ const type::Type* Type() override { return type_; }
private:
/// the result type of the instruction
diff --git a/src/tint/ir/block_test.cc b/src/tint/ir/block_test.cc
index 8a5308f..3549a05 100644
--- a/src/tint/ir/block_test.cc
+++ b/src/tint/ir/block_test.cc
@@ -14,7 +14,6 @@
#include "src/tint/ir/block.h"
#include "gtest/gtest-spi.h"
-#include "src/tint/ir/block_param.h"
#include "src/tint/ir/ir_test_helper.h"
namespace tint::ir {
@@ -24,91 +23,73 @@
using IR_BlockTest = IRTestHelper;
TEST_F(IR_BlockTest, HasBranchTarget_Empty) {
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
EXPECT_FALSE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_NoBranch) {
- auto* blk = b.CreateBlock();
- blk->Append(b.Add(mod.Types().i32(), b.Constant(1_u), b.Constant(2_u)));
+ auto* blk = b.Block();
+ blk->Append(b.Add(mod.Types().i32(), 1_u, 2_u));
EXPECT_FALSE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_BreakIf) {
- auto* blk = b.CreateBlock();
- auto* loop = b.CreateLoop();
- blk->Append(b.BreakIf(b.Constant(true), loop));
+ auto* blk = b.Block();
+ auto* loop = b.Loop();
+ blk->Append(b.BreakIf(true, loop));
EXPECT_TRUE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_Continue) {
- auto* blk = b.CreateBlock();
- auto* loop = b.CreateLoop();
+ auto* blk = b.Block();
+ auto* loop = b.Loop();
blk->Append(b.Continue(loop));
EXPECT_TRUE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_ExitIf) {
- auto* blk = b.CreateBlock();
- auto* if_ = b.CreateIf(b.Constant(true));
+ auto* blk = b.Block();
+ auto* if_ = b.If(true);
blk->Append(b.ExitIf(if_));
EXPECT_TRUE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_ExitLoop) {
- auto* blk = b.CreateBlock();
- auto* loop = b.CreateLoop();
+ auto* blk = b.Block();
+ auto* loop = b.Loop();
blk->Append(b.ExitLoop(loop));
EXPECT_TRUE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_ExitSwitch) {
- auto* blk = b.CreateBlock();
- auto* s = b.CreateSwitch(b.Constant(1_u));
+ auto* blk = b.Block();
+ auto* s = b.Switch(1_u);
blk->Append(b.ExitSwitch(s));
EXPECT_TRUE(blk->HasBranchTarget());
}
-TEST_F(IR_BlockTest, HasBranchTarget_If) {
- auto* blk = b.CreateBlock();
- blk->Append(b.CreateIf(b.Constant(true)));
- EXPECT_TRUE(blk->HasBranchTarget());
-}
-
-TEST_F(IR_BlockTest, HasBranchTarget_Loop) {
- auto* blk = b.CreateBlock();
- blk->Append(b.CreateLoop());
- EXPECT_TRUE(blk->HasBranchTarget());
-}
-
TEST_F(IR_BlockTest, HasBranchTarget_NextIteration) {
- auto* blk = b.CreateBlock();
- auto* loop = b.CreateLoop();
+ auto* blk = b.Block();
+ auto* loop = b.Loop();
blk->Append(b.NextIteration(loop));
EXPECT_TRUE(blk->HasBranchTarget());
}
TEST_F(IR_BlockTest, HasBranchTarget_Return) {
- auto* f = b.CreateFunction("myFunc", mod.Types().void_());
+ auto* f = b.Function("myFunc", mod.Types().void_());
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(b.Return(f));
EXPECT_TRUE(blk->HasBranchTarget());
}
-TEST_F(IR_BlockTest, HasBranchTarget_Switch) {
- auto* blk = b.CreateBlock();
- blk->Append(b.CreateSwitch(b.Constant(true)));
- EXPECT_TRUE(blk->HasBranchTarget());
-}
-
TEST_F(IR_BlockTest, SetInstructions) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst1, inst2, inst3});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst1, inst2, inst3});
ASSERT_EQ(inst1->Block(), blk);
ASSERT_EQ(inst2->Block(), blk);
@@ -132,14 +113,14 @@
}
TEST_F(IR_BlockTest, Append) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->Append(inst1);
- blk->Append(inst2);
- blk->Append(inst3);
+ auto* blk = b.Block();
+ EXPECT_EQ(blk->Append(inst1), inst1);
+ EXPECT_EQ(blk->Append(inst2), inst2);
+ EXPECT_EQ(blk->Append(inst3), inst3);
ASSERT_EQ(inst1->Block(), blk);
ASSERT_EQ(inst2->Block(), blk);
@@ -163,14 +144,14 @@
}
TEST_F(IR_BlockTest, Prepend) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->Prepend(inst3);
- blk->Prepend(inst2);
- blk->Prepend(inst1);
+ auto* blk = b.Block();
+ EXPECT_EQ(blk->Prepend(inst3), inst3);
+ EXPECT_EQ(blk->Prepend(inst2), inst2);
+ EXPECT_EQ(blk->Prepend(inst1), inst1);
ASSERT_EQ(inst1->Block(), blk);
ASSERT_EQ(inst2->Block(), blk);
@@ -194,10 +175,10 @@
}
TEST_F(IR_BlockTest, InsertBefore_AtStart) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(inst2);
blk->InsertBefore(inst2, inst1);
@@ -218,11 +199,11 @@
}
TEST_F(IR_BlockTest, InsertBefore_Middle) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->Append(inst3);
blk->InsertBefore(inst3, inst2);
@@ -249,10 +230,10 @@
}
TEST_F(IR_BlockTest, InsertAfter_AtEnd) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->InsertAfter(inst1, inst2);
@@ -273,11 +254,11 @@
}
TEST_F(IR_BlockTest, InsertAfter_Middle) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->Append(inst3);
blk->InsertAfter(inst1, inst2);
@@ -304,13 +285,13 @@
}
TEST_F(IR_BlockTest, Replace_Middle) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst3 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst3 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst1, inst4, inst3});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst1, inst4, inst3});
blk->Replace(inst4, inst2);
ASSERT_EQ(inst1->Block(), blk);
@@ -336,12 +317,12 @@
}
TEST_F(IR_BlockTest, Replace_Start) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst4, inst2});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst4, inst2});
blk->Replace(inst4, inst1);
ASSERT_EQ(inst1->Block(), blk);
@@ -362,12 +343,12 @@
}
TEST_F(IR_BlockTest, Replace_End) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst1, inst4});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst1, inst4});
blk->Replace(inst4, inst2);
ASSERT_EQ(inst1->Block(), blk);
@@ -388,11 +369,11 @@
}
TEST_F(IR_BlockTest, Replace_OnlyNode) {
- auto* inst1 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst4});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst4});
blk->Replace(inst4, inst1);
ASSERT_EQ(inst1->Block(), blk);
@@ -408,12 +389,12 @@
}
TEST_F(IR_BlockTest, Remove_Middle) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst1, inst4, inst2});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst1, inst4, inst2});
blk->Remove(inst4);
ASSERT_EQ(inst4->Block(), nullptr);
@@ -432,11 +413,11 @@
}
TEST_F(IR_BlockTest, Remove_Start) {
- auto* inst1 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst4, inst1});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst4, inst1});
blk->Remove(inst4);
ASSERT_EQ(inst4->Block(), nullptr);
@@ -451,11 +432,11 @@
}
TEST_F(IR_BlockTest, Remove_End) {
- auto* inst1 = b.CreateLoop();
- auto* inst4 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst1, inst4});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst1, inst4});
blk->Remove(inst4);
ASSERT_EQ(inst4->Block(), nullptr);
@@ -470,10 +451,10 @@
}
TEST_F(IR_BlockTest, Remove_OnlyNode) {
- auto* inst4 = b.CreateLoop();
+ auto* inst4 = b.Loop();
- auto* blk = b.CreateBlock();
- blk->SetInstructions(utils::Vector{inst4});
+ auto* blk = b.Block();
+ blk->SetInstructions({inst4});
blk->Remove(inst4);
ASSERT_EQ(inst4->Block(), nullptr);
@@ -488,7 +469,7 @@
Module mod;
Builder b{mod};
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Prepend(nullptr);
},
"internal compiler error");
@@ -500,8 +481,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Prepend(inst1);
blk->Prepend(inst1);
@@ -515,7 +496,7 @@
Module mod;
Builder b{mod};
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Append(nullptr);
},
"internal compiler error");
@@ -527,8 +508,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->Append(inst1);
},
@@ -541,8 +522,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->InsertBefore(nullptr, inst1);
},
"internal compiler error");
@@ -554,8 +535,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->InsertBefore(inst1, nullptr);
},
@@ -568,10 +549,10 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
- auto* blk2 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
+ auto* blk2 = b.Block();
blk2->Append(inst1);
blk1->InsertBefore(inst1, inst2);
},
@@ -584,9 +565,9 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
blk1->Append(inst1);
blk1->Append(inst2);
blk1->InsertBefore(inst1, inst2);
@@ -600,8 +581,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->InsertAfter(nullptr, inst1);
},
"internal compiler error");
@@ -613,8 +594,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->InsertAfter(inst1, nullptr);
},
@@ -627,10 +608,10 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
- auto* blk2 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
+ auto* blk2 = b.Block();
blk2->Append(inst1);
blk1->InsertAfter(inst1, inst2);
},
@@ -643,9 +624,9 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
blk1->Append(inst1);
blk1->Append(inst2);
blk1->InsertAfter(inst1, inst2);
@@ -659,8 +640,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Replace(nullptr, inst1);
},
"internal compiler error");
@@ -672,8 +653,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
blk->Replace(inst1, nullptr);
},
@@ -686,10 +667,10 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
- auto* blk2 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
+ auto* blk2 = b.Block();
blk2->Append(inst1);
blk1->Replace(inst1, inst2);
},
@@ -702,9 +683,9 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk1 = b.Block();
blk1->Append(inst1);
blk1->Append(inst2);
blk1->Replace(inst1, inst2);
@@ -718,7 +699,7 @@
Module mod;
Builder b{mod};
- auto* blk = b.CreateBlock();
+ auto* blk = b.Block();
blk->Remove(nullptr);
},
"internal compiler error");
@@ -730,38 +711,14 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk1 = b.CreateBlock();
- auto* blk2 = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk1 = b.Block();
+ auto* blk2 = b.Block();
blk2->Append(inst1);
blk1->Remove(inst1);
},
"internal compiler error");
}
-TEST_F(IR_BlockTest, Fail_NullBlockParam) {
- EXPECT_FATAL_FAILURE(
- {
- Module mod;
- Builder b{mod};
-
- auto* blk = b.CreateBlock();
- blk->SetParams(utils::Vector<const BlockParam*, 1>{nullptr});
- },
- "");
-}
-
-TEST_F(IR_BlockTest, Fail_NullInboundBranch) {
- EXPECT_FATAL_FAILURE(
- {
- Module mod;
- Builder b{mod};
-
- auto* blk = b.CreateBlock();
- blk->AddInboundBranch(nullptr);
- },
- "");
-}
-
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/branch.h b/src/tint/ir/branch.h
index f7dfd12..15a7d42 100644
--- a/src/tint/ir/branch.h
+++ b/src/tint/ir/branch.h
@@ -32,7 +32,7 @@
~Branch() override;
/// @returns the branch arguments
- virtual utils::Slice<Value const* const> Args() const { return operands_.Slice(); }
+ virtual utils::Slice<Value* const> Args() { return operands_.Slice(); }
};
} // namespace tint::ir
diff --git a/src/tint/ir/break_if.cc b/src/tint/ir/break_if.cc
index 644502b..4a18fb8 100644
--- a/src/tint/ir/break_if.cc
+++ b/src/tint/ir/break_if.cc
@@ -18,6 +18,7 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::BreakIf);
@@ -32,8 +33,8 @@
AddOperand(condition);
if (loop_) {
- loop_->Body()->AddInboundBranch(this);
- loop_->Merge()->AddInboundBranch(this);
+ loop_->Body()->AddInboundSiblingBranch(this);
+ loop_->Merge()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/break_if.h b/src/tint/ir/break_if.h
index b673cd8..1eba580 100644
--- a/src/tint/ir/break_if.h
+++ b/src/tint/ir/break_if.h
@@ -37,15 +37,13 @@
~BreakIf() override;
/// @returns the branch arguments
- utils::Slice<Value const* const> Args() const override {
- return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
- }
+ utils::Slice<Value* const> Args() override { return operands_.Slice().Offset(1); }
/// @returns the break condition
- const Value* Condition() const { return operands_[0]; }
+ Value* Condition() { return operands_[0]; }
/// @returns the loop containing the break-if
- const ir::Loop* Loop() const { return loop_; }
+ ir::Loop* Loop() { return loop_; }
private:
ir::Loop* loop_ = nullptr;
diff --git a/src/tint/ir/break_if_test.cc b/src/tint/ir/break_if_test.cc
index 96363f1..9dd095a 100644
--- a/src/tint/ir/break_if_test.cc
+++ b/src/tint/ir/break_if_test.cc
@@ -25,12 +25,12 @@
using IR_BreakIfTest = IRTestHelper;
TEST_F(IR_BreakIfTest, Usage) {
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
auto* cond = b.Constant(true);
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* brk = b.BreakIf(cond, loop, utils::Vector{arg1, arg2});
+ auto* brk = b.BreakIf(cond, loop, arg1, arg2);
EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{brk, 0u}));
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{brk, 1u}));
@@ -42,7 +42,7 @@
{
Module mod;
Builder b{mod};
- b.BreakIf(nullptr, b.CreateLoop());
+ b.BreakIf(nullptr, b.Loop());
},
"");
}
@@ -52,7 +52,7 @@
{
Module mod;
Builder b{mod};
- b.BreakIf(b.Constant(true), nullptr);
+ b.BreakIf(true, nullptr);
},
"");
}
@@ -62,7 +62,7 @@
{
Module mod;
Builder b{mod};
- b.BreakIf(b.Constant(true), b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ b.BreakIf(true, b.Loop(), nullptr);
},
"");
}
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index fdf11b9..e2d2b83 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -26,220 +26,54 @@
Builder::~Builder() = default;
-ir::Block* Builder::CreateRootBlockIfNeeded() {
+ir::Block* Builder::RootBlock() {
if (!ir.root_block) {
- ir.root_block = CreateBlock();
+ ir.root_block = Block();
}
return ir.root_block;
}
-Block* Builder::CreateBlock() {
- return ir.blocks.Create<Block>();
+Block* Builder::Block() {
+ return ir.blocks.Create<ir::Block>();
}
-Function* Builder::CreateFunction(std::string_view name,
- const type::Type* return_type,
- Function::PipelineStage stage,
- std::optional<std::array<uint32_t, 3>> wg_size) {
- auto* ir_func = ir.values.Create<Function>(return_type, stage, wg_size);
- ir_func->SetStartTarget(CreateBlock());
+MultiInBlock* Builder::MultiInBlock() {
+ return ir.blocks.Create<ir::MultiInBlock>();
+}
+
+Function* Builder::Function(std::string_view name,
+ const type::Type* return_type,
+ Function::PipelineStage stage,
+ std::optional<std::array<uint32_t, 3>> wg_size) {
+ auto* ir_func = ir.values.Create<ir::Function>(return_type, stage, wg_size);
+ ir_func->SetStartTarget(Block());
ir.SetName(ir_func, name);
return ir_func;
}
-If* Builder::CreateIf(Value* condition) {
- return ir.values.Create<If>(condition, CreateBlock(), CreateBlock(), CreateBlock());
+ir::Loop* Builder::Loop() {
+ return ir.values.Create<ir::Loop>(Block(), MultiInBlock(), MultiInBlock(), MultiInBlock());
}
-Loop* Builder::CreateLoop() {
- return ir.values.Create<Loop>(CreateBlock(), CreateBlock(), CreateBlock(), CreateBlock());
+Block* Builder::Case(ir::Switch* s, utils::VectorRef<Switch::CaseSelector> selectors) {
+ auto* block = Block();
+ s->Cases().Push(Switch::Case{std::move(selectors), block});
+ block->SetParent(s);
+ return block;
}
-Switch* Builder::CreateSwitch(Value* condition) {
- return ir.values.Create<Switch>(condition, CreateBlock());
-}
-
-Block* Builder::CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors) {
- s->Cases().Push(Switch::Case{std::move(selectors), CreateBlock()});
-
- Block* b = s->Cases().Back().Start();
- b->AddInboundBranch(s);
- b->SetParent(s);
- return b;
-}
-
-Binary* Builder::CreateBinary(enum Binary::Kind kind,
- const type::Type* type,
- Value* lhs,
- Value* rhs) {
- return ir.values.Create<ir::Binary>(kind, type, lhs, rhs);
-}
-
-Binary* Builder::And(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kAnd, type, lhs, rhs);
-}
-
-Binary* Builder::Or(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kOr, type, lhs, rhs);
-}
-
-Binary* Builder::Xor(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kXor, type, lhs, rhs);
-}
-
-Binary* Builder::Equal(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kEqual, type, lhs, rhs);
-}
-
-Binary* Builder::NotEqual(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kNotEqual, type, lhs, rhs);
-}
-
-Binary* Builder::LessThan(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kLessThan, type, lhs, rhs);
-}
-
-Binary* Builder::GreaterThan(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kGreaterThan, type, lhs, rhs);
-}
-
-Binary* Builder::LessThanEqual(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kLessThanEqual, type, lhs, rhs);
-}
-
-Binary* Builder::GreaterThanEqual(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kGreaterThanEqual, type, lhs, rhs);
-}
-
-Binary* Builder::ShiftLeft(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kShiftLeft, type, lhs, rhs);
-}
-
-Binary* Builder::ShiftRight(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kShiftRight, type, lhs, rhs);
-}
-
-Binary* Builder::Add(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kAdd, type, lhs, rhs);
-}
-
-Binary* Builder::Subtract(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kSubtract, type, lhs, rhs);
-}
-
-Binary* Builder::Multiply(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kMultiply, type, lhs, rhs);
-}
-
-Binary* Builder::Divide(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kDivide, type, lhs, rhs);
-}
-
-Binary* Builder::Modulo(const type::Type* type, Value* lhs, Value* rhs) {
- return CreateBinary(Binary::Kind::kModulo, type, lhs, rhs);
-}
-
-Unary* Builder::CreateUnary(enum Unary::Kind kind, const type::Type* type, Value* val) {
- return ir.values.Create<ir::Unary>(kind, type, val);
-}
-
-Unary* Builder::Complement(const type::Type* type, Value* val) {
- return CreateUnary(Unary::Kind::kComplement, type, val);
-}
-
-Unary* Builder::Negation(const type::Type* type, Value* val) {
- return CreateUnary(Unary::Kind::kNegation, type, val);
-}
-
-Binary* Builder::Not(const type::Type* type, Value* val) {
- return Equal(type, val, Constant(false));
-}
-
-ir::Bitcast* Builder::Bitcast(const type::Type* type, Value* val) {
- return ir.values.Create<ir::Bitcast>(type, val);
+Block* Builder::Case(ir::Switch* s, std::initializer_list<Switch::CaseSelector> selectors) {
+ return Case(s, utils::Vector<Switch::CaseSelector, 4>(selectors));
}
ir::Discard* Builder::Discard() {
return ir.values.Create<ir::Discard>(ir.Types().void_());
}
-ir::UserCall* Builder::UserCall(const type::Type* type,
- Function* func,
- utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::UserCall>(type, func, std::move(args));
-}
-
-ir::Convert* Builder::Convert(const type::Type* to,
- const type::Type* from,
- utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::Convert>(to, from, std::move(args));
-}
-
-ir::Construct* Builder::Construct(const type::Type* to, utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::Construct>(to, std::move(args));
-}
-
-ir::Builtin* Builder::Builtin(const type::Type* type,
- builtin::Function func,
- utils::VectorRef<Value*> args) {
- return ir.values.Create<ir::Builtin>(type, func, args);
-}
-
-ir::Load* Builder::Load(Value* from) {
- TINT_ASSERT(IR, from != nullptr);
- if (from == nullptr) {
- return nullptr;
- }
-
- auto* ptr = from->Type()->As<type::Pointer>();
- TINT_ASSERT(IR, ptr != nullptr);
- if (ptr == nullptr) {
- return nullptr;
- }
-
- return ir.values.Create<ir::Load>(ptr->StoreType(), from);
-}
-
-ir::Store* Builder::Store(Value* to, Value* from) {
- return ir.values.Create<ir::Store>(to, from);
-}
-
-ir::Var* Builder::Declare(const type::Type* type) {
+ir::Var* Builder::Var(const type::Pointer* type) {
return ir.values.Create<ir::Var>(type);
}
-ir::Return* Builder::Return(Function* func, utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::Return>(func, std::move(args));
-}
-
-ir::NextIteration* Builder::NextIteration(Loop* loop,
- utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::NextIteration>(loop, std::move(args));
-}
-
-ir::BreakIf* Builder::BreakIf(Value* condition,
- Loop* loop,
- utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::BreakIf>(condition, loop, std::move(args));
-}
-
-ir::Continue* Builder::Continue(Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::Continue>(loop, std::move(args));
-}
-
-ir::ExitSwitch* Builder::ExitSwitch(Switch* sw,
- utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::ExitSwitch>(sw, std::move(args));
-}
-
-ir::ExitLoop* Builder::ExitLoop(Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::ExitLoop>(loop, std::move(args));
-}
-
-ir::ExitIf* Builder::ExitIf(If* i, utils::VectorRef<Value*> args /* = utils::Empty */) {
- return ir.values.Create<ir::ExitIf>(i, std::move(args));
-}
-
ir::BlockParam* Builder::BlockParam(const type::Type* type) {
return ir.values.Create<ir::BlockParam>(type);
}
@@ -248,16 +82,16 @@
return ir.values.Create<ir::FunctionParam>(type);
}
-ir::Access* Builder::Access(const type::Type* type,
- Value* source,
- utils::VectorRef<Value*> indices) {
- return ir.values.Create<ir::Access>(type, source, indices);
+ir::Swizzle* Builder::Swizzle(const type::Type* type,
+ ir::Value* object,
+ utils::VectorRef<uint32_t> indices) {
+ return ir.values.Create<ir::Swizzle>(type, object, std::move(indices));
}
ir::Swizzle* Builder::Swizzle(const type::Type* type,
- Value* source,
- utils::VectorRef<uint32_t> indices) {
- return ir.values.Create<ir::Swizzle>(type, source, indices);
+ ir::Value* object,
+ std::initializer_list<uint32_t> indices) {
+ return ir.values.Create<ir::Swizzle>(type, object, utils::Vector<uint32_t, 4>(indices));
}
} // namespace tint::ir
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 4849743..1bb2973 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -23,7 +23,7 @@
#include "src/tint/ir/bitcast.h"
#include "src/tint/ir/block_param.h"
#include "src/tint/ir/break_if.h"
-#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/builtin_call.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/construct.h"
#include "src/tint/ir/continue.h"
@@ -38,6 +38,7 @@
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
@@ -60,6 +61,18 @@
/// Builds an ir::Module
class Builder {
+ /// A helper used to enable overloads if the first type in `TYPES` is a utils::Vector or
+ /// utils::VectorRef.
+ template <typename... TYPES>
+ using EnableIfVectorLike = utils::traits::EnableIf<
+ utils::IsVectorLike<utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>>;
+
+ /// A helper used to disable overloads if the first type in `TYPES` is a utils::Vector or
+ /// utils::VectorRef.
+ template <typename... TYPES>
+ using DisableIfVectorLike = utils::traits::EnableIf<
+ !utils::IsVectorLike<utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>>;
+
public:
/// Constructor
/// @param mod the ir::Module to wrap with this builder
@@ -67,8 +80,11 @@
/// Destructor
~Builder();
- /// @returns a new block flow node
- Block* CreateBlock();
+ /// @returns a new block
+ ir::Block* Block();
+
+ /// @returns a new multi-in block
+ ir::MultiInBlock* MultiInBlock();
/// Creates a function flow node
/// @param name the function name
@@ -76,30 +92,43 @@
/// @param stage the function stage
/// @param wg_size the workgroup_size
/// @returns the flow node
- Function* CreateFunction(std::string_view name,
- const type::Type* return_type,
- Function::PipelineStage stage = Function::PipelineStage::kUndefined,
- std::optional<std::array<uint32_t, 3>> wg_size = {});
+ ir::Function* Function(std::string_view name,
+ const type::Type* return_type,
+ Function::PipelineStage stage = Function::PipelineStage::kUndefined,
+ std::optional<std::array<uint32_t, 3>> wg_size = {});
/// Creates an if flow node
/// @param condition the if condition
/// @returns the flow node
- If* CreateIf(Value* condition);
+ template <typename T>
+ ir::If* If(T&& condition) {
+ return ir.values.Create<ir::If>(Value(std::forward<T>(condition)), Block(), Block(),
+ MultiInBlock());
+ }
/// Creates a loop flow node
/// @returns the flow node
- Loop* CreateLoop();
+ ir::Loop* Loop();
/// Creates a switch flow node
/// @param condition the switch condition
/// @returns the flow node
- Switch* CreateSwitch(Value* condition);
+ template <typename T>
+ ir::Switch* Switch(T&& condition) {
+ return ir.values.Create<ir::Switch>(Value(std::forward<T>(condition)), MultiInBlock());
+ }
/// Creates a case flow node for the given case branch.
/// @param s the switch to create the case into
/// @param selectors the case selectors for the case statement
/// @returns the start block for the case flow node
- Block* CreateCase(Switch* s, utils::VectorRef<Switch::CaseSelector> selectors);
+ ir::Block* Case(ir::Switch* s, utils::VectorRef<Switch::CaseSelector> selectors);
+
+ /// Creates a case flow node for the given case branch.
+ /// @param s the switch to create the case into
+ /// @param selectors the case selectors for the case statement
+ /// @returns the start block for the case flow node
+ ir::Block* Case(ir::Switch* s, std::initializer_list<Switch::CaseSelector> selectors);
/// Creates a new ir::Constant
/// @param val the constant value
@@ -133,156 +162,276 @@
/// @returns the new constant
ir::Constant* Constant(bool v) { return Constant(ir.constant_values.Get(v)); }
+ /// Creates a ir::Constant for the given number
+ /// @param number the number value
+ /// @returns the new constant
+ template <typename T, typename = std::enable_if_t<IsNumeric<T>>>
+ ir::Constant* Value(T&& number) {
+ return Constant(std::forward<T>(number));
+ }
+
+ /// Pass-through overload for Value()
+ /// @param v the ir::Value pointer
+ /// @returns @p v
+ ir::Value* Value(ir::Value* v) { return v; }
+
+ /// Creates a ir::Constant for the given boolean
+ /// @param v the boolean value
+ /// @returns the new constant
+ ir::Constant* Value(bool v) { return Constant(v); }
+
+ /// Pass-through overload for Values() with vector-like argument
+ /// @param vec the vector of ir::Value*
+ /// @return @p vec
+ template <typename VEC, typename = EnableIfVectorLike<utils::traits::Decay<VEC>>>
+ auto Values(VEC&& vec) {
+ return std::forward<VEC>(vec);
+ }
+
+ /// Overload for Values() with utils::Empty argument
+ /// @return utils::Empty
+ utils::EmptyType Values(utils::EmptyType) { return utils::Empty; }
+
+ /// Overload for Values() with no arguments
+ /// @return utils::Empty
+ utils::EmptyType Values() { return utils::Empty; }
+
+ /// @param args the arguments to pass to Value()
+ /// @returns a vector of ir::Value* built from transforming the arguments with Value()
+ template <typename... ARGS, typename = DisableIfVectorLike<ARGS...>>
+ auto Values(ARGS&&... args) {
+ return utils::Vector{Value(std::forward<ARGS>(args))...};
+ }
+
/// Creates an op for `lhs kind rhs`
/// @param kind the kind of operation
/// @param type the result type of the binary expression
/// @param lhs the left-hand-side of the operation
/// @param rhs the right-hand-side of the operation
/// @returns the operation
- Binary* CreateBinary(enum Binary::Kind kind, const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Binary(enum Binary::Kind kind, const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return ir.values.Create<ir::Binary>(kind, type, Value(std::forward<LHS>(lhs)),
+ Value(std::forward<RHS>(rhs)));
+ }
/// Creates an And operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* And(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* And(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kAnd, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
+ }
/// Creates an Or operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Or(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Or(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kOr, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
+ }
/// Creates an Xor operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Xor(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Xor(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kXor, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
+ }
/// Creates an Equal operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Equal(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Equal(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kEqual, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an NotEqual operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* NotEqual(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* NotEqual(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kNotEqual, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an LessThan operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* LessThan(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* LessThan(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kLessThan, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an GreaterThan operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* GreaterThan(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* GreaterThan(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kGreaterThan, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an LessThanEqual operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* LessThanEqual(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* LessThanEqual(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kLessThanEqual, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an GreaterThanEqual operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* GreaterThanEqual(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* GreaterThanEqual(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kGreaterThanEqual, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an ShiftLeft operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* ShiftLeft(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* ShiftLeft(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kShiftLeft, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an ShiftRight operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* ShiftRight(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* ShiftRight(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kShiftRight, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an Add operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Add(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Add(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kAdd, type, std::forward<LHS>(lhs), std::forward<RHS>(rhs));
+ }
/// Creates an Subtract operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Subtract(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Subtract(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kSubtract, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an Multiply operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Multiply(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Multiply(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kMultiply, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an Divide operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Divide(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Divide(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kDivide, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an Modulo operation
/// @param type the result type of the expression
/// @param lhs the lhs of the add
/// @param rhs the rhs of the add
/// @returns the operation
- Binary* Modulo(const type::Type* type, Value* lhs, Value* rhs);
+ template <typename LHS, typename RHS>
+ ir::Binary* Modulo(const type::Type* type, LHS&& lhs, RHS&& rhs) {
+ return Binary(ir::Binary::Kind::kModulo, type, std::forward<LHS>(lhs),
+ std::forward<RHS>(rhs));
+ }
/// Creates an op for `kind val`
/// @param kind the kind of operation
/// @param type the result type of the binary expression
/// @param val the value of the operation
/// @returns the operation
- Unary* CreateUnary(enum Unary::Kind kind, const type::Type* type, Value* val);
+ template <typename VAL>
+ ir::Unary* Unary(enum Unary::Kind kind, const type::Type* type, VAL&& val) {
+ return ir.values.Create<ir::Unary>(kind, type, Value(std::forward<VAL>(val)));
+ }
/// Creates a Complement operation
/// @param type the result type of the expression
/// @param val the value
/// @returns the operation
- Unary* Complement(const type::Type* type, Value* val);
+ template <typename VAL>
+ ir::Unary* Complement(const type::Type* type, VAL&& val) {
+ return Unary(ir::Unary::Kind::kComplement, type, std::forward<VAL>(val));
+ }
/// Creates a Negation operation
/// @param type the result type of the expression
/// @param val the value
/// @returns the operation
- Unary* Negation(const type::Type* type, Value* val);
+ template <typename VAL>
+ ir::Unary* Negation(const type::Type* type, VAL&& val) {
+ return Unary(ir::Unary::Kind::kNegation, type, std::forward<VAL>(val));
+ }
/// Creates a Not operation
/// @param type the result type of the expression
/// @param val the value
/// @returns the operation
- Binary* Not(const type::Type* type, Value* val);
+ template <typename VAL>
+ ir::Binary* Not(const type::Type* type, VAL&& val) {
+ return Equal(type, std::forward<VAL>(val), Constant(false));
+ }
/// Creates a bitcast instruction
/// @param type the result type of the bitcast
/// @param val the value being bitcast
/// @returns the instruction
- ir::Bitcast* Bitcast(const type::Type* type, Value* val);
+ template <typename VAL>
+ ir::Bitcast* Bitcast(const type::Type* type, VAL&& val) {
+ return ir.values.Create<ir::Bitcast>(type, Value(std::forward<VAL>(val)));
+ }
/// Creates a discard instruction
/// @returns the instruction
@@ -290,97 +439,133 @@
/// Creates a user function call instruction
/// @param type the return type of the call
- /// @param func the function being called
+ /// @param func the function to call
/// @param args the call arguments
/// @returns the instruction
- ir::UserCall* UserCall(const type::Type* type,
- Function* func,
- utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::UserCall* Call(const type::Type* type, ir::Function* func, ARGS&&... args) {
+ return ir.values.Create<ir::UserCall>(type, func, Values(std::forward<ARGS>(args)...));
+ }
+
+ /// Creates a builtin call instruction
+ /// @param type the return type of the call
+ /// @param func the builtin function to call
+ /// @param args the call arguments
+ /// @returns the instruction
+ template <typename... ARGS>
+ ir::BuiltinCall* Call(const type::Type* type, builtin::Function func, ARGS&&... args) {
+ return ir.values.Create<ir::BuiltinCall>(type, func, Values(std::forward<ARGS>(args)...));
+ }
/// Creates a value conversion instruction
/// @param to the type converted to
- /// @param from the type converted from
- /// @param args the arguments to be converted
+ /// @param val the value to be converted
/// @returns the instruction
- ir::Convert* Convert(const type::Type* to,
- const type::Type* from,
- utils::VectorRef<Value*> args);
+ template <typename VAL>
+ ir::Convert* Convert(const type::Type* to, VAL&& val) {
+ return ir.values.Create<ir::Convert>(to, Value(std::forward<VAL>(val)));
+ }
/// Creates a value constructor instruction
- /// @param to the type being converted
- /// @param args the arguments to be converted
+ /// @param type the type to constructed
+ /// @param args the arguments to the constructor
/// @returns the instruction
- ir::Construct* Construct(const type::Type* to, utils::VectorRef<Value*> args = utils::Empty);
-
- /// Creates a builtin call instruction
- /// @param type the return type
- /// @param func the builtin function
- /// @param args the arguments to be converted
- /// @returns the instruction
- ir::Builtin* Builtin(const type::Type* type,
- builtin::Function func,
- utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::Construct* Construct(const type::Type* type, ARGS&&... args) {
+ return ir.values.Create<ir::Construct>(type, Values(std::forward<ARGS>(args)...));
+ }
/// Creates a load instruction
/// @param from the expression being loaded from
/// @returns the instruction
- ir::Load* Load(Value* from);
+ template <typename VAL>
+ ir::Load* Load(VAL&& from) {
+ return ir.values.Create<ir::Load>(Value(std::forward<VAL>(from)));
+ }
/// Creates a store instruction
/// @param to the expression being stored too
/// @param from the expression being stored
/// @returns the instruction
- ir::Store* Store(Value* to, Value* from);
+ template <typename ARG>
+ ir::Store* Store(ir::Value* to, ARG&& from) {
+ return ir.values.Create<ir::Store>(to, Value(std::forward<ARG>(from)));
+ }
/// Creates a new `var` declaration
/// @param type the var type
/// @returns the instruction
- ir::Var* Declare(const type::Type* type);
+ ir::Var* Var(const type::Pointer* type);
/// Creates a return instruction
/// @param func the function being returned
- /// @param args the return arguments
/// @returns the instruction
- ir::Return* Return(Function* func, utils::VectorRef<Value*> args = utils::Empty);
+ ir::Return* Return(ir::Function* func) { return ir.values.Create<ir::Return>(func); }
+
+ /// Creates a return instruction
+ /// @param func the function being returned
+ /// @param value the return value
+ /// @returns the instruction
+ template <typename ARG>
+ ir::Return* Return(ir::Function* func, ARG&& value) {
+ return ir.values.Create<ir::Return>(func, Value(std::forward<ARG>(value)));
+ }
/// Creates a loop next iteration instruction
/// @param loop the loop being iterated
/// @param args the branch arguments
/// @returns the instruction
- ir::NextIteration* NextIteration(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::NextIteration* NextIteration(ir::Loop* loop, ARGS&&... args) {
+ return ir.values.Create<ir::NextIteration>(loop, Values(std::forward<ARGS>(args)...));
+ }
/// Creates a loop break-if instruction
/// @param condition the break condition
/// @param loop the loop being iterated
/// @param args the branch arguments
/// @returns the instruction
- ir::BreakIf* BreakIf(Value* condition,
- Loop* loop,
- utils::VectorRef<Value*> args = utils::Empty);
+ template <typename CONDITION, typename... ARGS>
+ ir::BreakIf* BreakIf(CONDITION&& condition, ir::Loop* loop, ARGS&&... args) {
+ return ir.values.Create<ir::BreakIf>(Value(std::forward<CONDITION>(condition)), loop,
+ Values(std::forward<ARGS>(args)...));
+ }
/// Creates a continue instruction
/// @param loop the loop being continued
/// @param args the branch arguments
/// @returns the instruction
- ir::Continue* Continue(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::Continue* Continue(ir::Loop* loop, ARGS&&... args) {
+ return ir.values.Create<ir::Continue>(loop, Values(std::forward<ARGS>(args)...));
+ }
/// Creates an exit switch instruction
/// @param sw the switch being exited
/// @param args the branch arguments
/// @returns the instruction
- ir::ExitSwitch* ExitSwitch(Switch* sw, utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::ExitSwitch* ExitSwitch(ir::Switch* sw, ARGS&&... args) {
+ return ir.values.Create<ir::ExitSwitch>(sw, Values(std::forward<ARGS>(args)...));
+ }
/// Creates an exit loop instruction
/// @param loop the loop being exited
/// @param args the branch arguments
/// @returns the instruction
- ir::ExitLoop* ExitLoop(Loop* loop, utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::ExitLoop* ExitLoop(ir::Loop* loop, ARGS&&... args) {
+ return ir.values.Create<ir::ExitLoop>(loop, Values(std::forward<ARGS>(args)...));
+ }
/// Creates an exit if instruction
/// @param i the if being exited
/// @param args the branch arguments
/// @returns the instruction
- ir::ExitIf* ExitIf(If* i, utils::VectorRef<Value*> args = utils::Empty);
+ template <typename... ARGS>
+ ir::ExitIf* ExitIf(ir::If* i, ARGS&&... args) {
+ return ir.values.Create<ir::ExitIf>(i, Values(std::forward<ARGS>(args)...));
+ }
/// Creates a new `BlockParam`
/// @param type the parameter type
@@ -394,21 +579,35 @@
/// Creates a new `Access`
/// @param type the return type
- /// @param source the source value
+ /// @param object the object being accessed
/// @param indices the access indices
/// @returns the instruction
- ir::Access* Access(const type::Type* type, Value* source, utils::VectorRef<Value*> indices);
+ template <typename... ARGS>
+ ir::Access* Access(const type::Type* type, ir::Value* object, ARGS&&... indices) {
+ return ir.values.Create<ir::Access>(type, object, Values(std::forward<ARGS>(indices)...));
+ }
/// Creates a new `Swizzle`
/// @param type the return type
- /// @param source the source value
- /// @param indices the access indices
+ /// @param object the object being swizzled
+ /// @param indices the swizzle indices
/// @returns the instruction
- ir::Swizzle* Swizzle(const type::Type* type, Value* source, utils::VectorRef<uint32_t> indices);
+ ir::Swizzle* Swizzle(const type::Type* type,
+ ir::Value* object,
+ utils::VectorRef<uint32_t> indices);
+
+ /// Creates a new `Swizzle`
+ /// @param type the return type
+ /// @param object the object being swizzled
+ /// @param indices the swizzle indices
+ /// @returns the instruction
+ ir::Swizzle* Swizzle(const type::Type* type,
+ ir::Value* object,
+ std::initializer_list<uint32_t> indices);
/// Retrieves the root block for the module, creating if necessary
/// @returns the root block
- ir::Block* CreateRootBlockIfNeeded();
+ ir::Block* RootBlock();
/// The IR module.
Module& ir;
diff --git a/src/tint/ir/builtin.cc b/src/tint/ir/builtin_call.cc
similarity index 75%
rename from src/tint/ir/builtin.cc
rename to src/tint/ir/builtin_call.cc
index c571168..a230510 100644
--- a/src/tint/ir/builtin.cc
+++ b/src/tint/ir/builtin_call.cc
@@ -12,25 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/builtin_call.h"
#include <utility>
#include "src/tint/debug.h"
-TINT_INSTANTIATE_TYPEINFO(tint::ir::Builtin);
+TINT_INSTANTIATE_TYPEINFO(tint::ir::BuiltinCall);
-// \cond DO_NOT_DOCUMENT
namespace tint::ir {
-Builtin::Builtin(const type::Type* ty, builtin::Function func, utils::VectorRef<Value*> arguments)
+BuiltinCall::BuiltinCall(const type::Type* ty,
+ builtin::Function func,
+ utils::VectorRef<Value*> arguments)
: Base(ty), func_(func) {
TINT_ASSERT(IR, func != builtin::Function::kNone);
TINT_ASSERT(IR, func != builtin::Function::kTintMaterialize);
AddOperands(std::move(arguments));
}
-Builtin::~Builtin() = default;
+BuiltinCall::~BuiltinCall() = default;
} // namespace tint::ir
-// \endcond
diff --git a/src/tint/ir/builtin.h b/src/tint/ir/builtin_call.h
similarity index 70%
rename from src/tint/ir/builtin.h
rename to src/tint/ir/builtin_call.h
index 628a04f..4aef1eb 100644
--- a/src/tint/ir/builtin.h
+++ b/src/tint/ir/builtin_call.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_TINT_IR_BUILTIN_H_
-#define SRC_TINT_IR_BUILTIN_H_
+#ifndef SRC_TINT_IR_BUILTIN_CALL_H_
+#define SRC_TINT_IR_BUILTIN_CALL_H_
#include "src/tint/builtin/function.h"
#include "src/tint/ir/call.h"
@@ -22,24 +22,24 @@
namespace tint::ir {
/// A builtin call instruction in the IR.
-class Builtin : public utils::Castable<Builtin, Call> {
+class BuiltinCall : public utils::Castable<BuiltinCall, Call> {
public:
/// Constructor
/// @param res_type the result type
/// @param func the builtin function
/// @param args the conversion arguments
- Builtin(const type::Type* res_type,
- builtin::Function func,
- utils::VectorRef<Value*> args = utils::Empty);
- ~Builtin() override;
+ BuiltinCall(const type::Type* res_type,
+ builtin::Function func,
+ utils::VectorRef<Value*> args = utils::Empty);
+ ~BuiltinCall() override;
/// @returns the builtin function
- builtin::Function Func() const { return func_; }
+ builtin::Function Func() { return func_; }
private:
- const builtin::Function func_;
+ builtin::Function func_;
};
} // namespace tint::ir
-#endif // SRC_TINT_IR_BUILTIN_H_
+#endif // SRC_TINT_IR_BUILTIN_CALL_H_
diff --git a/src/tint/ir/builtin_test.cc b/src/tint/ir/builtin_call_test.cc
similarity index 69%
rename from src/tint/ir/builtin_test.cc
rename to src/tint/ir/builtin_call_test.cc
index d5cf04a..fb6580e 100644
--- a/src/tint/ir/builtin_test.cc
+++ b/src/tint/ir/builtin_call_test.cc
@@ -21,55 +21,53 @@
namespace {
using namespace tint::number_suffixes; // NOLINT
-using IR_BuiltinTest = IRTestHelper;
+using IR_BuiltinCallTest = IRTestHelper;
-TEST_F(IR_BuiltinTest, Usage) {
+TEST_F(IR_BuiltinCallTest, Usage) {
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* builtin =
- b.Builtin(mod.Types().f32(), builtin::Function::kAbs, utils::Vector{arg1, arg2});
+ auto* builtin = b.Call(mod.Types().f32(), builtin::Function::kAbs, arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{builtin, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{builtin, 1u}));
}
-TEST_F(IR_BuiltinTest, Fail_NullType) {
+TEST_F(IR_BuiltinCallTest, Fail_NullType) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- b.Builtin(nullptr, builtin::Function::kAbs);
+ b.Call(nullptr, builtin::Function::kAbs);
},
"");
}
-TEST_F(IR_BuiltinTest, Fail_NoneFunction) {
+TEST_F(IR_BuiltinCallTest, Fail_NoneFunction) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- b.Builtin(mod.Types().f32(), builtin::Function::kNone);
+ b.Call(mod.Types().f32(), builtin::Function::kNone);
},
"");
}
-TEST_F(IR_BuiltinTest, Fail_TintMaterializeFunction) {
+TEST_F(IR_BuiltinCallTest, Fail_TintMaterializeFunction) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- b.Builtin(mod.Types().f32(), builtin::Function::kTintMaterialize);
+ b.Call(mod.Types().f32(), builtin::Function::kTintMaterialize);
},
"");
}
-TEST_F(IR_BuiltinTest, Fail_NullArg) {
+TEST_F(IR_BuiltinCallTest, Fail_NullArg) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- b.Builtin(mod.Types().f32(), builtin::Function::kAbs,
- utils::Vector<Value*, 1>{nullptr});
+ b.Call(mod.Types().f32(), builtin::Function::kAbs, nullptr);
},
"");
}
diff --git a/src/tint/ir/call.h b/src/tint/ir/call.h
index 669aaf5..c26aa5d 100644
--- a/src/tint/ir/call.h
+++ b/src/tint/ir/call.h
@@ -26,10 +26,10 @@
~Call() override;
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the call arguments
- virtual utils::Slice<Value const* const> Args() const { return operands_.Slice(); }
+ virtual utils::Slice<Value* const> Args() { return operands_.Slice(); }
protected:
/// Constructor
diff --git a/src/tint/ir/constant.h b/src/tint/ir/constant.h
index cf9ed40..90acb04 100644
--- a/src/tint/ir/constant.h
+++ b/src/tint/ir/constant.h
@@ -29,10 +29,10 @@
~Constant() override;
/// @returns the constants value
- const constant::Value* Value() const { return value_; }
+ const constant::Value* Value() { return value_; }
/// @returns the type of the constant
- const type::Type* Type() const override { return value_->Type(); }
+ const type::Type* Type() override { return value_->Type(); }
private:
const constant::Value* const value_ = nullptr;
diff --git a/src/tint/ir/construct_test.cc b/src/tint/ir/construct_test.cc
index cdab810..fa31ff9 100644
--- a/src/tint/ir/construct_test.cc
+++ b/src/tint/ir/construct_test.cc
@@ -27,7 +27,7 @@
TEST_F(IR_ConstructTest, Usage) {
auto* arg1 = b.Constant(true);
auto* arg2 = b.Constant(false);
- auto* c = b.Construct(mod.Types().f32(), utils::Vector{arg1, arg2});
+ auto* c = b.Construct(mod.Types().f32(), arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{c, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{c, 1u}));
@@ -48,7 +48,7 @@
{
Module mod;
Builder b{mod};
- b.Construct(mod.Types().f32(), utils::Vector<Value*, 1>{nullptr});
+ b.Construct(mod.Types().f32(), nullptr);
},
"");
}
diff --git a/src/tint/ir/continue.cc b/src/tint/ir/continue.cc
index a7282ac..e1e8d84 100644
--- a/src/tint/ir/continue.cc
+++ b/src/tint/ir/continue.cc
@@ -18,6 +18,7 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Continue);
@@ -28,7 +29,7 @@
TINT_ASSERT(IR, loop_);
if (loop_) {
- loop_->Continuing()->AddInboundBranch(this);
+ loop_->Continuing()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/continue.h b/src/tint/ir/continue.h
index a954e74..f9c5455 100644
--- a/src/tint/ir/continue.h
+++ b/src/tint/ir/continue.h
@@ -35,7 +35,7 @@
~Continue() override;
/// @returns the loop owning the continue block
- const ir::Loop* Loop() const { return loop_; }
+ ir::Loop* Loop() { return loop_; }
private:
ir::Loop* loop_ = nullptr;
diff --git a/src/tint/ir/continue_test.cc b/src/tint/ir/continue_test.cc
index 631d696..29ea605 100644
--- a/src/tint/ir/continue_test.cc
+++ b/src/tint/ir/continue_test.cc
@@ -25,11 +25,11 @@
using IR_ContinueTest = IRTestHelper;
TEST_F(IR_ContinueTest, Usage) {
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* brk = b.Continue(loop, utils::Vector{arg1, arg2});
+ auto* brk = b.Continue(loop, arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{brk, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{brk, 1u}));
@@ -50,7 +50,7 @@
{
Module mod;
Builder b{mod};
- b.Continue(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ b.Continue(b.Loop(), nullptr);
},
"");
}
diff --git a/src/tint/ir/convert.cc b/src/tint/ir/convert.cc
index b55243b..ed5961b 100644
--- a/src/tint/ir/convert.cc
+++ b/src/tint/ir/convert.cc
@@ -22,13 +22,10 @@
namespace tint::ir {
-Convert::Convert(const type::Type* to_type,
- const type::Type* from_type,
- utils::VectorRef<Value*> arguments)
- : Base(to_type), from_type_(from_type) {
- TINT_ASSERT(IR, from_type_);
- TINT_ASSERT(IR, !arguments.IsEmpty());
- AddOperands(std::move(arguments));
+Convert::Convert(const type::Type* to_type, Value* value) : Base(to_type) {
+ TINT_ASSERT_OR_RETURN(IR, value);
+
+ AddOperand(value);
}
Convert::~Convert() = default;
diff --git a/src/tint/ir/convert.h b/src/tint/ir/convert.h
index 77778e0..0bb07cf 100644
--- a/src/tint/ir/convert.h
+++ b/src/tint/ir/convert.h
@@ -25,21 +25,10 @@
class Convert : public utils::Castable<Convert, Call> {
public:
/// Constructor
- /// @param result_type the result type
- /// @param from_type the type being converted from
- /// @param args the conversion arguments
- Convert(const type::Type* result_type,
- const type::Type* from_type,
- utils::VectorRef<Value*> args);
+ /// @param to_type the target conversion type
+ /// @param value the value to convert
+ Convert(const type::Type* to_type, Value* value);
~Convert() override;
-
- /// @returns the from type
- const type::Type* FromType() const { return from_type_; }
- /// @returns the to type
- const type::Type* ToType() const { return Type(); }
-
- private:
- const type::Type* from_type_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/convert_test.cc b/src/tint/ir/convert_test.cc
index d0ef71d..d1affd4 100644
--- a/src/tint/ir/convert_test.cc
+++ b/src/tint/ir/convert_test.cc
@@ -27,17 +27,7 @@
{
Module mod;
Builder b{mod};
- b.Convert(nullptr, mod.Types().f32(), utils::Vector{b.Constant(1_u)});
- },
- "");
-}
-
-TEST_F(IR_ConvertTest, Fail_NullFromType) {
- EXPECT_FATAL_FAILURE(
- {
- Module mod;
- Builder b{mod};
- b.Convert(mod.Types().f32(), nullptr, utils::Vector{b.Constant(1_u)});
+ b.Convert(nullptr, 1_u);
},
"");
}
@@ -47,17 +37,7 @@
{
Module mod;
Builder b{mod};
- b.Convert(mod.Types().f32(), mod.Types().i32(), utils::Empty);
- },
- "");
-}
-
-TEST_F(IR_ConvertTest, Fail_NullArg) {
- EXPECT_FATAL_FAILURE(
- {
- Module mod;
- Builder b{mod};
- b.Convert(mod.Types().f32(), mod.Types().i32(), utils::Vector<Value*, 1>{nullptr});
+ b.Convert(mod.Types().f32(), nullptr);
},
"");
}
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 9f7aa8d..6d17316 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -22,8 +22,9 @@
#include "src/tint/ir/binary.h"
#include "src/tint/ir/bitcast.h"
#include "src/tint/ir/block.h"
+#include "src/tint/ir/block_param.h"
#include "src/tint/ir/break_if.h"
-#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/builtin_call.h"
#include "src/tint/ir/construct.h"
#include "src/tint/ir/continue.h"
#include "src/tint/ir/convert.h"
@@ -34,6 +35,7 @@
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
@@ -63,7 +65,7 @@
} // namespace
-Disassembler::Disassembler(const Module& mod) : mod_(mod) {}
+Disassembler::Disassembler(Module& mod) : mod_(mod) {}
Disassembler::~Disassembler() = default;
@@ -80,19 +82,19 @@
current_output_start_pos_ = out_.tellp();
}
-void Disassembler::EmitBlockInstructions(const Block* b) {
- for (const auto* inst : *b) {
+void Disassembler::EmitBlockInstructions(Block* b) {
+ for (auto* inst : *b) {
Indent();
EmitInstruction(inst);
}
}
-size_t Disassembler::IdOf(const Block* node) {
+size_t Disassembler::IdOf(Block* node) {
TINT_ASSERT(IR, node);
return block_ids_.GetOrCreate(node, [&] { return block_ids_.Count(); });
}
-std::string_view Disassembler::IdOf(const Value* value) {
+std::string_view Disassembler::IdOf(Value* value) {
TINT_ASSERT(IR, value);
return value_ids_.GetOrCreate(value, [&] {
if (auto sym = mod_.NameOf(value)) {
@@ -126,7 +128,7 @@
return out_.str();
}
-void Disassembler::Walk(const Block* blk) {
+void Disassembler::Walk(Block* blk) {
if (visited_.Contains(blk)) {
return;
}
@@ -134,14 +136,19 @@
WalkInternal(blk);
}
-void Disassembler::WalkInternal(const Block* blk) {
+void Disassembler::WalkInternal(Block* blk) {
+ Indent();
+
SourceMarker sm(this);
- Indent() << "%b" << IdOf(blk) << " = block";
- if (!blk->Params().IsEmpty()) {
- out_ << " (";
- EmitValueList(blk->Params().Slice());
- out_ << ")";
+ out_ << "%b" << IdOf(blk) << " = block";
+ if (auto* merge = blk->As<MultiInBlock>()) {
+ if (!merge->Params().IsEmpty()) {
+ out_ << " (";
+ EmitValueList(merge->Params().Slice());
+ out_ << ")";
+ }
}
+ sm.Store(blk);
out_ << " {";
EmitLine();
@@ -150,7 +157,6 @@
EmitBlockInstructions(blk);
}
Indent() << "}";
- sm.Store(blk);
EmitLine();
}
@@ -172,7 +178,7 @@
}
}
-void Disassembler::EmitParamAttributes(const FunctionParam* p) {
+void Disassembler::EmitParamAttributes(FunctionParam* p) {
if (!p->Invariant() && !p->Location().has_value() && !p->BindingPoint().has_value() &&
!p->Builtin().has_value()) {
return;
@@ -209,7 +215,7 @@
out_ << "]";
}
-void Disassembler::EmitReturnAttributes(const Function* func) {
+void Disassembler::EmitReturnAttributes(Function* func) {
if (!func->ReturnInvariant() && !func->ReturnLocation().has_value() &&
!func->ReturnBuiltin().has_value()) {
return;
@@ -241,7 +247,7 @@
out_ << "]";
}
-void Disassembler::EmitFunction(const Function* func) {
+void Disassembler::EmitFunction(Function* func) {
in_function_ = true;
Indent() << "%" << IdOf(func) << " =";
@@ -256,7 +262,7 @@
out_ << " func(";
- for (const auto* p : func->Params()) {
+ for (auto* p : func->Params()) {
if (p != func->Params().Front()) {
out_ << ", ";
}
@@ -279,17 +285,17 @@
EmitLine();
}
-void Disassembler::EmitValueWithType(const Value* val) {
+void Disassembler::EmitValueWithType(Value* val) {
EmitValue(val);
if (auto* i = val->As<ir::Instruction>(); i->Type() != nullptr) {
out_ << ":" << i->Type()->FriendlyName();
}
}
-void Disassembler::EmitValue(const Value* val) {
+void Disassembler::EmitValue(Value* val) {
tint::Switch(
val,
- [&](const ir::Constant* constant) {
+ [&](ir::Constant* constant) {
std::function<void(const constant::Value*)> emit = [&](const constant::Value* c) {
tint::Switch(
c,
@@ -332,29 +338,27 @@
};
emit(constant->Value());
},
- [&](const ir::Instruction* i) { out_ << "%" << IdOf(i); },
- [&](const ir::BlockParam* p) {
- out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName();
- },
- [&](const ir::FunctionParam* p) { out_ << "%" << IdOf(p); },
+ [&](ir::Instruction* i) { out_ << "%" << IdOf(i); },
+ [&](ir::BlockParam* p) { out_ << "%" << IdOf(p) << ":" << p->Type()->FriendlyName(); },
+ [&](ir::FunctionParam* p) { out_ << "%" << IdOf(p); },
[&](Default) { out_ << "Unknown value: " << val->TypeInfo().name; });
}
-void Disassembler::EmitInstructionName(std::string_view name, const Instruction* inst) {
+void Disassembler::EmitInstructionName(std::string_view name, Instruction* inst) {
SourceMarker sm(this);
out_ << name;
sm.Store(inst);
}
-void Disassembler::EmitInstruction(const Instruction* inst) {
+void Disassembler::EmitInstruction(Instruction* inst) {
tint::Switch(
- inst, //
- [&](const ir::Switch* s) { EmitSwitch(s); }, //
- [&](const ir::If* i) { EmitIf(i); }, //
- [&](const ir::Loop* l) { EmitLoop(l); }, //
- [&](const ir::Binary* b) { EmitBinary(b); }, //
- [&](const ir::Unary* u) { EmitUnary(u); },
- [&](const ir::Bitcast* b) {
+ inst, //
+ [&](Switch* s) { EmitSwitch(s); }, //
+ [&](If* i) { EmitIf(i); }, //
+ [&](Loop* l) { EmitLoop(l); }, //
+ [&](Binary* b) { EmitBinary(b); }, //
+ [&](Unary* u) { EmitUnary(u); },
+ [&](Bitcast* b) {
EmitValueWithType(b);
out_ << " = ";
EmitInstructionName("bitcast", b);
@@ -362,11 +366,11 @@
EmitArgs(b);
EmitLine();
},
- [&](const ir::Discard* d) {
+ [&](Discard* d) {
EmitInstructionName("discard", d);
EmitLine();
},
- [&](const ir::Builtin* b) {
+ [&](BuiltinCall* b) {
EmitValueWithType(b);
out_ << " = ";
EmitInstructionName(builtin::str(b->Func()), b);
@@ -374,7 +378,7 @@
EmitArgs(b);
EmitLine();
},
- [&](const ir::Construct* c) {
+ [&](Construct* c) {
EmitValueWithType(c);
out_ << " = ";
EmitInstructionName("construct", c);
@@ -382,15 +386,15 @@
EmitArgs(c);
EmitLine();
},
- [&](const ir::Convert* c) {
+ [&](Convert* c) {
EmitValueWithType(c);
out_ << " = ";
EmitInstructionName("convert", c);
- out_ << " " << c->FromType()->FriendlyName() << ", ";
+ out_ << " ";
EmitArgs(c);
EmitLine();
},
- [&](const ir::Load* l) {
+ [&](Load* l) {
EmitValueWithType(l);
out_ << " = ";
EmitInstructionName("load", l);
@@ -398,7 +402,7 @@
EmitValue(l->From());
EmitLine();
},
- [&](const ir::Store* s) {
+ [&](Store* s) {
EmitInstructionName("store", s);
out_ << " ";
EmitValue(s->To());
@@ -406,7 +410,7 @@
EmitValue(s->From());
EmitLine();
},
- [&](const ir::UserCall* uc) {
+ [&](UserCall* uc) {
EmitValueWithType(uc);
out_ << " = ";
EmitInstructionName("call", uc);
@@ -417,7 +421,7 @@
EmitArgs(uc);
EmitLine();
},
- [&](const ir::Var* v) {
+ [&](Var* v) {
EmitValueWithType(v);
out_ << " = ";
EmitInstructionName("var", v);
@@ -431,22 +435,17 @@
}
EmitLine();
},
- [&](const ir::Access* a) {
+ [&](Access* a) {
EmitValueWithType(a);
out_ << " = ";
EmitInstructionName("access", a);
out_ << " ";
- EmitValue(a->Object());
+ EmitOperand(a, a->Object(), Access::kObjectOperandOffset);
out_ << ", ";
- for (size_t i = 0; i < a->Indices().Length(); ++i) {
- if (i > 0) {
- out_ << ", ";
- }
- EmitValue(a->Indices()[i]);
- }
+ EmitOperandList(a, a->Indices(), Access::kIndicesOperandOffset);
EmitLine();
},
- [&](const ir::Swizzle* s) {
+ [&](Swizzle* s) {
EmitValueWithType(s);
out_ << " = ";
EmitInstructionName("swizzle", s);
@@ -471,20 +470,32 @@
}
EmitLine();
},
- [&](const ir::Branch* b) { EmitBranch(b); },
+ [&](Branch* b) { EmitBranch(b); },
[&](Default) { out_ << "Unknown instruction: " << inst->TypeInfo().name; });
}
-void Disassembler::EmitOperand(const Value* val, const Instruction* inst, uint32_t index) {
+void Disassembler::EmitOperand(Instruction* inst, Value* val, size_t index) {
SourceMarker condMarker(this);
EmitValue(val);
- condMarker.Store(Operand{inst, index});
+ condMarker.Store(Usage{inst, static_cast<uint32_t>(index)});
}
-void Disassembler::EmitIf(const If* i) {
+void Disassembler::EmitOperandList(Instruction* inst,
+ utils::Slice<Value* const> operands,
+ size_t start_index) {
+ size_t index = start_index;
+ for (auto* operand : operands) {
+ if (index != start_index) {
+ out_ << ", ";
+ }
+ EmitOperand(inst, operand, index++);
+ }
+}
+
+void Disassembler::EmitIf(If* i) {
SourceMarker sm(this);
out_ << "if ";
- EmitOperand(i->Condition(), i, If::kConditionOperandIndex);
+ EmitOperand(i, i->Condition(), If::kConditionOperandOffset);
bool has_true = i->True()->HasBranchTarget();
bool has_false = i->False()->HasBranchTarget();
@@ -531,7 +542,7 @@
}
}
-void Disassembler::EmitLoop(const Loop* l) {
+void Disassembler::EmitLoop(Loop* l) {
utils::Vector<std::string, 4> parts;
if (l->Initializer()->HasBranchTarget()) {
parts.Push("i: %b" + std::to_string(IdOf(l->Initializer())));
@@ -583,16 +594,16 @@
}
}
-void Disassembler::EmitSwitch(const Switch* s) {
+void Disassembler::EmitSwitch(Switch* s) {
out_ << "switch ";
EmitValue(s->Condition());
out_ << " [";
- for (const auto& c : s->Cases()) {
+ for (auto& c : s->Cases()) {
if (&c != &s->Cases().Front()) {
out_ << ", ";
}
out_ << "c: (";
- for (const auto& selector : c.selectors) {
+ for (auto& selector : c.selectors) {
if (&selector != &c.selectors.Front()) {
out_ << " ";
}
@@ -628,21 +639,17 @@
}
}
-void Disassembler::EmitBranch(const Branch* b) {
+void Disassembler::EmitBranch(Branch* b) {
SourceMarker sm(this);
tint::Switch(
b, //
- [&](const ir::Return*) { out_ << "ret"; },
- [&](const ir::Continue* cont) {
- out_ << "continue %b" << IdOf(cont->Loop()->Continuing());
- },
- [&](const ir::ExitIf* ei) { out_ << "exit_if %b" << IdOf(ei->If()->Merge()); },
- [&](const ir::ExitSwitch* es) { out_ << "exit_switch %b" << IdOf(es->Switch()->Merge()); },
- [&](const ir::ExitLoop* el) { out_ << "exit_loop %b" << IdOf(el->Loop()->Merge()); },
- [&](const ir::NextIteration* ni) {
- out_ << "next_iteration %b" << IdOf(ni->Loop()->Body());
- },
- [&](const ir::BreakIf* bi) {
+ [&](Return*) { out_ << "ret"; },
+ [&](Continue* cont) { out_ << "continue %b" << IdOf(cont->Loop()->Continuing()); },
+ [&](ExitIf* ei) { out_ << "exit_if %b" << IdOf(ei->If()->Merge()); },
+ [&](ExitSwitch* es) { out_ << "exit_switch %b" << IdOf(es->Switch()->Merge()); },
+ [&](ExitLoop* el) { out_ << "exit_loop %b" << IdOf(el->Loop()->Merge()); },
+ [&](NextIteration* ni) { out_ << "next_iteration %b" << IdOf(ni->Loop()->Body()); },
+ [&](BreakIf* bi) {
out_ << "break_if ";
EmitValue(bi->Condition());
out_ << " %b" << IdOf(bi->Loop()->Body());
@@ -658,7 +665,7 @@
EmitLine();
}
-void Disassembler::EmitValueList(utils::Slice<Value const* const> values) {
+void Disassembler::EmitValueList(utils::Slice<Value* const> values) {
for (auto* v : values) {
if (v != values.Front()) {
out_ << ", ";
@@ -667,11 +674,11 @@
}
}
-void Disassembler::EmitArgs(const Call* call) {
+void Disassembler::EmitArgs(Call* call) {
EmitValueList(call->Args());
}
-void Disassembler::EmitBinary(const Binary* b) {
+void Disassembler::EmitBinary(Binary* b) {
EmitValueWithType(b);
out_ << " = ";
switch (b->Kind()) {
@@ -731,7 +738,7 @@
EmitLine();
}
-void Disassembler::EmitUnary(const Unary* u) {
+void Disassembler::EmitUnary(Unary* u) {
EmitValueWithType(u);
out_ << " = ";
switch (u->Kind()) {
diff --git a/src/tint/ir/disassembler.h b/src/tint/ir/disassembler.h
index 6c58639..6d7248c 100644
--- a/src/tint/ir/disassembler.h
+++ b/src/tint/ir/disassembler.h
@@ -39,33 +39,9 @@
/// Helper class to disassemble the IR
class Disassembler {
public:
- /// An operand used in an instruction
- struct Operand {
- /// The instruction
- const Instruction* instruction = nullptr;
- /// The operand index
- uint32_t operand_index = 0u;
-
- /// A specialization of utils::Hasher for Operand.
- struct Hasher {
- /// @param u the operand to hash
- /// @returns a hash of the operand
- inline std::size_t operator()(const Operand& u) const {
- return utils::Hash(u.instruction, u.operand_index);
- }
- };
-
- /// An equality helper for Operand.
- /// @param other the operand to compare against
- /// @returns true if the two operands are equal
- bool operator==(const Operand& other) const {
- return instruction == other.instruction && operand_index == other.operand_index;
- }
- };
-
/// Constructor
/// @param mod the module
- explicit Disassembler(const Module& mod);
+ explicit Disassembler(Module& mod);
~Disassembler();
/// Returns the module as a string
@@ -74,41 +50,39 @@
/// Writes the block instructions to the stream
/// @param b the block containing the instructions
- void EmitBlockInstructions(const Block* b);
+ void EmitBlockInstructions(Block* b);
/// @returns the string representation
std::string AsString() const { return out_.str(); }
/// @param inst the instruction to retrieve
/// @returns the source for the instruction
- Source InstructionSource(const Instruction* inst) {
+ Source InstructionSource(Instruction* inst) {
return instruction_to_src_.Get(inst).value_or(Source{});
}
/// @param operand the operand to retrieve
/// @returns the source for the operand
- Source OperandSource(Operand operand) {
- return operand_to_src_.Get(operand).value_or(Source{});
- }
+ Source OperandSource(Usage operand) { return operand_to_src_.Get(operand).value_or(Source{}); }
/// @param blk teh block to retrieve
/// @returns the source for the block
- Source BlockSource(const Block* blk) { return block_to_src_.Get(blk).value_or(Source{}); }
+ Source BlockSource(Block* blk) { return block_to_src_.Get(blk).value_or(Source{}); }
/// Stores the given @p src location for @p inst instruction
/// @param inst the instruction to store
/// @param src the source location
- void SetSource(const Instruction* inst, Source src) { instruction_to_src_.Add(inst, src); }
+ void SetSource(Instruction* inst, Source src) { instruction_to_src_.Add(inst, src); }
/// Stores the given @p src location for @p blk block
/// @param blk the block to store
/// @param src the source location
- void SetSource(const Block* blk, Source src) { block_to_src_.Add(blk, src); }
+ void SetSource(Block* blk, Source src) { block_to_src_.Add(blk, src); }
/// Stores the given @p src location for @p op operand
/// @param op the operand to store
/// @param src the source location
- void SetSource(Operand op, Source src) { operand_to_src_.Add(op, src); }
+ void SetSource(Usage op, Source src) { operand_to_src_.Add(op, src); }
/// @returns the source location for the current emission location
Source::Location MakeCurrentLocation();
@@ -119,11 +93,11 @@
explicit SourceMarker(Disassembler* d) : dis_(d), begin_(dis_->MakeCurrentLocation()) {}
~SourceMarker() = default;
- void Store(const Instruction* inst) { dis_->SetSource(inst, MakeSource()); }
+ void Store(Instruction* inst) { dis_->SetSource(inst, MakeSource()); }
- void Store(const Block* blk) { dis_->SetSource(blk, MakeSource()); }
+ void Store(Block* blk) { dis_->SetSource(blk, MakeSource()); }
- void Store(Operand operand) { dis_->SetSource(operand, MakeSource()); }
+ void Store(Usage operand) { dis_->SetSource(operand, MakeSource()); }
Source MakeSource() const {
return Source(Source::Range(begin_, dis_->MakeCurrentLocation()));
@@ -136,46 +110,49 @@
utils::StringStream& Indent();
- size_t IdOf(const Block* blk);
- std::string_view IdOf(const Value* node);
+ size_t IdOf(Block* blk);
+ std::string_view IdOf(Value* node);
- void Walk(const Block* blk);
- void WalkInternal(const Block* blk);
- void EmitFunction(const Function* func);
- void EmitParamAttributes(const FunctionParam* p);
- void EmitReturnAttributes(const Function* func);
+ void Walk(Block* blk);
+ void WalkInternal(Block* blk);
+ void EmitFunction(Function* func);
+ void EmitParamAttributes(FunctionParam* p);
+ void EmitReturnAttributes(Function* func);
void EmitBindingPoint(BindingPoint p);
void EmitLocation(Location loc);
- void EmitInstruction(const Instruction* inst);
- void EmitValueWithType(const Value* val);
- void EmitValue(const Value* val);
- void EmitValueList(utils::Slice<ir::Value const* const> values);
- void EmitArgs(const Call* call);
- void EmitBinary(const Binary* b);
- void EmitUnary(const Unary* b);
- void EmitBranch(const Branch* b);
- void EmitSwitch(const Switch* s);
- void EmitLoop(const Loop* l);
- void EmitIf(const If* i);
+ void EmitInstruction(Instruction* inst);
+ void EmitValueWithType(Value* val);
+ void EmitValue(Value* val);
+ void EmitValueList(utils::Slice<ir::Value* const> values);
+ void EmitArgs(Call* call);
+ void EmitBinary(Binary* b);
+ void EmitUnary(Unary* b);
+ void EmitBranch(Branch* b);
+ void EmitSwitch(Switch* s);
+ void EmitLoop(Loop* l);
+ void EmitIf(If* i);
void EmitStructDecl(const type::Struct* str);
void EmitLine();
- void EmitOperand(const Value* val, const Instruction* inst, uint32_t index);
- void EmitInstructionName(std::string_view name, const Instruction* inst);
+ void EmitOperand(Instruction* inst, Value* val, size_t index);
+ void EmitOperandList(Instruction* inst,
+ utils::Slice<Value* const> operands,
+ size_t start_index);
+ void EmitInstructionName(std::string_view name, Instruction* inst);
- const Module& mod_;
+ Module& mod_;
utils::StringStream out_;
- utils::Hashset<const Block*, 32> visited_;
- utils::Hashmap<const Block*, size_t, 32> block_ids_;
- utils::Hashmap<const Value*, std::string, 32> value_ids_;
+ utils::Hashset<Block*, 32> visited_;
+ utils::Hashmap<Block*, size_t, 32> block_ids_;
+ utils::Hashmap<Value*, std::string, 32> value_ids_;
uint32_t indent_size_ = 0;
bool in_function_ = false;
uint32_t current_output_line_ = 1;
uint32_t current_output_start_pos_ = 0;
- utils::Hashmap<const Block*, Source, 8> block_to_src_;
- utils::Hashmap<const Instruction*, Source, 8> instruction_to_src_;
- utils::Hashmap<Operand, Source, 8, Operand::Hasher> operand_to_src_;
+ utils::Hashmap<Block*, Source, 8> block_to_src_;
+ utils::Hashmap<Instruction*, Source, 8> instruction_to_src_;
+ utils::Hashmap<Usage, Source, 8, Usage::Hasher> operand_to_src_;
};
} // namespace tint::ir
diff --git a/src/tint/ir/discard_test.cc b/src/tint/ir/discard_test.cc
index 92eb6da..91e1ef1 100644
--- a/src/tint/ir/discard_test.cc
+++ b/src/tint/ir/discard_test.cc
@@ -23,7 +23,7 @@
using IR_DiscardTest = IRTestHelper;
TEST_F(IR_DiscardTest, Discard) {
- const auto* inst = b.Discard();
+ auto* inst = b.Discard();
ASSERT_TRUE(inst->Is<ir::Discard>());
}
diff --git a/src/tint/ir/exit_if.cc b/src/tint/ir/exit_if.cc
index 44a7002..391c786 100644
--- a/src/tint/ir/exit_if.cc
+++ b/src/tint/ir/exit_if.cc
@@ -16,8 +16,8 @@
#include <utility>
-#include "src/tint/ir/block.h"
#include "src/tint/ir/if.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitIf);
@@ -27,7 +27,7 @@
TINT_ASSERT(IR, if_);
if (if_) {
- if_->Merge()->AddInboundBranch(this);
+ if_->Merge()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/exit_if.h b/src/tint/ir/exit_if.h
index 153b714..5ddaabf 100644
--- a/src/tint/ir/exit_if.h
+++ b/src/tint/ir/exit_if.h
@@ -35,7 +35,7 @@
~ExitIf() override;
/// @returns the if being exited
- const ir::If* If() const { return if_; }
+ ir::If* If() { return if_; }
private:
ir::If* if_ = nullptr;
diff --git a/src/tint/ir/exit_if_test.cc b/src/tint/ir/exit_if_test.cc
index 13e0cc1..f496d6c 100644
--- a/src/tint/ir/exit_if_test.cc
+++ b/src/tint/ir/exit_if_test.cc
@@ -27,8 +27,8 @@
TEST_F(IR_ExitIfTest, Usage) {
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* if_ = b.CreateIf(b.Constant(true));
- auto* e = b.ExitIf(if_, utils::Vector{arg1, arg2});
+ auto* if_ = b.If(true);
+ auto* e = b.ExitIf(if_, arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
@@ -49,7 +49,7 @@
{
Module mod;
Builder b{mod};
- b.ExitIf(b.CreateIf(b.Constant(false)), utils::Vector<Value*, 1>{nullptr});
+ b.ExitIf(b.If(false), nullptr);
},
"");
}
diff --git a/src/tint/ir/exit_loop.cc b/src/tint/ir/exit_loop.cc
index 729e466..865b25c 100644
--- a/src/tint/ir/exit_loop.cc
+++ b/src/tint/ir/exit_loop.cc
@@ -18,6 +18,7 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitLoop);
@@ -28,7 +29,7 @@
TINT_ASSERT(IR, loop_);
if (loop_) {
- loop_->Merge()->AddInboundBranch(this);
+ loop_->Merge()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/exit_loop.h b/src/tint/ir/exit_loop.h
index 4ef8110..2e7f2ea 100644
--- a/src/tint/ir/exit_loop.h
+++ b/src/tint/ir/exit_loop.h
@@ -35,7 +35,7 @@
~ExitLoop() override;
/// @returns the loop being exited
- const ir::Loop* Loop() const { return loop_; }
+ ir::Loop* Loop() { return loop_; }
private:
ir::Loop* loop_ = nullptr;
diff --git a/src/tint/ir/exit_loop_test.cc b/src/tint/ir/exit_loop_test.cc
index 11a6b28..2b8b3de 100644
--- a/src/tint/ir/exit_loop_test.cc
+++ b/src/tint/ir/exit_loop_test.cc
@@ -27,8 +27,8 @@
TEST_F(IR_ExitLoopTest, Usage) {
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* loop = b.CreateLoop();
- auto* e = b.ExitLoop(loop, utils::Vector{arg1, arg2});
+ auto* loop = b.Loop();
+ auto* e = b.ExitLoop(loop, arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
@@ -49,7 +49,7 @@
{
Module mod;
Builder b{mod};
- b.ExitLoop(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ b.ExitLoop(b.Loop(), nullptr);
},
"");
}
diff --git a/src/tint/ir/exit_switch.cc b/src/tint/ir/exit_switch.cc
index 29c6a24..3372d6d 100644
--- a/src/tint/ir/exit_switch.cc
+++ b/src/tint/ir/exit_switch.cc
@@ -16,7 +16,7 @@
#include <utility>
-#include "src/tint/ir/block.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/switch.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::ExitSwitch);
@@ -28,7 +28,7 @@
TINT_ASSERT(IR, switch_);
if (switch_) {
- switch_->Merge()->AddInboundBranch(this);
+ switch_->Merge()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/exit_switch.h b/src/tint/ir/exit_switch.h
index 31a23e2..706ac69 100644
--- a/src/tint/ir/exit_switch.h
+++ b/src/tint/ir/exit_switch.h
@@ -35,7 +35,7 @@
~ExitSwitch() override;
/// @returns the switch being exited
- const ir::Switch* Switch() const { return switch_; }
+ ir::Switch* Switch() { return switch_; }
private:
ir::Switch* switch_ = nullptr;
diff --git a/src/tint/ir/exit_switch_test.cc b/src/tint/ir/exit_switch_test.cc
index 945bb5c..85cb228 100644
--- a/src/tint/ir/exit_switch_test.cc
+++ b/src/tint/ir/exit_switch_test.cc
@@ -27,8 +27,8 @@
TEST_F(IR_ExitSwitchTest, Usage) {
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* switch_ = b.CreateSwitch(b.Constant(true));
- auto* e = b.ExitSwitch(switch_, utils::Vector{arg1, arg2});
+ auto* switch_ = b.Switch(true);
+ auto* e = b.ExitSwitch(switch_, arg1, arg2);
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
@@ -49,7 +49,7 @@
{
Module mod;
Builder b{mod};
- b.ExitSwitch(b.CreateSwitch(b.Constant(false)), utils::Vector<Value*, 1>{nullptr});
+ b.ExitSwitch(b.Switch(false), nullptr);
},
"");
}
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index ad4d29f..8ffa513 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -109,8 +109,8 @@
using ResultType = utils::Result<Module, diag::List>;
-bool IsConnected(const Block* b) {
- return b->InboundBranches().Length() > 0;
+bool IsConnected(MultiInBlock* b) {
+ return b->InboundSiblingBranches().Length() > 0;
}
/// Impl is the private-implementation of FromProgram().
@@ -217,7 +217,7 @@
[&](const ast::Variable* var) {
// Setup the current flow node to be the root block for the module. The builder
// will handle creating it if it doesn't exist already.
- TINT_SCOPED_ASSIGNMENT(current_block_, builder_.CreateRootBlockIfNeeded());
+ TINT_SCOPED_ASSIGNMENT(current_block_, builder_.RootBlock());
EmitVariable(var);
},
[&](const ast::Function* func) { EmitFunction(func); },
@@ -264,8 +264,8 @@
const auto* sem = program_->Sem().Get(ast_func);
- auto* ir_func = builder_.CreateFunction(ast_func->name->symbol.NameView(),
- sem->ReturnType()->Clone(clone_ctx_.type_ctx));
+ auto* ir_func = builder_.Function(ast_func->name->symbol.NameView(),
+ sem->ReturnType()->Clone(clone_ctx_.type_ctx));
current_function_ = ir_func;
builder_.ir.functions.Push(ir_func);
@@ -610,7 +610,7 @@
if (!reg) {
return;
}
- auto* if_inst = builder_.CreateIf(reg.Get());
+ auto* if_inst = builder_.If(reg.Get());
current_block_->Append(if_inst);
{
@@ -645,12 +645,9 @@
}
void EmitLoop(const ast::LoopStatement* stmt) {
- auto* loop_inst = builder_.CreateLoop();
+ auto* loop_inst = builder_.Loop();
current_block_->Append(loop_inst);
- // Loop branches directly to the body (no initializer)
- loop_inst->Body()->AddInboundBranch(loop_inst);
-
{
ControlStackScope scope(this, loop_inst);
current_block_ = loop_inst->Body();
@@ -692,12 +689,9 @@
}
void EmitWhile(const ast::WhileStatement* stmt) {
- auto* loop_inst = builder_.CreateLoop();
+ auto* loop_inst = builder_.Loop();
current_block_->Append(loop_inst);
- // Loop branches directly to the body (no initializer)
- loop_inst->Body()->AddInboundBranch(loop_inst);
-
// Continue is always empty, just go back to the start
current_block_ = loop_inst->Continuing();
SetBranch(builder_.NextIteration(loop_inst));
@@ -714,7 +708,7 @@
}
// Create an `if (cond) {} else {break;}` control flow
- auto* if_inst = builder_.CreateIf(reg.Get());
+ auto* if_inst = builder_.If(reg.Get());
current_block_->Append(if_inst);
current_block_ = if_inst->True();
@@ -736,7 +730,7 @@
}
void EmitForLoop(const ast::ForLoopStatement* stmt) {
- auto* loop_inst = builder_.CreateLoop();
+ auto* loop_inst = builder_.Loop();
current_block_->Append(loop_inst);
// Make sure the initializer ends up in a contained scope
@@ -747,16 +741,10 @@
ControlStackScope scope(this, loop_inst);
if (stmt->initializer) {
- // Loop branches to the initializer
- loop_inst->Initializer()->AddInboundBranch(loop_inst);
-
// Emit the for initializer before branching to the body
current_block_ = loop_inst->Initializer();
EmitStatement(stmt->initializer);
SetBranch(builder_.NextIteration(loop_inst));
- } else {
- // If there's no initializer, then the loop branches directly to the body block
- loop_inst->Body()->AddInboundBranch(loop_inst);
}
current_block_ = loop_inst->Body();
@@ -768,7 +756,7 @@
}
// Create an `if (cond) {} else {break;}` control flow
- auto* if_inst = builder_.CreateIf(reg.Get());
+ auto* if_inst = builder_.If(reg.Get());
current_block_->Append(if_inst);
current_block_ = if_inst->True();
@@ -803,7 +791,7 @@
if (!reg) {
return;
}
- auto* switch_inst = builder_.CreateSwitch(reg.Get());
+ auto* switch_inst = builder_.Switch(reg.Get());
current_block_->Append(switch_inst);
{
@@ -820,7 +808,7 @@
}
}
- current_block_ = builder_.CreateCase(switch_inst, selectors);
+ current_block_ = builder_.Case(switch_inst, selectors);
EmitBlock(c->Body()->Declaration());
if (NeedBranch()) {
@@ -836,15 +824,19 @@
}
void EmitReturn(const ast::ReturnStatement* stmt) {
- utils::Vector<Value*, 1> ret_value;
+ Value* ret_value = nullptr;
if (stmt->value) {
auto ret = EmitExpression(stmt->value);
if (!ret) {
return;
}
- ret_value.Push(ret.Get());
+ ret_value = ret.Get();
}
- SetBranch(builder_.Return(current_function_, std::move(ret_value)));
+ if (ret_value) {
+ SetBranch(builder_.Return(current_function_, ret_value));
+ } else {
+ SetBranch(builder_.Return(current_function_));
+ }
}
void EmitBreak(const ast::BreakStatement*) {
@@ -956,9 +948,9 @@
// The access result type should match the source result type. If the source is a pointer,
// we generate a pointer.
const type::Type* ty = nullptr;
- if (info.object->Type()->Is<type::Pointer>() && !info.result_type->Is<type::Pointer>()) {
- auto* ptr = info.object->Type()->As<type::Pointer>();
- ty = builder_.ir.Types().pointer(info.result_type, ptr->AddressSpace(), ptr->Access());
+ if (auto* ptr = info.object->Type()->As<type::Pointer>();
+ ptr && !info.result_type->Is<type::Pointer>()) {
+ ty = builder_.ir.Types().ptr(ptr->AddressSpace(), info.result_type, ptr->Access());
} else {
ty = info.result_type;
}
@@ -1086,10 +1078,10 @@
[&](const ast::Var* v) {
auto* ref = sem->Type()->As<type::Reference>();
auto* ty = builder_.ir.Types().Get<type::Pointer>(
- ref->StoreType()->Clone(clone_ctx_.type_ctx), ref->AddressSpace(),
+ ref->AddressSpace(), ref->StoreType()->Clone(clone_ctx_.type_ctx),
ref->Access());
- auto* val = builder_.Declare(ty);
+ auto* val = builder_.Var(ty);
if (v->initializer) {
auto init = EmitExpression(v->initializer);
if (!init) {
@@ -1192,11 +1184,11 @@
return utils::Failure;
}
- auto* if_inst = builder_.CreateIf(lhs.Get());
+ auto* if_inst = builder_.If(lhs.Get());
current_block_->Append(if_inst);
auto* result = builder_.BlockParam(builder_.ir.Types().bool_());
- if_inst->Merge()->SetParams(utils::Vector{result});
+ if_inst->Merge()->SetParams({result});
utils::Result<Value*> rhs;
{
@@ -1374,20 +1366,19 @@
// If this is a builtin function, emit the specific builtin value
if (auto* b = sem->Target()->As<sem::Builtin>()) {
- inst = builder_.Builtin(ty, b->Type(), args);
+ inst = builder_.Call(ty, b->Type(), args);
} else if (sem->Target()->As<sem::ValueConstructor>()) {
inst = builder_.Construct(ty, std::move(args));
- } else if (auto* conv = sem->Target()->As<sem::ValueConversion>()) {
- auto* from = conv->Source()->Clone(clone_ctx_.type_ctx);
- inst = builder_.Convert(ty, from, std::move(args));
+ } else if (sem->Target()->Is<sem::ValueConversion>()) {
+ inst = builder_.Convert(ty, args[0]);
} else if (expr->target->identifier->Is<ast::TemplatedIdentifier>()) {
TINT_UNIMPLEMENTED(IR, diagnostics_) << "missing templated ident support";
return utils::Failure;
} else {
// Not a builtin and not a templated call, so this is a user function.
- inst = builder_.UserCall(
- ty, scopes_.Get(expr->target->identifier->symbol)->As<ir::Function>(),
- std::move(args));
+ inst =
+ builder_.Call(ty, scopes_.Get(expr->target->identifier->symbol)->As<ir::Function>(),
+ std::move(args));
}
if (inst == nullptr) {
return utils::Failure;
diff --git a/src/tint/ir/from_program_call_test.cc b/src/tint/ir/from_program_call_test.cc
index b2194ec..ce89975 100644
--- a/src/tint/ir/from_program_call_test.cc
+++ b/src/tint/ir/from_program_call_test.cc
@@ -106,7 +106,7 @@
%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
%b2 = block {
%3:i32 = load %i
- %tint_symbol:f32 = convert i32, %3
+ %tint_symbol:f32 = convert %3
ret
}
}
diff --git a/src/tint/ir/from_program_literal_test.cc b/src/tint/ir/from_program_literal_test.cc
index d542a63..e0e6f8b 100644
--- a/src/tint/ir/from_program_literal_test.cc
+++ b/src/tint/ir/from_program_literal_test.cc
@@ -24,13 +24,13 @@
namespace tint::ir {
namespace {
-const Value* GlobalVarInitializer(const Module& m) {
+Value* GlobalVarInitializer(Module& m) {
if (m.root_block->Length() == 0u) {
ADD_FAILURE() << "m.root_block has no instruction";
return nullptr;
}
- const auto instr = m.root_block->Instructions();
+ auto instr = m.root_block->Instructions();
auto* var = instr->As<ir::Var>();
if (!var) {
ADD_FAILURE() << "m.root_block.instructions[0] was not a var";
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index e016481..679ec8a 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -19,6 +19,7 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/program_test_helper.h"
#include "src/tint/ir/switch.h"
@@ -134,13 +135,11 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::If>(m);
+ auto* if_ = FindSingleValue<ir::If>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->False()->InboundBranches().Length());
- EXPECT_EQ(2u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(2u, if_->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -174,13 +173,11 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::If>(m);
+ auto* if_ = FindSingleValue<ir::If>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, if_->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -214,13 +211,11 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::If>(m);
+ auto* if_ = FindSingleValue<ir::If>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, if_->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -254,13 +249,11 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::If>(m);
+ auto* if_ = FindSingleValue<ir::If>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->False()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, if_->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -290,11 +283,6 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* if_flow = FindSingleValue<ir::If>(m);
- ASSERT_NE(if_flow, nullptr);
-
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
- ASSERT_NE(loop_flow, nullptr);
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -338,13 +326,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -374,18 +362,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
-
- auto* if_flow = FindSingleValue<ir::If>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, loop_flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -435,13 +418,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, loop_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -510,17 +493,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
- auto* if_flow = FindSingleValue<ir::If>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(1u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -564,13 +543,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -603,13 +582,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -634,17 +613,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* loop_flow = FindSingleValue<ir::Loop>(m);
- auto* if_flow = FindSingleValue<ir::If>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, loop_flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, loop_flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(2u, loop_flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(0u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(2u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -809,20 +784,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
-
- ASSERT_NE(flow->Body()->Branch(), nullptr);
- ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Body()->Branch()->As<ir::If>();
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -871,20 +839,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
-
- ASSERT_NE(flow->Body()->Branch(), nullptr);
- ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Body()->Branch()->As<ir::If>();
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -946,20 +907,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
-
- ASSERT_NE(flow->Body()->Branch(), nullptr);
- ASSERT_TRUE(flow->Body()->Branch()->Is<ir::If>());
- auto* if_flow = flow->Body()->Branch()->As<ir::If>();
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(2u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(2u, flow->Merge()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->True()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->False()->InboundBranches().Length());
- EXPECT_EQ(1u, if_flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(2u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(2u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m), R"()");
}
@@ -972,14 +926,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->Initializer()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -1014,13 +967,13 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- auto* flow = FindSingleValue<ir::Loop>(m);
+ auto* loop = FindSingleValue<ir::Loop>(m);
ASSERT_EQ(1u, m.functions.Length());
- EXPECT_EQ(1u, flow->Body()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Continuing()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, loop->Body()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(1u, loop->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -1072,10 +1025,7 @@
ASSERT_EQ(1u, cases[2].selectors.Length());
EXPECT_TRUE(cases[2].selectors[0].IsDefault());
- EXPECT_EQ(1u, cases[0].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, cases[1].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, cases[2].Start()->InboundBranches().Length());
- EXPECT_EQ(3u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(3u, flow->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -1135,8 +1085,7 @@
EXPECT_TRUE(cases[0].selectors[2].IsDefault());
- EXPECT_EQ(1u, cases[0].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -1174,8 +1123,7 @@
ASSERT_EQ(1u, cases[0].selectors.Length());
EXPECT_TRUE(cases[0].selectors[0].IsDefault());
- EXPECT_EQ(1u, cases[0].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(1u, flow->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -1220,9 +1168,7 @@
ASSERT_EQ(1u, cases[1].selectors.Length());
EXPECT_TRUE(cases[1].selectors[0].IsDefault());
- EXPECT_EQ(1u, cases[0].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, cases[1].Start()->InboundBranches().Length());
- EXPECT_EQ(2u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(2u, flow->Merge()->InboundSiblingBranches().Length());
// This is 1 because the if is dead-code eliminated and the return doesn't happen.
EXPECT_EQ(Disassemble(m),
@@ -1260,7 +1206,6 @@
ASSERT_TRUE(res) << (!res ? res.Failure() : "");
auto m = res.Move();
- ASSERT_EQ(FindSingleValue<ir::If>(m), nullptr);
auto* flow = FindSingleValue<ir::Switch>(m);
@@ -1276,9 +1221,7 @@
ASSERT_EQ(1u, cases[1].selectors.Length());
EXPECT_TRUE(cases[1].selectors[0].IsDefault());
- EXPECT_EQ(1u, cases[0].Start()->InboundBranches().Length());
- EXPECT_EQ(1u, cases[1].Start()->InboundBranches().Length());
- EXPECT_EQ(0u, flow->Merge()->InboundBranches().Length());
+ EXPECT_EQ(0u, flow->Merge()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
diff --git a/src/tint/ir/function.cc b/src/tint/ir/function.cc
index fa13fb8..10b896a 100644
--- a/src/tint/ir/function.cc
+++ b/src/tint/ir/function.cc
@@ -14,6 +14,8 @@
#include "src/tint/ir/function.h"
+#include "src/tint/utils/predicates.h"
+
TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
namespace tint::ir {
@@ -31,9 +33,12 @@
void Function::SetParams(utils::VectorRef<FunctionParam*> params) {
params_ = std::move(params);
- for (auto* param : params_) {
- TINT_ASSERT(IR, param != nullptr);
- }
+ TINT_ASSERT(IR, !params_.Any(utils::IsNull));
+}
+
+void Function::SetParams(std::initializer_list<FunctionParam*> params) {
+ params_ = params;
+ TINT_ASSERT(IR, !params_.Any(utils::IsNull));
}
utils::StringStream& operator<<(utils::StringStream& out, Function::PipelineStage value) {
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index 775033d..a9e1801 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -71,7 +71,7 @@
void SetStage(PipelineStage stage) { pipeline_stage_ = stage; }
/// @returns the function pipeline stage
- PipelineStage Stage() const { return pipeline_stage_; }
+ PipelineStage Stage() { return pipeline_stage_; }
/// Sets the workgroup size
/// @param x the x size
@@ -80,10 +80,10 @@
void SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { workgroup_size_ = {x, y, z}; }
/// @returns the workgroup size information
- std::optional<std::array<uint32_t, 3>> WorkgroupSize() const { return workgroup_size_; }
+ std::optional<std::array<uint32_t, 3>> WorkgroupSize() { return workgroup_size_; }
/// @returns the return type for the function
- const type::Type* ReturnType() const { return return_.type; }
+ const type::Type* ReturnType() { return return_.type; }
/// Sets the return attributes
/// @param builtin the builtin to set
@@ -92,7 +92,7 @@
return_.builtin = builtin;
}
/// @returns the return builtin attribute
- std::optional<enum ReturnBuiltin> ReturnBuiltin() const { return return_.builtin; }
+ std::optional<enum ReturnBuiltin> ReturnBuiltin() { return return_.builtin; }
/// Sets the return location
/// @param loc the location to set
@@ -101,19 +101,24 @@
return_.location = {loc, interp};
}
/// @returns the return location
- std::optional<Location> ReturnLocation() const { return return_.location; }
+ std::optional<Location> ReturnLocation() { return return_.location; }
/// Sets the return as invariant
/// @param val the invariant value to set
void SetReturnInvariant(bool val) { return_.invariant = val; }
/// @returns the return invariant value
- bool ReturnInvariant() const { return return_.invariant; }
+ bool ReturnInvariant() { return return_.invariant; }
/// Sets the function parameters
- /// @param params the function paramters
+ /// @param params the function parameters
void SetParams(utils::VectorRef<FunctionParam*> params);
+
+ /// Sets the function parameters
+ /// @param params the function parameters
+ void SetParams(std::initializer_list<FunctionParam*> params);
+
/// @returns the function parameters
- utils::VectorRef<FunctionParam*> Params() const { return params_; }
+ const utils::VectorRef<FunctionParam*> Params() { return params_; }
/// Sets the start target for the function
/// @param target the start target
@@ -122,7 +127,7 @@
start_target_ = target;
}
/// @returns the function start target
- Block* StartTarget() const { return start_target_; }
+ Block* StartTarget() { return start_target_; }
private:
PipelineStage pipeline_stage_;
diff --git a/src/tint/ir/function_param.h b/src/tint/ir/function_param.h
index 45ce4e9..6a1e8ca 100644
--- a/src/tint/ir/function_param.h
+++ b/src/tint/ir/function_param.h
@@ -60,7 +60,7 @@
~FunctionParam() override;
/// @returns the type of the var
- const type::Type* Type() const override { return type_; }
+ const type::Type* Type() override { return type_; }
/// Sets the builtin information. Note, it is currently an error if the builtin is already set.
/// @param val the builtin to set
@@ -69,13 +69,13 @@
builtin_ = val;
}
/// @returns the builtin set for the parameter
- std::optional<FunctionParam::Builtin> Builtin() const { return builtin_; }
+ std::optional<FunctionParam::Builtin> Builtin() { return builtin_; }
/// Sets the parameter as invariant
/// @param val the value to set for invariant
void SetInvariant(bool val) { invariant_ = val; }
/// @returns true if parameter is invariant
- bool Invariant() const { return invariant_; }
+ bool Invariant() { return invariant_; }
/// Sets the location
/// @param loc the location value
@@ -84,14 +84,14 @@
location_ = {loc, interpolation};
}
/// @returns the location if `Attributes` contains `kLocation`
- std::optional<struct Location> Location() const { return location_; }
+ std::optional<struct Location> Location() { return location_; }
/// Sets the binding point
/// @param group the group
/// @param binding the binding
void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
/// @returns the binding points if `Attributes` contains `kBindingPoint`
- std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }
+ std::optional<struct BindingPoint> BindingPoint() { return binding_point_; }
private:
const type::Type* type_ = nullptr;
diff --git a/src/tint/ir/function_test.cc b/src/tint/ir/function_test.cc
index e8ee091..35076df 100644
--- a/src/tint/ir/function_test.cc
+++ b/src/tint/ir/function_test.cc
@@ -27,7 +27,7 @@
{
Module mod;
Builder b{mod};
- b.CreateFunction("my_func", nullptr);
+ b.Function("my_func", nullptr);
},
"");
}
@@ -37,7 +37,7 @@
{
Module mod;
Builder b{mod};
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ auto* f = b.Function("my_func", mod.Types().void_());
f->SetReturnBuiltin(Function::ReturnBuiltin::kFragDepth);
f->SetReturnBuiltin(Function::ReturnBuiltin::kPosition);
},
@@ -49,8 +49,8 @@
{
Module mod;
Builder b{mod};
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
- f->SetParams(utils::Vector<FunctionParam*, 1>{nullptr});
+ auto* f = b.Function("my_func", mod.Types().void_());
+ f->SetParams({nullptr});
},
"");
}
@@ -60,7 +60,7 @@
{
Module mod;
Builder b{mod};
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ auto* f = b.Function("my_func", mod.Types().void_());
f->SetStartTarget(nullptr);
},
"");
diff --git a/src/tint/ir/if.cc b/src/tint/ir/if.cc
index 29bcbb2..7f34986 100644
--- a/src/tint/ir/if.cc
+++ b/src/tint/ir/if.cc
@@ -16,11 +16,12 @@
TINT_INSTANTIATE_TYPEINFO(tint::ir::If);
-#include "src/tint/ir/block.h"
+#include "src/tint/ir/multi_in_block.h"
namespace tint::ir {
-If::If(Value* cond, ir::Block* t, ir::Block* f, ir::Block* m) : true_(t), false_(f), merge_(m) {
+If::If(Value* cond, ir::Block* t, ir::Block* f, ir::MultiInBlock* m)
+ : true_(t), false_(f), merge_(m) {
TINT_ASSERT(IR, cond);
TINT_ASSERT(IR, true_);
TINT_ASSERT(IR, false_);
@@ -28,11 +29,9 @@
AddOperand(cond);
if (true_) {
- true_->AddInboundBranch(this);
true_->SetParent(this);
}
if (false_) {
- false_->AddInboundBranch(this);
false_->SetParent(this);
}
if (merge_) {
diff --git a/src/tint/ir/if.h b/src/tint/ir/if.h
index 4566ef8..4434116 100644
--- a/src/tint/ir/if.h
+++ b/src/tint/ir/if.h
@@ -17,6 +17,11 @@
#include "src/tint/ir/control_instruction.h"
+// Forward declarations
+namespace tint::ir {
+class MultiInBlock;
+} // namespace tint::ir
+
namespace tint::ir {
/// If instruction.
@@ -41,43 +46,35 @@
class If : public utils::Castable<If, ControlInstruction> {
public:
/// The index of the condition operand
- static constexpr size_t kConditionOperandIndex = 0;
+ static constexpr size_t kConditionOperandOffset = 0;
/// Constructor
/// @param cond the if condition
/// @param t the true block
/// @param f the false block
/// @param m the merge block
- explicit If(Value* cond, ir::Block* t, ir::Block* f, ir::Block* m);
+ If(Value* cond, ir::Block* t, ir::Block* f, ir::MultiInBlock* m);
~If() override;
/// @returns the branch arguments
- utils::Slice<Value const* const> Args() const override { return utils::Slice<Value*>{}; }
+ utils::Slice<Value* const> Args() override { return utils::Slice<Value*>{}; }
/// @returns the if condition
- const Value* Condition() const { return operands_[kConditionOperandIndex]; }
- /// @returns the if condition
- Value* Condition() { return operands_[kConditionOperandIndex]; }
+ Value* Condition() { return operands_[kConditionOperandOffset]; }
/// @returns the true branch block
- const ir::Block* True() const { return true_; }
- /// @returns the true branch block
ir::Block* True() { return true_; }
/// @returns the false branch block
- const ir::Block* False() const { return false_; }
- /// @returns the false branch block
ir::Block* False() { return false_; }
/// @returns the merge branch block
- const ir::Block* Merge() const { return merge_; }
- /// @returns the merge branch block
- ir::Block* Merge() { return merge_; }
+ ir::MultiInBlock* Merge() { return merge_; }
private:
ir::Block* true_ = nullptr;
ir::Block* false_ = nullptr;
- ir::Block* merge_ = nullptr;
+ ir::MultiInBlock* merge_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/if_test.cc b/src/tint/ir/if_test.cc
index 4d5bcdf..60915a6 100644
--- a/src/tint/ir/if_test.cc
+++ b/src/tint/ir/if_test.cc
@@ -25,13 +25,13 @@
TEST_F(IR_IfTest, Usage) {
auto* cond = b.Constant(true);
- auto* if_ = b.CreateIf(cond);
+ auto* if_ = b.If(cond);
EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{if_, 0u}));
}
TEST_F(IR_IfTest, Parent) {
auto* cond = b.Constant(true);
- auto* if_ = b.CreateIf(cond);
+ auto* if_ = b.If(cond);
EXPECT_EQ(if_->True()->Parent(), if_);
EXPECT_EQ(if_->False()->Parent(), if_);
EXPECT_EQ(if_->Merge()->Parent(), if_);
@@ -42,7 +42,7 @@
{
Module mod;
Builder b{mod};
- b.CreateIf(nullptr);
+ b.If(nullptr);
},
"");
}
@@ -52,7 +52,7 @@
{
Module mod;
Builder b{mod};
- If if_(b.Constant(false), nullptr, b.CreateBlock(), b.CreateBlock());
+ If if_(b.Constant(false), nullptr, b.Block(), b.MultiInBlock());
},
"");
}
@@ -62,17 +62,17 @@
{
Module mod;
Builder b{mod};
- If if_(b.Constant(false), b.CreateBlock(), nullptr, b.CreateBlock());
+ If if_(b.Constant(false), b.Block(), nullptr, b.MultiInBlock());
},
"");
}
-TEST_F(IR_IfTest, Fail_NullMergeBlock) {
+TEST_F(IR_IfTest, Fail_NullMultiInBlock) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- If if_(b.Constant(false), b.CreateBlock(), b.CreateBlock(), nullptr);
+ If if_(b.Constant(false), b.Block(), b.Block(), nullptr);
},
"");
}
diff --git a/src/tint/ir/instruction.h b/src/tint/ir/instruction.h
index c19171b..4cf7d70 100644
--- a/src/tint/ir/instruction.h
+++ b/src/tint/ir/instruction.h
@@ -37,8 +37,6 @@
/// @returns the block that owns this instruction
ir::Block* Block() { return block_; }
- /// @returns the block that owns this instruction
- const ir::Block* Block() const { return block_; }
/// Adds the new instruction before the given instruction in the owning block
/// @param before the instruction to insert before
diff --git a/src/tint/ir/instruction_test.cc b/src/tint/ir/instruction_test.cc
index 3a19cd1..2da2102 100644
--- a/src/tint/ir/instruction_test.cc
+++ b/src/tint/ir/instruction_test.cc
@@ -24,9 +24,9 @@
using IR_InstructionTest = IRTestHelper;
TEST_F(IR_InstructionTest, InsertBefore) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst2);
inst1->InsertBefore(inst2);
EXPECT_EQ(2u, blk->Length());
@@ -39,7 +39,7 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
+ auto* inst1 = b.Loop();
inst1->InsertBefore(nullptr);
},
"");
@@ -51,17 +51,17 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
inst1->InsertBefore(inst2);
},
"");
}
TEST_F(IR_InstructionTest, InsertAfter) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst2);
inst1->InsertAfter(inst2);
EXPECT_EQ(2u, blk->Length());
@@ -74,7 +74,7 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
+ auto* inst1 = b.Loop();
inst1->InsertAfter(nullptr);
},
"");
@@ -86,17 +86,17 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
inst1->InsertAfter(inst2);
},
"");
}
TEST_F(IR_InstructionTest, ReplaceWith) {
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst2);
inst2->ReplaceWith(inst1);
EXPECT_EQ(1u, blk->Length());
@@ -110,8 +110,8 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
inst1->ReplaceWith(nullptr);
},
@@ -124,16 +124,16 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
- auto* inst2 = b.CreateLoop();
+ auto* inst1 = b.Loop();
+ auto* inst2 = b.Loop();
inst1->ReplaceWith(inst2);
},
"");
}
TEST_F(IR_InstructionTest, Remove) {
- auto* inst1 = b.CreateLoop();
- auto* blk = b.CreateBlock();
+ auto* inst1 = b.Loop();
+ auto* blk = b.Block();
blk->Append(inst1);
EXPECT_EQ(1u, blk->Length());
@@ -148,7 +148,7 @@
Module mod;
Builder b{mod};
- auto* inst1 = b.CreateLoop();
+ auto* inst1 = b.Loop();
inst1->Remove();
},
"");
diff --git a/src/tint/ir/ir_test_helper.h b/src/tint/ir/ir_test_helper.h
index a7826ec..737a596 100644
--- a/src/tint/ir/ir_test_helper.h
+++ b/src/tint/ir/ir_test_helper.h
@@ -32,6 +32,17 @@
Module mod;
/// The IR builder
Builder b{mod};
+ /// The type manager
+ type::Manager& ty{mod.Types()};
+
+ /// Alias to builtin::AddressSpace::kStorage
+ static constexpr auto storage = builtin::AddressSpace::kStorage;
+ /// Alias to builtin::AddressSpace::kUniform
+ static constexpr auto uniform = builtin::AddressSpace::kUniform;
+ /// Alias to builtin::AddressSpace::kPrivate
+ static constexpr auto private_ = builtin::AddressSpace::kPrivate;
+ /// Alias to builtin::AddressSpace::kFunction
+ static constexpr auto function = builtin::AddressSpace::kFunction;
};
using IRTestHelper = IRTestHelperBase<testing::Test>;
diff --git a/src/tint/ir/load.cc b/src/tint/ir/load.cc
index 7761192..4beb11e 100644
--- a/src/tint/ir/load.cc
+++ b/src/tint/ir/load.cc
@@ -13,17 +13,20 @@
// limitations under the License.
#include "src/tint/ir/load.h"
+
#include "src/tint/debug.h"
+#include "src/tint/type/pointer.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Load);
namespace tint::ir {
-Load::Load(const type::Type* type, Value* f) : Base(), result_type_(type) {
- TINT_ASSERT(IR, result_type_);
- TINT_ASSERT(IR, f);
+Load::Load(Value* from) {
+ TINT_ASSERT_OR_RETURN(IR, from);
+ TINT_ASSERT_OR_RETURN(IR, tint::Is<type::Pointer>(from->Type()));
- AddOperand(f);
+ result_type_ = from->Type()->UnwrapPtr();
+ AddOperand(from);
}
Load::~Load() = default;
diff --git a/src/tint/ir/load.h b/src/tint/ir/load.h
index 3a33320..29b3b22 100644
--- a/src/tint/ir/load.h
+++ b/src/tint/ir/load.h
@@ -23,17 +23,17 @@
/// A load instruction in the IR.
class Load : public utils::Castable<Load, OperandInstruction<1>> {
public:
- /// Constructor
- /// @param type the result type
+ /// Constructor (infers type)
/// @param from the value being loaded from
- Load(const type::Type* type, Value* from);
+ explicit Load(Value* from);
+
~Load() override;
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the value being loaded from
- Value* From() const { return operands_[0]; }
+ Value* From() { return operands_[0]; }
private:
const type::Type* result_type_ = nullptr;
diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc
index ad21d13..0cfba85 100644
--- a/src/tint/ir/load_test.cc
+++ b/src/tint/ir/load_test.cc
@@ -26,10 +26,10 @@
using IR_LoadTest = IRTestHelper;
TEST_F(IR_LoadTest, Create) {
- auto* store_type = mod.Types().i32();
- auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite));
- const auto* inst = b.Load(var);
+ auto* store_type = ty.i32();
+ auto* var =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, store_type, builtin::Access::kReadWrite));
+ auto* inst = b.Load(var);
ASSERT_TRUE(inst->Is<Load>());
ASSERT_EQ(inst->From(), var);
@@ -41,35 +41,21 @@
}
TEST_F(IR_LoadTest, Usage) {
- auto* store_type = mod.Types().i32();
- auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite));
+ auto* store_type = ty.i32();
+ auto* var =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, store_type, builtin::Access::kReadWrite));
auto* inst = b.Load(var);
ASSERT_NE(inst->From(), nullptr);
EXPECT_THAT(inst->From()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
}
-TEST_F(IR_LoadTest, Fail_NullType) {
- EXPECT_FATAL_FAILURE(
- {
- Module mod;
- Builder b{mod};
-
- auto* store_type = mod.Types().i32();
- auto* var = b.Declare(mod.Types().pointer(store_type, builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite));
- Load l(nullptr, var);
- },
- "");
-}
-
TEST_F(IR_LoadTest, Fail_NonPtr_Builder) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- b.Load(b.Declare(mod.Types().f32()));
+ b.Load(b.Constant(1_i));
},
"");
}
@@ -89,7 +75,7 @@
{
Module mod;
Builder b{mod};
- Load l(mod.Types().f32(), nullptr);
+ Load l(nullptr);
},
"");
}
diff --git a/src/tint/ir/loop.cc b/src/tint/ir/loop.cc
index 37c4f00..4f25fcd 100644
--- a/src/tint/ir/loop.cc
+++ b/src/tint/ir/loop.cc
@@ -16,13 +16,13 @@
#include <utility>
-#include "src/tint/ir/block.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Loop);
namespace tint::ir {
-Loop::Loop(ir::Block* i, ir::Block* b, ir::Block* c, ir::Block* m)
+Loop::Loop(ir::Block* i, ir::MultiInBlock* b, ir::MultiInBlock* c, ir::MultiInBlock* m)
: initializer_(i), body_(b), continuing_(c), merge_(m) {
TINT_ASSERT(IR, initializer_);
TINT_ASSERT(IR, body_);
@@ -45,7 +45,7 @@
Loop::~Loop() = default;
-bool Loop::HasInitializer() const {
+bool Loop::HasInitializer() {
return initializer_->HasBranchTarget();
}
diff --git a/src/tint/ir/loop.h b/src/tint/ir/loop.h
index 50c69c7..31f18fe 100644
--- a/src/tint/ir/loop.h
+++ b/src/tint/ir/loop.h
@@ -17,6 +17,11 @@
#include "src/tint/ir/control_instruction.h"
+// Forward declarations
+namespace tint::ir {
+class MultiInBlock;
+} // namespace tint::ir
+
namespace tint::ir {
/// Loop instruction.
@@ -62,38 +67,30 @@
/// @param b the body block
/// @param c the continuing block
/// @param m the merge block
- Loop(ir::Block* i, ir::Block* b, ir::Block* c, ir::Block* m);
+ Loop(ir::Block* i, ir::MultiInBlock* b, ir::MultiInBlock* c, ir::MultiInBlock* m);
~Loop() override;
/// @returns the switch initializer block
- const ir::Block* Initializer() const { return initializer_; }
- /// @returns the switch initializer block
ir::Block* Initializer() { return initializer_; }
/// @returns true if the loop uses an initializer block. If true, then the Loop first branches
/// to the initializer block, otherwise it first branches to the body block.
- bool HasInitializer() const;
+ bool HasInitializer();
/// @returns the switch start block
- const ir::Block* Body() const { return body_; }
- /// @returns the switch start block
- ir::Block* Body() { return body_; }
+ ir::MultiInBlock* Body() { return body_; }
/// @returns the switch continuing block
- const ir::Block* Continuing() const { return continuing_; }
- /// @returns the switch continuing block
- ir::Block* Continuing() { return continuing_; }
+ ir::MultiInBlock* Continuing() { return continuing_; }
/// @returns the switch merge branch
- const ir::Block* Merge() const { return merge_; }
- /// @returns the switch merge branch
- ir::Block* Merge() { return merge_; }
+ ir::MultiInBlock* Merge() { return merge_; }
private:
ir::Block* initializer_ = nullptr;
- ir::Block* body_ = nullptr;
- ir::Block* continuing_ = nullptr;
- ir::Block* merge_ = nullptr;
+ ir::MultiInBlock* body_ = nullptr;
+ ir::MultiInBlock* continuing_ = nullptr;
+ ir::MultiInBlock* merge_ = nullptr;
};
} // namespace tint::ir
diff --git a/src/tint/ir/loop_test.cc b/src/tint/ir/loop_test.cc
index e51a005..38fecfa 100644
--- a/src/tint/ir/loop_test.cc
+++ b/src/tint/ir/loop_test.cc
@@ -23,7 +23,7 @@
using IR_LoopTest = IRTestHelper;
TEST_F(IR_LoopTest, Parent) {
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
EXPECT_EQ(loop->Initializer()->Parent(), loop);
EXPECT_EQ(loop->Body()->Parent(), loop);
EXPECT_EQ(loop->Continuing()->Parent(), loop);
@@ -35,7 +35,7 @@
{
Module mod;
Builder b{mod};
- Loop loop(nullptr, b.CreateBlock(), b.CreateBlock(), b.CreateBlock());
+ Loop loop(nullptr, b.MultiInBlock(), b.MultiInBlock(), b.MultiInBlock());
},
"");
}
@@ -45,7 +45,7 @@
{
Module mod;
Builder b{mod};
- Loop loop(b.CreateBlock(), nullptr, b.CreateBlock(), b.CreateBlock());
+ Loop loop(b.Block(), nullptr, b.MultiInBlock(), b.MultiInBlock());
},
"");
}
@@ -55,17 +55,17 @@
{
Module mod;
Builder b{mod};
- Loop loop(b.CreateBlock(), b.CreateBlock(), nullptr, b.CreateBlock());
+ Loop loop(b.Block(), b.MultiInBlock(), nullptr, b.MultiInBlock());
},
"");
}
-TEST_F(IR_LoopTest, Fail_NullMergeBlock) {
+TEST_F(IR_LoopTest, Fail_NullMultiInBlock) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
- Loop loop(b.CreateBlock(), b.CreateBlock(), b.CreateBlock(), nullptr);
+ Loop loop(b.Block(), b.MultiInBlock(), b.MultiInBlock(), nullptr);
},
"");
}
diff --git a/src/tint/ir/module.cc b/src/tint/ir/module.cc
index acee4ae..3b86a80 100644
--- a/src/tint/ir/module.cc
+++ b/src/tint/ir/module.cc
@@ -26,11 +26,11 @@
Module& Module::operator=(Module&&) = default;
-Symbol Module::NameOf(const Value* value) const {
+Symbol Module::NameOf(Value* value) {
return value_to_id_.Get(value).value_or(Symbol{});
}
-Symbol Module::SetName(const Value* value, std::string_view name) {
+Symbol Module::SetName(Value* value, std::string_view name) {
TINT_ASSERT(IR, !name.empty());
if (auto old = value_to_id_.Get(value)) {
diff --git a/src/tint/ir/module.h b/src/tint/ir/module.h
index d3393e6..4f33f26 100644
--- a/src/tint/ir/module.h
+++ b/src/tint/ir/module.h
@@ -39,10 +39,10 @@
ProgramID prog_id_;
/// Map of value to pre-declared identifier
- utils::Hashmap<const Value*, Symbol, 32> value_to_id_;
+ utils::Hashmap<Value*, Symbol, 32> value_to_id_;
/// Map of pre-declared identifier to value
- utils::Hashmap<Symbol, const Value*, 32> id_to_value_;
+ utils::Hashmap<Symbol, Value*, 32> id_to_value_;
public:
/// Constructor
@@ -60,15 +60,12 @@
/// @param value the value
/// @return the name of the given value, or an invalid symbol if the value is not named.
- Symbol NameOf(const Value* value) const;
+ Symbol NameOf(Value* value);
/// @param value the value to name.
/// @param name the desired name of the value. May be suffixed on collision.
/// @return the unique symbol of the given value.
- Symbol SetName(const Value* value, std::string_view name);
-
- /// @return the type manager for the module
- const type::Manager& Types() const { return constant_values.types; }
+ Symbol SetName(Value* value, std::string_view name);
/// @return the type manager for the module
type::Manager& Types() { return constant_values.types; }
diff --git a/src/tint/ir/module_test.cc b/src/tint/ir/module_test.cc
index 945e499..61cd5d4 100644
--- a/src/tint/ir/module_test.cc
+++ b/src/tint/ir/module_test.cc
@@ -21,30 +21,35 @@
using namespace tint::number_suffixes; // NOLINT
-using IR_ModuleTest = IRTestHelper;
+class IR_ModuleTest : public IRTestHelper {
+ protected:
+ const type::Pointer* ptr(const type::Type* elem) {
+ return ty.ptr(builtin::AddressSpace::kFunction, elem, builtin::Access::kReadWrite);
+ }
+};
TEST_F(IR_ModuleTest, NameOfUnnamed) {
- auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* v = mod.values.Create<ir::Var>(ptr(ty.i32()));
EXPECT_FALSE(mod.NameOf(v).IsValid());
}
TEST_F(IR_ModuleTest, SetName) {
- auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* v = mod.values.Create<ir::Var>(ptr(ty.i32()));
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.NameOf(v).Name(), "a");
}
TEST_F(IR_ModuleTest, SetNameRename) {
- auto* v = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* v = mod.values.Create<ir::Var>(ptr(ty.i32()));
EXPECT_EQ(mod.SetName(v, "a").Name(), "a");
EXPECT_EQ(mod.SetName(v, "b").Name(), "b");
EXPECT_EQ(mod.NameOf(v).Name(), "b");
}
TEST_F(IR_ModuleTest, SetNameCollision) {
- auto* a = mod.values.Create<ir::Var>(mod.Types().i32());
- auto* b = mod.values.Create<ir::Var>(mod.Types().i32());
- auto* c = mod.values.Create<ir::Var>(mod.Types().i32());
+ auto* a = mod.values.Create<ir::Var>(ptr(ty.i32()));
+ auto* b = mod.values.Create<ir::Var>(ptr(ty.i32()));
+ auto* c = mod.values.Create<ir::Var>(ptr(ty.i32()));
EXPECT_EQ(mod.SetName(a, "x").Name(), "x");
EXPECT_EQ(mod.SetName(b, "x_1").Name(), "x_1");
EXPECT_EQ(mod.SetName(c, "x").Name(), "x_2");
diff --git a/src/tint/ir/multi_in_block.cc b/src/tint/ir/multi_in_block.cc
new file mode 100644
index 0000000..a86db11
--- /dev/null
+++ b/src/tint/ir/multi_in_block.cc
@@ -0,0 +1,47 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/multi_in_block.h"
+
+#include "src/tint/utils/predicates.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ir::MultiInBlock);
+
+namespace tint::ir {
+
+MultiInBlock::MultiInBlock() : Base() {}
+
+MultiInBlock::~MultiInBlock() = default;
+
+void MultiInBlock::SetParams(utils::VectorRef<BlockParam*> params) {
+ params_ = std::move(params);
+
+ TINT_ASSERT(IR, !params_.Any(utils::IsNull));
+}
+
+void MultiInBlock::SetParams(std::initializer_list<BlockParam*> params) {
+ params_ = std::move(params);
+
+ TINT_ASSERT(IR, !params_.Any(utils::IsNull));
+}
+
+void MultiInBlock::AddInboundSiblingBranch(ir::Branch* node) {
+ TINT_ASSERT(IR, node != nullptr);
+
+ if (node) {
+ inbound_sibling_branches_.Push(node);
+ }
+}
+
+} // namespace tint::ir
diff --git a/src/tint/ir/multi_in_block.h b/src/tint/ir/multi_in_block.h
new file mode 100644
index 0000000..6f9b3ee
--- /dev/null
+++ b/src/tint/ir/multi_in_block.h
@@ -0,0 +1,67 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_IR_MULTI_IN_BLOCK_H_
+#define SRC_TINT_IR_MULTI_IN_BLOCK_H_
+
+#include <utility>
+
+#include "src/tint/ir/block.h"
+
+// Forward declarations
+namespace tint::ir {
+class BlockParam;
+}
+
+namespace tint::ir {
+
+/// A block that can be the target of multiple branches.
+/// MultiInBlocks maintain a list of inbound branches from branch instructions excluding ir::If,
+/// ir::Switch and ir::Loop which implicitly branch to the internal block.
+/// MultiInBlocks hold a number of BlockParam parameters, used to pass values from the branch source
+/// to this target.
+class MultiInBlock : public utils::Castable<MultiInBlock, Block> {
+ public:
+ /// Constructor
+ MultiInBlock();
+ ~MultiInBlock() override;
+
+ /// Sets the params to the block
+ /// @param params the params for the block
+ void SetParams(utils::VectorRef<BlockParam*> params);
+
+ /// Sets the params to the block
+ /// @param params the params for the block
+ void SetParams(std::initializer_list<BlockParam*> params);
+
+ /// @returns the params to the block
+ const utils::Vector<BlockParam*, 2>& Params() { return params_; }
+
+ /// @returns branches made to this block by sibling blocks
+ const utils::VectorRef<ir::Branch*> InboundSiblingBranches() {
+ return inbound_sibling_branches_;
+ }
+
+ /// Adds the given branch to the list of branches made to this block by sibling blocks
+ /// @param branch the branch to add
+ void AddInboundSiblingBranch(ir::Branch* branch);
+
+ private:
+ utils::Vector<BlockParam*, 2> params_;
+ utils::Vector<ir::Branch*, 2> inbound_sibling_branches_;
+};
+
+} // namespace tint::ir
+
+#endif // SRC_TINT_IR_MULTI_IN_BLOCK_H_
diff --git a/src/tint/ir/multi_in_block_test.cc b/src/tint/ir/multi_in_block_test.cc
new file mode 100644
index 0000000..2d82715
--- /dev/null
+++ b/src/tint/ir/multi_in_block_test.cc
@@ -0,0 +1,51 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ir/multi_in_block.h"
+#include "gtest/gtest-spi.h"
+#include "src/tint/ir/block_param.h"
+#include "src/tint/ir/ir_test_helper.h"
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using IR_MultiInBlockTest = IRTestHelper;
+
+TEST_F(IR_MultiInBlockTest, Fail_NullBlockParam) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+
+ auto* blk = b.MultiInBlock();
+ blk->SetParams({nullptr});
+ },
+ "");
+}
+
+TEST_F(IR_MultiInBlockTest, Fail_NullInboundBranch) {
+ EXPECT_FATAL_FAILURE(
+ {
+ Module mod;
+ Builder b{mod};
+
+ auto* blk = b.MultiInBlock();
+ blk->AddInboundSiblingBranch(nullptr);
+ },
+ "");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/next_iteration.cc b/src/tint/ir/next_iteration.cc
index b1bf620..dc1ccd2 100644
--- a/src/tint/ir/next_iteration.cc
+++ b/src/tint/ir/next_iteration.cc
@@ -16,8 +16,8 @@
#include <utility>
-#include "src/tint/ir/block.h"
#include "src/tint/ir/loop.h"
+#include "src/tint/ir/multi_in_block.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::NextIteration);
@@ -26,9 +26,8 @@
NextIteration::NextIteration(ir::Loop* loop, utils::VectorRef<Value*> args /* = utils::Empty */)
: loop_(loop) {
TINT_ASSERT(IR, loop_);
-
if (loop_) {
- loop_->Body()->AddInboundBranch(this);
+ loop_->Body()->AddInboundSiblingBranch(this);
}
AddOperands(std::move(args));
}
diff --git a/src/tint/ir/next_iteration.h b/src/tint/ir/next_iteration.h
index 064ad3f..e80f628 100644
--- a/src/tint/ir/next_iteration.h
+++ b/src/tint/ir/next_iteration.h
@@ -35,7 +35,7 @@
~NextIteration() override;
/// @returns the loop being iterated
- const ir::Loop* Loop() const { return loop_; }
+ ir::Loop* Loop() { return loop_; }
private:
ir::Loop* loop_ = nullptr;
diff --git a/src/tint/ir/next_iteration_test.cc b/src/tint/ir/next_iteration_test.cc
index 33f99da..c0d821c 100644
--- a/src/tint/ir/next_iteration_test.cc
+++ b/src/tint/ir/next_iteration_test.cc
@@ -37,7 +37,7 @@
{
Module mod;
Builder b{mod};
- b.NextIteration(b.CreateLoop(), utils::Vector<Value*, 1>{nullptr});
+ b.NextIteration(b.Loop(), nullptr);
},
"");
}
diff --git a/src/tint/ir/program_test_helper.h b/src/tint/ir/program_test_helper.h
index c0ecdb4..3ce88f0 100644
--- a/src/tint/ir/program_test_helper.h
+++ b/src/tint/ir/program_test_helper.h
@@ -51,7 +51,7 @@
/// @param mod the module
/// @returns the disassembly string of the module
- std::string Disassemble(const Module& mod) const {
+ std::string Disassemble(Module& mod) {
Disassembler d(mod);
return d.Disassemble();
}
diff --git a/src/tint/ir/return.cc b/src/tint/ir/return.cc
index 73bcde0..8c0de42 100644
--- a/src/tint/ir/return.cc
+++ b/src/tint/ir/return.cc
@@ -22,13 +22,19 @@
namespace tint::ir {
-Return::Return(Function* func, utils::VectorRef<Value*> args) : func_(func) {
- TINT_ASSERT(IR, func_);
+Return::Return(Function* func) : func_(func) {
+ TINT_ASSERT_OR_RETURN(IR, func_);
- if (func_) {
- func_->AddUsage({this, 0u});
- }
- AddOperands(std::move(args));
+ func_->AddUsage({this, 0u});
+}
+
+Return::Return(Function* func, Value* arg) : func_(func) {
+ TINT_ASSERT_OR_RETURN(IR, func_);
+ TINT_ASSERT_OR_RETURN(IR, arg);
+
+ func_->AddUsage({this, 0u});
+
+ AddOperand(arg);
}
Return::~Return() = default;
diff --git a/src/tint/ir/return.h b/src/tint/ir/return.h
index 4e4bfba..55463e2 100644
--- a/src/tint/ir/return.h
+++ b/src/tint/ir/return.h
@@ -28,14 +28,19 @@
/// A return instruction.
class Return : public utils::Castable<Return, Branch> {
public:
+ /// Constructor (no return value)
+ /// @param func the function being returned
+ explicit Return(Function* func);
+
/// Constructor
/// @param func the function being returned
- /// @param args the branch arguments
- explicit Return(Function* func, utils::VectorRef<Value*> args = {});
+ /// @param arg the return value
+ Return(Function* func, Value* arg);
+
~Return() override;
/// @returns the function being returned
- const Function* Func() const { return func_; }
+ Function* Func() { return func_; }
private:
Function* func_ = nullptr;
diff --git a/src/tint/ir/return_test.cc b/src/tint/ir/return_test.cc
index e7a032e..fbe3330 100644
--- a/src/tint/ir/return_test.cc
+++ b/src/tint/ir/return_test.cc
@@ -13,6 +13,8 @@
// limitations under the License.
#include "src/tint/ir/return.h"
+
+#include "gmock/gmock.h"
#include "gtest/gtest-spi.h"
#include "src/tint/ir/ir_test_helper.h"
@@ -37,11 +39,23 @@
{
Module mod;
Builder b{mod};
- b.Return(b.CreateFunction("myfunc", mod.Types().void_()),
- utils::Vector<Value*, 1>{nullptr});
+ mod.values.Create<Return>(b.Function("myfunc", mod.Types().void_()), nullptr);
},
"");
}
+TEST_F(IR_ReturnTest, ImplicitNoValue) {
+ auto* ret = b.Return(b.Function("myfunc", ty.void_()));
+ EXPECT_TRUE(ret->Args().IsEmpty());
+}
+
+TEST_F(IR_ReturnTest, WithValue) {
+ auto* val = b.Constant(42_i);
+ auto* ret = b.Return(b.Function("myfunc", ty.i32()), val);
+ ASSERT_EQ(ret->Args().Length(), 1u);
+ EXPECT_EQ(ret->Args()[0], val);
+ EXPECT_THAT(val->Usages(), testing::UnorderedElementsAre(Usage{ret, 0u}));
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/store.h b/src/tint/ir/store.h
index 0dee73c..41289f2 100644
--- a/src/tint/ir/store.h
+++ b/src/tint/ir/store.h
@@ -30,10 +30,10 @@
~Store() override;
/// @returns the value being stored too
- Value* To() const { return operands_[0]; }
+ Value* To() { return operands_[0]; }
/// @returns the value being stored
- Value* From() const { return operands_[1]; }
+ Value* From() { return operands_[1]; }
};
} // namespace tint::ir
diff --git a/src/tint/ir/store_test.cc b/src/tint/ir/store_test.cc
index a28ecad..5578005 100644
--- a/src/tint/ir/store_test.cc
+++ b/src/tint/ir/store_test.cc
@@ -26,9 +26,9 @@
using IR_StoreTest = IRTestHelper;
TEST_F(IR_StoreTest, CreateStore) {
- auto* to = b.Declare(mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kPrivate,
- builtin::Access::kReadWrite));
- const auto* inst = b.Store(to, b.Constant(4_i));
+ auto* to = b.Var(mod.Types().ptr(builtin::AddressSpace::kPrivate, mod.Types().i32(),
+ builtin::Access::kReadWrite));
+ auto* inst = b.Store(to, 4_i);
ASSERT_TRUE(inst->Is<Store>());
ASSERT_EQ(inst->To(), to);
@@ -41,7 +41,7 @@
TEST_F(IR_StoreTest, Store_Usage) {
auto* to = b.Discard();
- auto* inst = b.Store(to, b.Constant(4_i));
+ auto* inst = b.Store(to, 4_i);
ASSERT_NE(inst->To(), nullptr);
EXPECT_THAT(inst->To()->Usages(), testing::UnorderedElementsAre(Usage{inst, 0u}));
@@ -55,7 +55,7 @@
{
Module mod;
Builder b{mod};
- b.Store(nullptr, b.Constant(1_u));
+ b.Store(nullptr, 1_u);
},
"");
}
@@ -65,8 +65,8 @@
{
Module mod;
Builder b{mod};
- auto* to = b.Declare(mod.Types().pointer(
- mod.Types().i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
+ auto* to = b.Var(mod.Types().ptr(builtin::AddressSpace::kPrivate, mod.Types().i32(),
+ builtin::Access::kReadWrite));
b.Store(to, nullptr);
},
"");
diff --git a/src/tint/ir/switch.cc b/src/tint/ir/switch.cc
index 68137ac..8b70cfc 100644
--- a/src/tint/ir/switch.cc
+++ b/src/tint/ir/switch.cc
@@ -16,11 +16,11 @@
TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch);
-#include "src/tint/ir/block.h"
+#include "src/tint/ir/multi_in_block.h"
namespace tint::ir {
-Switch::Switch(Value* cond, ir::Block* m) : merge_(m) {
+Switch::Switch(Value* cond, ir::MultiInBlock* m) : merge_(m) {
TINT_ASSERT(IR, cond);
TINT_ASSERT(IR, merge_);
diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h
index b065074..7791c2c 100644
--- a/src/tint/ir/switch.h
+++ b/src/tint/ir/switch.h
@@ -20,6 +20,7 @@
// Forward declarations
namespace tint::ir {
class Constant;
+class MultiInBlock;
} // namespace tint::ir
namespace tint::ir {
@@ -48,7 +49,7 @@
/// A case selector
struct CaseSelector {
/// @returns true if this is a default selector
- bool IsDefault() const { return val == nullptr; }
+ bool IsDefault() { return val == nullptr; }
/// The selector value, or nullptr if this is the default selector
Constant* val = nullptr;
@@ -62,37 +63,29 @@
ir::Block* start = nullptr;
/// @returns the case start target
- const ir::Block* Start() const { return start; }
- /// @returns the case start target
ir::Block* Start() { return start; }
};
/// Constructor
/// @param cond the condition
/// @param m the merge block
- explicit Switch(Value* cond, ir::Block* m);
+ explicit Switch(Value* cond, ir::MultiInBlock* m);
~Switch() override;
/// @returns the switch merge branch
- const ir::Block* Merge() const { return merge_; }
- /// @returns the switch merge branch
- ir::Block* Merge() { return merge_; }
+ ir::MultiInBlock* Merge() { return merge_; }
/// @returns the switch cases
- utils::VectorRef<Case> Cases() const { return cases_; }
- /// @returns the switch cases
utils::Vector<Case, 4>& Cases() { return cases_; }
/// @returns the branch arguments
- utils::Slice<Value const* const> Args() const override { return {}; }
+ utils::Slice<Value* const> Args() override { return {}; }
/// @returns the condition
- const Value* Condition() const { return operands_[0]; }
- /// @returns the condition
Value* Condition() { return operands_[0]; }
private:
- ir::Block* merge_ = nullptr;
+ ir::MultiInBlock* merge_ = nullptr;
utils::Vector<Case, 4> cases_;
};
diff --git a/src/tint/ir/switch_test.cc b/src/tint/ir/switch_test.cc
index cbcf11a..72e0dda 100644
--- a/src/tint/ir/switch_test.cc
+++ b/src/tint/ir/switch_test.cc
@@ -26,13 +26,13 @@
TEST_F(IR_SwitchTest, Usage) {
auto* cond = b.Constant(true);
- auto* switch_ = b.CreateSwitch(cond);
+ auto* switch_ = b.Switch(cond);
EXPECT_THAT(cond->Usages(), testing::UnorderedElementsAre(Usage{switch_, 0u}));
}
TEST_F(IR_SwitchTest, Parent) {
- auto* switch_ = b.CreateSwitch(b.Constant(1_i));
- b.CreateCase(switch_, utils::Vector{Switch::CaseSelector{nullptr}});
+ auto* switch_ = b.Switch(1_i);
+ b.Case(switch_, {Switch::CaseSelector{nullptr}});
EXPECT_THAT(switch_->Merge()->Parent(), switch_);
EXPECT_THAT(switch_->Cases().Front().Start()->Parent(), switch_);
}
@@ -42,12 +42,12 @@
{
Module mod;
Builder b{mod};
- b.CreateSwitch(nullptr);
+ b.Switch(nullptr);
},
"");
}
-TEST_F(IR_SwitchTest, Fail_NullMergeBlock) {
+TEST_F(IR_SwitchTest, Fail_NullMultiInBlock) {
EXPECT_FATAL_FAILURE(
{
Module mod;
diff --git a/src/tint/ir/swizzle.h b/src/tint/ir/swizzle.h
index 0dde62d..f340a01 100644
--- a/src/tint/ir/swizzle.h
+++ b/src/tint/ir/swizzle.h
@@ -31,13 +31,13 @@
~Swizzle() override;
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the object used for the access
- Value* Object() const { return operands_[0]; }
+ Value* Object() { return operands_[0]; }
/// @returns the swizzle indices
- utils::VectorRef<uint32_t> Indices() const { return indices_; }
+ utils::VectorRef<uint32_t> Indices() { return indices_; }
private:
const type::Type* result_type_ = nullptr;
diff --git a/src/tint/ir/swizzle_test.cc b/src/tint/ir/swizzle_test.cc
index 64df5c6..ada269d 100644
--- a/src/tint/ir/swizzle_test.cc
+++ b/src/tint/ir/swizzle_test.cc
@@ -24,10 +24,8 @@
using IR_SwizzleTest = IRTestHelper;
TEST_F(IR_SwizzleTest, SetsUsage) {
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- auto* a = b.Swizzle(mod.Types().i32(), var, utils::Vector{1u});
+ auto* var = b.Var(ty.ptr<function, i32>());
+ auto* a = b.Swizzle(mod.Types().i32(), var, {1u});
EXPECT_THAT(var->Usages(), testing::UnorderedElementsAre(Usage{a, 0u}));
}
@@ -37,10 +35,8 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- b.Swizzle(nullptr, var, utils::Vector{1u});
+ auto* var = b.Var(mod.Types().ptr<function, i32>());
+ b.Swizzle(nullptr, var, {1u});
},
"");
}
@@ -50,7 +46,7 @@
{
Module mod;
Builder b{mod};
- b.Swizzle(mod.Types().i32(), nullptr, utils::Vector{1u});
+ b.Swizzle(mod.Types().i32(), nullptr, {1u});
},
"");
}
@@ -60,9 +56,7 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
+ auto* var = b.Var(mod.Types().ptr<function, i32>());
b.Swizzle(mod.Types().i32(), var, utils::Empty);
},
"");
@@ -73,10 +67,8 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- b.Swizzle(mod.Types().i32(), var, utils::Vector{1u, 1u, 1u, 1u, 1u});
+ auto* var = b.Var(mod.Types().ptr<function, i32>());
+ b.Swizzle(mod.Types().i32(), var, {1u, 1u, 1u, 1u, 1u});
},
"");
}
@@ -86,10 +78,8 @@
{
Module mod;
Builder b{mod};
- auto* ty = mod.Types().pointer(mod.Types().i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite);
- auto* var = b.Declare(ty);
- b.Swizzle(mod.Types().i32(), var, utils::Vector{4u});
+ auto* var = b.Var(mod.Types().ptr<function, i32>());
+ b.Swizzle(mod.Types().i32(), var, {4u});
},
"");
}
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index b226236..c536ae6 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -26,6 +26,7 @@
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/module.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
@@ -64,7 +65,7 @@
class State {
public:
- explicit State(const Module& m) : mod(m) {}
+ explicit State(Module& m) : mod(m) {}
Program Run() {
// TODO(crbug.com/tint/1902): Emit root block
@@ -77,13 +78,13 @@
private:
/// The source IR module
- const Module& mod;
+ Module& mod;
/// The target ProgramBuilder
ProgramBuilder b;
/// A hashmap of value to symbol used in the emitted AST
- utils::Hashmap<const Value*, Symbol, 32> value_names_;
+ utils::Hashmap<Value*, Symbol, 32> value_names_;
// The nesting depth of the currently generated AST
// 0 is module scope
@@ -91,12 +92,12 @@
// 2+ is within control flow
uint32_t nesting_depth_ = 0;
- const ast::Function* Fn(const Function* fn) {
+ const ast::Function* Fn(Function* fn) {
SCOPED_NESTING();
// TODO(crbug.com/tint/1915): Properly implement this when we've fleshed out Function
static constexpr size_t N = decltype(ast::Function::params)::static_length;
- auto params = utils::Transform<N>(fn->Params(), [&](const ir::FunctionParam* param) {
+ auto params = utils::Transform<N>(fn->Params(), [&](FunctionParam* param) {
auto name = AssignNameTo(param);
auto ty = Type(param->Type());
return b.Param(name, ty);
@@ -111,13 +112,13 @@
std::move(ret_attrs));
}
- const ast::BlockStatement* BlockGraph(const ir::Block* start_node) {
+ const ast::BlockStatement* BlockGraph(ir::Block* start_node) {
// TODO(crbug.com/tint/1902): Check if the block is dead
utils::Vector<const ast::Statement*,
decltype(ast::BlockStatement::statements)::static_length>
stmts;
- const ir::Block* block = start_node;
+ ir::Block* block = start_node;
// TODO(crbug.com/tint/1902): Handle block arguments.
@@ -155,18 +156,18 @@
/// @param inst the ir::Instruction
/// @return an ast::Statement from @p inst, or nullptr if there was an error
- const ast::Statement* Stmt(const ir::Instruction* inst) {
+ const ast::Statement* Stmt(ir::Instruction* inst) {
return tint::Switch(
- inst, //
- [&](const ir::Store* i) { return Store(i); }, //
- [&](const ir::Call* i) { return CallStmt(i); }, //
- [&](const ir::Var* i) { return Var(i); }, //
- [&](const ir::If* if_) { return If(if_); }, //
- [&](const ir::Switch* switch_) { return Switch(switch_); }, //
- [&](const ir::Return* ret) { return Return(ret); }, //
- [&](const ir::Value*) { return ValueStmt(inst); },
+ inst, //
+ [&](ir::Store* i) { return Store(i); }, //
+ [&](ir::Call* i) { return CallStmt(i); }, //
+ [&](ir::Var* i) { return Var(i); }, //
+ [&](ir::If* if_) { return If(if_); }, //
+ [&](ir::Switch* switch_) { return Switch(switch_); }, //
+ [&](ir::Return* ret) { return Return(ret); }, //
+ [&](ir::Value*) { return ValueStmt(inst); },
// TODO(dsinclair): Remove when branch is only a parent ...
- [&](const ir::Branch*) { return nullptr; },
+ [&](ir::Branch*) { return nullptr; },
[&](Default) {
UNHANDLED_CASE(inst);
return nullptr;
@@ -175,7 +176,7 @@
/// @param i the ir::If
/// @return an ast::IfStatement from @p i, or nullptr if there was an error
- const ast::IfStatement* If(const ir::If* i) {
+ const ast::IfStatement* If(ir::If* i) {
SCOPED_NESTING();
auto* cond = Expr(i->Condition());
auto* t = BlockGraph(i->True());
@@ -210,7 +211,7 @@
/// @param s the ir::Switch
/// @return an ast::SwitchStatement from @p s, or nullptr if there was an error
- const ast::SwitchStatement* Switch(const ir::Switch* s) {
+ const ast::SwitchStatement* Switch(ir::Switch* s) {
SCOPED_NESTING();
auto* cond = Expr(s->Condition());
@@ -218,33 +219,33 @@
return nullptr;
}
- auto cases = utils::Transform<2>(
- s->Cases(), //
- [&](const ir::Switch::Case c) -> const tint::ast::CaseStatement* {
- SCOPED_NESTING();
- auto* body = BlockGraph(c.start);
- if (!body) {
- return nullptr;
- }
+ auto cases =
+ utils::Transform(s->Cases(), //
+ [&](ir::Switch::Case c) -> const tint::ast::CaseStatement* {
+ SCOPED_NESTING();
+ auto* body = BlockGraph(c.start);
+ if (!body) {
+ return nullptr;
+ }
- auto selectors = utils::Transform(
- c.selectors, //
- [&](const ir::Switch::CaseSelector& cs) -> const ast::CaseSelector* {
- if (cs.IsDefault()) {
- return b.DefaultCaseSelector();
- }
- auto* expr = Expr(cs.val);
- if (!expr) {
- return nullptr;
- }
- return b.CaseSelector(expr);
- });
- if (selectors.Any(utils::IsNull)) {
- return nullptr;
- }
+ auto selectors = utils::Transform(
+ c.selectors, //
+ [&](ir::Switch::CaseSelector cs) -> const ast::CaseSelector* {
+ if (cs.IsDefault()) {
+ return b.DefaultCaseSelector();
+ }
+ auto* expr = Expr(cs.val);
+ if (!expr) {
+ return nullptr;
+ }
+ return b.CaseSelector(expr);
+ });
+ if (selectors.Any(utils::IsNull)) {
+ return nullptr;
+ }
- return b.Case(std::move(selectors), body);
- });
+ return b.Case(std::move(selectors), body);
+ });
if (cases.Any(utils::IsNull)) {
return nullptr;
}
@@ -254,7 +255,7 @@
/// @param ret the ir::Return
/// @return an ast::ReturnStatement from @p ret, or nullptr if there was an error
- const ast::ReturnStatement* Return(const ir::Return* ret) {
+ const ast::ReturnStatement* Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
// If this block is nested withing some control flow, then we must
@@ -283,13 +284,13 @@
/// @param call the ir::Call
/// @return an ast::CallStatement from @p call, or nullptr if there was an error
- const ast::CallStatement* CallStmt(const ir::Call* call) { return b.CallStmt(Call(call)); }
+ const ast::CallStatement* CallStmt(ir::Call* call) { return b.CallStmt(Call(call)); }
/// @param var the ir::Var
/// @return an ast::VariableDeclStatement from @p var
- const ast::VariableDeclStatement* Var(const ir::Var* var) {
+ const ast::VariableDeclStatement* Var(ir::Var* var) {
Symbol name = AssignNameTo(var);
- auto* ptr = var->Type()->As<type::Pointer>();
+ auto* ptr = var->Type();
auto ty = Type(ptr->StoreType());
const ast::Expression* init = nullptr;
if (var->Initializer()) {
@@ -307,14 +308,14 @@
/// @param store the ir::Store
/// @return an ast::AssignmentStatement from @p call
- const ast::AssignmentStatement* Store(const ir::Store* store) {
+ const ast::AssignmentStatement* Store(ir::Store* store) {
auto* expr = Expr(store->From());
return b.Assign(AssignNameTo(store->To()), expr);
}
/// @param val the ir::Value
/// @return an ast::Statement from @p val, or nullptr if the value does not produce a statement.
- const ast::Statement* ValueStmt(const ir::Value* val) {
+ const ast::Statement* ValueStmt(ir::Value* val) {
// As we're visiting this value's declaration it shouldn't already have a name reserved.
TINT_ASSERT(IR, !value_names_.Contains(val));
@@ -344,17 +345,17 @@
/// @param val the ir::Expression
/// @return an ast::Expression from @p val.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::Expression* Expr(const ir::Value* val) {
+ const ast::Expression* Expr(ir::Value* val) {
if (auto name = value_names_.Get(val)) {
return b.Expr(name.value());
}
return tint::Switch(
- val, //
- [&](const ir::Constant* c) { return ConstExpr(c); },
- [&](const ir::Load* l) { return LoadExpr(l); },
- [&](const ir::Unary* u) { return UnaryExpr(u); },
- [&](const ir::Binary* u) { return BinaryExpr(u); },
+ val, //
+ [&](ir::Constant* c) { return ConstExpr(c); }, //
+ [&](ir::Load* l) { return LoadExpr(l); }, //
+ [&](ir::Unary* u) { return UnaryExpr(u); }, //
+ [&](ir::Binary* u) { return BinaryExpr(u); }, //
[&](Default) {
UNHANDLED_CASE(val);
return b.Expr("<error>");
@@ -364,12 +365,11 @@
/// @param call the ir::Call
/// @return an ast::CallExpression from @p call.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::CallExpression* Call(const ir::Call* call) {
- auto args =
- utils::Transform<2>(call->Args(), [&](const ir::Value* arg) { return Expr(arg); });
+ const ast::CallExpression* Call(ir::Call* call) {
+ auto args = utils::Transform<2>(call->Args(), [&](ir::Value* arg) { return Expr(arg); });
return tint::Switch(
call, //
- [&](const ir::UserCall* c) { return b.Call(AssignNameTo(c->Func()), std::move(args)); },
+ [&](ir::UserCall* c) { return b.Call(AssignNameTo(c->Func()), std::move(args)); },
[&](Default) {
UNHANDLED_CASE(call);
return b.Call("<error>");
@@ -379,7 +379,7 @@
/// @param c the ir::Constant
/// @return an ast::Expression from @p c.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::Expression* ConstExpr(const ir::Constant* c) {
+ const ast::Expression* ConstExpr(ir::Constant* c) {
return tint::Switch(
c->Type(), //
[&](const type::I32*) { return b.Expr(c->Value()->ValueAs<i32>()); },
@@ -396,12 +396,12 @@
/// @param l the ir::Load
/// @return an ast::Expression from @p l.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::Expression* LoadExpr(const ir::Load* l) { return Expr(l->From()); }
+ const ast::Expression* LoadExpr(ir::Load* l) { return Expr(l->From()); }
/// @param u the ir::Unary
/// @return an ast::UnaryOpExpression from @p u.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::Expression* UnaryExpr(const ir::Unary* u) {
+ const ast::Expression* UnaryExpr(ir::Unary* u) {
switch (u->Kind()) {
case ir::Unary::Kind::kComplement:
return b.Complement(Expr(u->Val()));
@@ -414,7 +414,7 @@
/// @param e the ir::Binary
/// @return an ast::BinaryOpExpression from @p e.
/// @note May be a semantically-invalid placeholder expression on error.
- const ast::Expression* BinaryExpr(const ir::Binary* e) {
+ const ast::Expression* BinaryExpr(ir::Binary* e) {
if (e->Kind() == ir::Binary::Kind::kEqual) {
auto* rhs = e->RHS()->As<ir::Constant>();
if (rhs && rhs->Type()->Is<type::Bool>() && rhs->Value()->ValueAs<bool>() == false) {
@@ -538,7 +538,7 @@
auto access = address_space == builtin::AddressSpace::kStorage
? p->Access()
: builtin::Access::kUndefined;
- return b.ty.pointer(el, address_space, access);
+ return b.ty.ptr(address_space, el, access);
},
[&](const type::Reference*) {
TINT_ICE(IR, b.Diagnostics()) << "reference types should never appear in the IR";
@@ -557,7 +557,7 @@
/// Creates and returns a new, unique name for the given value, or returns the previously
/// created name.
/// @return the value's name
- Symbol AssignNameTo(const Value* value) {
+ Symbol AssignNameTo(Value* value) {
TINT_ASSERT(IR, value);
return value_names_.GetOrCreate(value, [&] {
if (auto sym = mod.NameOf(value)) {
@@ -570,7 +570,7 @@
} // namespace
-Program ToProgram(const Module& i) {
+Program ToProgram(Module& i) {
return State{i}.Run();
}
diff --git a/src/tint/ir/to_program.h b/src/tint/ir/to_program.h
index d4cf2c6..eecaf39 100644
--- a/src/tint/ir/to_program.h
+++ b/src/tint/ir/to_program.h
@@ -27,7 +27,7 @@
/// @param module the IR module
/// @return the tint::Program.
/// @note Check the returned Program::Diagnostics() for any errors.
-Program ToProgram(const Module& module);
+Program ToProgram(Module& module);
} // namespace tint::ir
diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc
index 9cb58e2..7a5506d 100644
--- a/src/tint/ir/transform/add_empty_entry_point.cc
+++ b/src/tint/ir/transform/add_empty_entry_point.cc
@@ -35,9 +35,9 @@
}
ir::Builder builder(*ir);
- auto* ep = builder.CreateFunction("unused_entry_point", ir->Types().void_(),
- Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
- ep->StartTarget()->SetInstructions(utils::Vector{builder.Return(ep)});
+ auto* ep = builder.Function("unused_entry_point", ir->Types().void_(),
+ Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
+ ep->StartTarget()->SetInstructions({builder.Return(ep)});
ir->functions.Push(ep);
}
diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc
index 1da2e7d..558e2af 100644
--- a/src/tint/ir/transform/add_empty_entry_point_test.cc
+++ b/src/tint/ir/transform/add_empty_entry_point_test.cc
@@ -38,8 +38,8 @@
}
TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) {
- auto* ep = b.CreateFunction("main", mod.Types().void_(), Function::PipelineStage::kFragment);
- ep->StartTarget()->SetInstructions(utils::Vector{b.Return(ep)});
+ auto* ep = b.Function("main", mod.Types().void_(), Function::PipelineStage::kFragment);
+ ep->StartTarget()->SetInstructions({b.Return(ep)});
mod.functions.Push(ep);
auto* expect = R"(
diff --git a/src/tint/ir/transform/block_decorated_structs.cc b/src/tint/ir/transform/block_decorated_structs.cc
index f9f191c..5c466a5 100644
--- a/src/tint/ir/transform/block_decorated_structs.cc
+++ b/src/tint/ir/transform/block_decorated_structs.cc
@@ -93,24 +93,21 @@
// Replace the old variable declaration with one that uses the block-decorated struct type.
auto* new_var =
- builder.Declare(ir->Types().pointer(block_struct, ptr->AddressSpace(), ptr->Access()));
+ builder.Var(ir->Types().ptr(ptr->AddressSpace(), block_struct, ptr->Access()));
new_var->SetBindingPoint(var->BindingPoint()->group, var->BindingPoint()->binding);
var->ReplaceWith(new_var);
// Replace uses of the old variable.
- while (!var->Usages().IsEmpty()) {
- auto& use = *var->Usages().begin();
+ var->ReplaceAllUsesWith([&](Usage use) -> Value* {
if (wrapped) {
// The structure has been wrapped, so replace all uses of the old variable with a
// member accessor on the new variable.
- auto* access =
- builder.Access(var->Type(), new_var, utils::Vector{builder.Constant(0_u)});
+ auto* access = builder.Access(var->Type(), new_var, 0_u);
access->InsertBefore(use.instruction);
- use.instruction->SetOperand(use.operand_index, access);
- } else {
- use.instruction->SetOperand(use.operand_index, new_var);
+ return access;
}
- }
+ return new_var;
+ });
}
}
diff --git a/src/tint/ir/transform/block_decorated_structs_test.cc b/src/tint/ir/transform/block_decorated_structs_test.cc
index a308f68..e86c563 100644
--- a/src/tint/ir/transform/block_decorated_structs_test.cc
+++ b/src/tint/ir/transform/block_decorated_structs_test.cc
@@ -29,7 +29,7 @@
using namespace tint::number_suffixes; // NOLINT
TEST_F(IR_BlockDecoratedStructsTest, NoRootBlock) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->Append(b.Return(func));
mod.functions.Push(func);
@@ -47,15 +47,14 @@
}
TEST_F(IR_BlockDecoratedStructsTest, Scalar_Uniform) {
- auto* buffer = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ auto* buffer = b.Var(ty.ptr<uniform, i32>());
buffer->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->Append(buffer);
+ b.RootBlock()->Append(buffer);
- auto* func = b.CreateFunction("foo", ty.i32());
- auto* load = b.Load(buffer);
- func->StartTarget()->Append(load);
- func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ auto* func = b.Function("foo", ty.i32());
+ auto* block = func->StartTarget();
+ auto* load = block->Append(b.Load(buffer));
+ block->Append(b.Return(func, load));
mod.functions.Push(func);
auto* expect = R"(
@@ -83,13 +82,12 @@
}
TEST_F(IR_BlockDecoratedStructsTest, Scalar_Storage) {
- auto* buffer = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer = b.Var(ty.ptr<storage, i32>());
buffer->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->Append(buffer);
+ b.RootBlock()->Append(buffer);
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->Append(b.Store(buffer, b.Constant(42_i)));
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->Append(b.Store(buffer, 42_i));
func->StartTarget()->Append(b.Return(func));
mod.functions.Push(func);
@@ -118,18 +116,15 @@
}
TEST_F(IR_BlockDecoratedStructsTest, RuntimeArray) {
- auto* buffer = b.Declare(ty.pointer(ty.runtime_array(ty.i32()), builtin::AddressSpace::kStorage,
- builtin::Access::kReadWrite));
+ auto* buffer = b.Var(ty.ptr(storage, ty.runtime_array(ty.i32()), builtin::Access::kReadWrite));
buffer->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->Append(buffer);
+ b.RootBlock()->Append(buffer);
- auto* func = b.CreateFunction("foo", ty.void_());
- auto* access =
- b.Access(ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite),
- buffer, utils::Vector{b.Constant(1_u)});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Store(access, b.Constant(42_i)));
- func->StartTarget()->Append(b.Return(func));
+ auto* func = b.Function("foo", ty.void_());
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.ptr<storage, i32>(), buffer, 1_u));
+ block->Append(b.Store(access, 42_i));
+ block->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
@@ -165,23 +160,19 @@
4u, 4u, type::StructMemberAttributes{}));
auto* structure = ty.Get<type::Struct>(mod.symbols.New(), members, 4u, 8u, 8u);
- auto* buffer = b.Declare(
- ty.pointer(structure, builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer = b.Var(ty.ptr(storage, structure, builtin::Access::kReadWrite));
buffer->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->Append(buffer);
+ b.RootBlock()->Append(buffer);
- auto* i32_ptr =
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite);
+ auto* i32_ptr = ty.ptr<storage, i32>();
- auto* func = b.CreateFunction("foo", ty.void_());
- auto* val_ptr = b.Access(i32_ptr, buffer, utils::Vector{b.Constant(0_u)});
- auto* load = b.Load(val_ptr);
- auto* elem_ptr = b.Access(i32_ptr, buffer, utils::Vector{b.Constant(1_u), b.Constant(3_u)});
- func->StartTarget()->Append(val_ptr);
- func->StartTarget()->Append(load);
- func->StartTarget()->Append(elem_ptr);
- func->StartTarget()->Append(b.Store(elem_ptr, load));
- func->StartTarget()->Append(b.Return(func));
+ auto* func = b.Function("foo", ty.void_());
+ auto* block = func->StartTarget();
+ auto* val_ptr = block->Append(b.Access(i32_ptr, buffer, 0_u));
+ auto* load = block->Append(b.Load(val_ptr));
+ auto* elem_ptr = block->Append(b.Access(i32_ptr, buffer, 1_u, 3_u));
+ block->Append(b.Store(elem_ptr, load));
+ block->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
@@ -224,16 +215,15 @@
type::StructMemberAttributes{}));
auto* structure = ty.Get<type::Struct>(mod.symbols.New(), members, 4u, 8u, 8u);
- auto* buffer = b.Declare(
- ty.pointer(structure, builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer = b.Var(ty.ptr(storage, structure, builtin::Access::kReadWrite));
buffer->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->Append(buffer);
+ b.RootBlock()->Append(buffer);
- auto* private_var = b.Declare(
- ty.pointer(structure, builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->Append(private_var);
+ auto* private_var =
+ b.Var(ty.ptr(builtin::AddressSpace::kPrivate, structure, builtin::Access::kReadWrite));
+ b.RootBlock()->Append(private_var);
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->Append(b.Store(buffer, private_var));
func->StartTarget()->Append(b.Return(func));
mod.functions.Push(func);
@@ -269,27 +259,23 @@
}
TEST_F(IR_BlockDecoratedStructsTest, MultipleBuffers) {
- auto* buffer_a = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
- auto* buffer_b = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
- auto* buffer_c = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* buffer_a = b.Var(ty.ptr<storage, i32>());
+ auto* buffer_b = b.Var(ty.ptr<storage, i32>());
+ auto* buffer_c = b.Var(ty.ptr<storage, i32>());
buffer_a->SetBindingPoint(0, 0);
buffer_b->SetBindingPoint(0, 1);
buffer_c->SetBindingPoint(0, 2);
- auto* root = b.CreateRootBlockIfNeeded();
+ auto* root = b.RootBlock();
root->Append(buffer_a);
root->Append(buffer_b);
root->Append(buffer_c);
- auto* func = b.CreateFunction("foo", ty.void_());
- auto* load_b = b.Load(buffer_b);
- auto* load_c = b.Load(buffer_c);
- func->StartTarget()->Append(load_b);
- func->StartTarget()->Append(load_c);
- func->StartTarget()->Append(b.Store(buffer_a, b.Add(ty.i32(), load_b, load_c)));
- func->StartTarget()->Append(b.Return(func));
+ auto* func = b.Function("foo", ty.void_());
+ auto* block = func->StartTarget();
+ auto* load_b = block->Append(b.Load(buffer_b));
+ auto* load_c = block->Append(b.Load(buffer_c));
+ block->Append(b.Store(buffer_a, b.Add(ty.i32(), load_b, load_c)));
+ block->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
diff --git a/src/tint/ir/transform/test_helper.h b/src/tint/ir/transform/test_helper.h
index 83ff8c2..ea75fcf 100644
--- a/src/tint/ir/transform/test_helper.h
+++ b/src/tint/ir/transform/test_helper.h
@@ -60,6 +60,15 @@
/// The type manager.
type::Manager& ty{mod.Types()};
+ /// Alias to builtin::AddressSpace::kStorage
+ static constexpr auto storage = builtin::AddressSpace::kStorage;
+ /// Alias to builtin::AddressSpace::kUniform
+ static constexpr auto uniform = builtin::AddressSpace::kUniform;
+ /// Alias to builtin::AddressSpace::kPrivate
+ static constexpr auto private_ = builtin::AddressSpace::kPrivate;
+ /// Alias to builtin::AddressSpace::kFunction
+ static constexpr auto function = builtin::AddressSpace::kFunction;
+
private:
std::vector<std::unique_ptr<Source::File>> files_;
};
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc
index fa063e7..b72dc34 100644
--- a/src/tint/ir/transform/var_for_dynamic_index.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -18,11 +18,9 @@
#include "src/tint/ir/builder.h"
#include "src/tint/ir/module.h"
-#include "src/tint/switch.h"
#include "src/tint/type/array.h"
#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
-#include "src/tint/type/struct.h"
#include "src/tint/type/vector.h"
#include "src/tint/utils/hashmap.h"
@@ -63,58 +61,48 @@
return base == other.base && indices == other.indices;
}
};
+
+std::optional<AccessToReplace> ShouldReplace(Access* access) {
+ if (access->Type()->Is<type::Pointer>()) {
+ // No need to modify accesses into pointer types.
+ return {};
+ }
+
+ // Find the first dynamic index, if any.
+ const auto& indices = access->Indices();
+ auto* source_type = access->Object()->Type();
+ for (uint32_t i = 0; i < indices.Length(); i++) {
+ if (source_type->Is<type::Vector>()) {
+ // Stop if we hit a vector, as they can support dynamic accesses.
+ return {};
+ }
+
+ // Check if the index is dynamic.
+ auto* const_idx = indices[i]->As<Constant>();
+ if (!const_idx) {
+ return AccessToReplace{access, i, source_type};
+ }
+
+ // Update the current source object type.
+ source_type = source_type->Element(const_idx->Value()->ValueAs<u32>());
+ }
+ // No dynamic indices were found.
+ return {};
+}
+
} // namespace
VarForDynamicIndex::VarForDynamicIndex() = default;
VarForDynamicIndex::~VarForDynamicIndex() = default;
-static std::optional<AccessToReplace> ShouldReplace(Access* access) {
- AccessToReplace to_replace{access, 0, access->Object()->Type()};
-
- // Find the first dynamic index, if any.
- bool has_dynamic_index = false;
- for (auto* idx : access->Indices()) {
- if (to_replace.dynamic_index_source_type->Is<type::Vector>()) {
- // Stop if we hit a vector, as they can support dynamic accesses.
- break;
- }
-
- // Check if the index is dynamic.
- auto* const_idx = idx->As<Constant>();
- if (!const_idx) {
- has_dynamic_index = true;
- break;
- }
- to_replace.first_dynamic_index++;
-
- // Update the current object type.
- to_replace.dynamic_index_source_type = tint::Switch(
- to_replace.dynamic_index_source_type, //
- [&](const type::Array* arr) { return arr->ElemType(); },
- [&](const type::Matrix* mat) { return mat->ColumnType(); },
- [&](const type::Struct* str) {
- return str->Members()[const_idx->Value()->ValueAs<u32>()]->Type();
- },
- [&](const type::Vector* vec) { return vec->type(); }, //
- [&](Default) { return nullptr; });
- }
- if (!has_dynamic_index) {
- // No need to modify accesses that only use constant indices.
- return {};
- }
-
- return to_replace;
-}
-
void VarForDynamicIndex::Run(ir::Module* ir, const DataMap&, DataMap&) const {
ir::Builder builder(*ir);
// Find the access instructions that need replacing.
utils::Vector<AccessToReplace, 4> worklist;
for (auto* inst : ir->values.Objects()) {
- auto* access = inst->As<Access>();
- if (access && !access->Type()->Is<type::Pointer>()) {
+ if (auto* access = inst->As<Access>()) {
if (auto to_replace = ShouldReplace(access)) {
worklist.Push(to_replace.value());
}
@@ -126,7 +114,7 @@
utils::Hashmap<PartialAccess, Value*, 4, PartialAccess::Hasher> source_object_to_value;
for (const auto& to_replace : worklist) {
auto* access = to_replace.access;
- Value* source_object = access->Object();
+ auto* source_object = access->Object();
// If the access starts with at least one constant index, extract the source of the first
// dynamic access to avoid copying the whole object.
@@ -143,9 +131,9 @@
// Declare a local variable and copy the source object to it.
auto* local = object_to_local.GetOrCreate(source_object, [&]() {
- auto* decl = builder.Declare(ir->Types().pointer(to_replace.dynamic_index_source_type,
- builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite));
+ auto* decl =
+ builder.Var(ir->Types().ptr(builtin::AddressSpace::kFunction, source_object->Type(),
+ builtin::Access::kReadWrite));
decl->SetInitializer(source_object);
decl->InsertBefore(access);
return decl;
@@ -154,8 +142,8 @@
// Create a new access instruction using the local variable as the source.
utils::Vector<Value*, 4> indices{access->Indices().Offset(to_replace.first_dynamic_index)};
auto* new_access =
- builder.Access(ir->Types().pointer(access->Type(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite),
+ builder.Access(ir->Types().ptr(builtin::AddressSpace::kFunction, access->Type(),
+ builtin::Access::kReadWrite),
local, indices);
access->ReplaceWith(new_access);
@@ -164,10 +152,7 @@
load->InsertAfter(new_access);
// Replace all uses of the old access instruction with the loaded result.
- while (!access->Usages().IsEmpty()) {
- auto& use = *access->Usages().begin();
- use.instruction->SetOperand(use.operand_index, load);
- }
+ access->ReplaceAllUsesWith([&](Usage) { return load; });
}
}
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
index 0f6af06..6b64025 100644
--- a/src/tint/ir/transform/var_for_dynamic_index_test.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -29,18 +29,18 @@
class IR_VarForDynamicIndexTest : public TransformTest {
protected:
const type::Type* ptr(const type::Type* elem) {
- return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ return ty.ptr(builtin::AddressSpace::kFunction, elem, builtin::Access::kReadWrite);
}
};
TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_ArrayValue) {
- auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr});
+ auto* arr = b.FunctionParam(ty.array<i32, 4u>());
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr});
- auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_i)});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.i32(), arr, 1_i));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -58,13 +58,13 @@
}
TEST_F(IR_VarForDynamicIndexTest, NoModify_ConstantIndex_MatrixValue) {
- auto* mat = b.FunctionParam(ty.mat2x2(ty.f32()));
- auto* func = b.CreateFunction("foo", ty.f32());
- func->SetParams(utils::Vector{mat});
+ auto* mat = b.FunctionParam(ty.mat2x2<f32>());
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({mat});
- auto* access = b.Access(ty.f32(), mat, utils::Vector{b.Constant(1_i), b.Constant(0_i)});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.f32(), mat, 1_i, 0_i));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -82,16 +82,15 @@
}
TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_ArrayPointer) {
- auto* arr = b.FunctionParam(ptr(ty.array(ty.i32(), 4u)));
+ auto* arr = b.FunctionParam(ptr(ty.array<i32, 4u>()));
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx});
- auto* access = b.Access(ptr(ty.i32()), arr, utils::Vector{idx});
- auto* load = b.Load(access);
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(load);
- func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ptr(ty.i32()), arr, idx));
+ auto* load = block->Append(b.Load(access));
+ block->Append(b.Return(func, load));
mod.functions.Push(func);
auto* expect = R"(
@@ -110,16 +109,15 @@
}
TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_MatrixPointer) {
- auto* mat = b.FunctionParam(ptr(ty.mat2x2(ty.f32())));
+ auto* mat = b.FunctionParam(ptr(ty.mat2x2<f32>()));
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.f32());
- func->SetParams(utils::Vector{mat, idx});
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({mat, idx});
- auto* access = b.Access(ptr(ty.f32()), mat, utils::Vector{idx, idx});
- auto* load = b.Load(access);
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(load);
- func->StartTarget()->Append(b.Return(func, utils::Vector{load}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ptr(ty.f32()), mat, idx, idx));
+ auto* load = block->Append(b.Load(access));
+ block->Append(b.Return(func, load));
mod.functions.Push(func);
auto* expect = R"(
@@ -138,14 +136,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, NoModify_DynamicIndex_VectorValue) {
- auto* vec = b.FunctionParam(ty.vec4(ty.f32()));
+ auto* vec = b.FunctionParam(ty.vec4<f32>());
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.f32());
- func->SetParams(utils::Vector{vec, idx});
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({vec, idx});
- auto* access = b.Access(ty.f32(), vec, utils::Vector{idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.f32(), vec, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -163,14 +161,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_ArrayValue) {
- auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* arr = b.FunctionParam(ty.array<i32, 4u>());
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx});
- auto* access = b.Access(ty.i32(), arr, utils::Vector{idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.i32(), arr, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -190,14 +188,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, DynamicIndex_MatrixValue) {
- auto* arr = b.FunctionParam(ty.mat2x2(ty.f32()));
+ auto* arr = b.FunctionParam(ty.mat2x2<f32>());
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.f32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({arr, idx});
- auto* access = b.Access(ty.f32(), arr, utils::Vector{idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.f32(), arr, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -217,14 +215,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, AccessChain) {
- auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array<i32, 4u>(), 4u), 4u));
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx});
- auto* access = b.Access(ty.i32(), arr, utils::Vector{idx, b.Constant(1_u), idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.i32(), arr, idx, 1_u, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -244,14 +242,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices) {
- auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array<i32, 4u>(), 4u), 4u));
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx});
- auto* access = b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -272,15 +270,14 @@
}
TEST_F(IR_VarForDynamicIndexTest, AccessChain_SkipConstantIndices_Interleaved) {
- auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u), 4u));
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.array<i32, 4u>(), 4u), 4u), 4u));
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx});
- auto* access =
- b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), idx, b.Constant(2_u), idx});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, idx, 2_u, idx));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -304,23 +301,22 @@
auto* str_ty = ty.Get<type::Struct>(
mod.symbols.Register("MyStruct"),
utils::Vector{
- ty.Get<type::StructMember>(mod.symbols.Register("arr1"), ty.array(ty.f32(), 1024u), 0u,
- 0u, 4u, 4096u, type::StructMemberAttributes{}),
- ty.Get<type::StructMember>(mod.symbols.Register("mat"), ty.mat4x4(ty.f32()), 1u, 4096u,
+ ty.Get<type::StructMember>(mod.symbols.Register("arr1"), ty.array<f32, 1024>(), 0u, 0u,
+ 4u, 4096u, type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.Register("mat"), ty.mat4x4<f32>(), 1u, 4096u,
16u, 64u, type::StructMemberAttributes{}),
- ty.Get<type::StructMember>(mod.symbols.Register("arr2"), ty.array(ty.f32(), 1024u), 2u,
+ ty.Get<type::StructMember>(mod.symbols.Register("arr2"), ty.array<f32, 1024>(), 2u,
4160u, 4u, 4096u, type::StructMemberAttributes{}),
},
16u, 32u, 32u);
auto* str_val = b.FunctionParam(str_ty);
auto* idx = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.f32());
- func->SetParams(utils::Vector{str_val, idx});
+ auto* func = b.Function("foo", ty.f32());
+ func->SetParams({str_val, idx});
- auto* access =
- b.Access(ty.f32(), str_val, utils::Vector{b.Constant(1_u), idx, b.Constant(0_u)});
- func->StartTarget()->Append(access);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access}));
+ auto* block = func->StartTarget();
+ auto* access = block->Append(b.Access(ty.f32(), str_val, 1_u, idx, 0_u));
+ block->Append(b.Return(func, access));
mod.functions.Push(func);
auto* expect = R"(
@@ -347,20 +343,18 @@
}
TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource) {
- auto* arr = b.FunctionParam(ty.array(ty.i32(), 4u));
+ auto* arr = b.FunctionParam(ty.array<i32, 4u>());
auto* idx_a = b.FunctionParam(ty.i32());
auto* idx_b = b.FunctionParam(ty.i32());
auto* idx_c = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx_a, idx_b, idx_c});
- auto* access_a = b.Access(ty.i32(), arr, utils::Vector{idx_a});
- auto* access_b = b.Access(ty.i32(), arr, utils::Vector{idx_b});
- auto* access_c = b.Access(ty.i32(), arr, utils::Vector{idx_c});
- func->StartTarget()->Append(access_a);
- func->StartTarget()->Append(access_b);
- func->StartTarget()->Append(access_c);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ auto* block = func->StartTarget();
+ block->Append(b.Access(ty.i32(), arr, idx_a));
+ block->Append(b.Access(ty.i32(), arr, idx_b));
+ auto* access_c = block->Append(b.Access(ty.i32(), arr, idx_c));
+ block->Append(b.Return(func, access_c));
mod.functions.Push(func);
auto* expect = R"(
@@ -384,23 +378,18 @@
}
TEST_F(IR_VarForDynamicIndexTest, MultipleAccessesFromSameSource_SkipConstantIndices) {
- auto* arr = b.FunctionParam(ty.array(ty.array(ty.array(ty.i32(), 4u), 4u), 4u));
+ auto* arr = b.FunctionParam(ty.array(ty.array(ty.array<i32, 4u>(), 4u), 4u));
auto* idx_a = b.FunctionParam(ty.i32());
auto* idx_b = b.FunctionParam(ty.i32());
auto* idx_c = b.FunctionParam(ty.i32());
- auto* func = b.CreateFunction("foo", ty.i32());
- func->SetParams(utils::Vector{arr, idx_a, idx_b, idx_c});
+ auto* func = b.Function("foo", ty.i32());
+ func->SetParams({arr, idx_a, idx_b, idx_c});
- auto* access_a =
- b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_a});
- auto* access_b =
- b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_b});
- auto* access_c =
- b.Access(ty.i32(), arr, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx_c});
- func->StartTarget()->Append(access_a);
- func->StartTarget()->Append(access_b);
- func->StartTarget()->Append(access_c);
- func->StartTarget()->Append(b.Return(func, utils::Vector{access_c}));
+ auto* block = func->StartTarget();
+ block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_a));
+ block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_b));
+ auto* access_c = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_c));
+ block->Append(b.Return(func, access_c));
mod.functions.Push(func);
auto* expect = R"(
diff --git a/src/tint/ir/unary.h b/src/tint/ir/unary.h
index 0ac191a..baf1218 100644
--- a/src/tint/ir/unary.h
+++ b/src/tint/ir/unary.h
@@ -37,15 +37,13 @@
~Unary() override;
/// @returns the type of the value
- const type::Type* Type() const override { return result_type_; }
+ const type::Type* Type() override { return result_type_; }
/// @returns the value for the instruction
- const Value* Val() const { return operands_[0]; }
- /// @returns the value for the instruction
Value* Val() { return operands_[0]; }
/// @returns the kind of unary instruction
- enum Kind Kind() const { return kind_; }
+ enum Kind Kind() { return kind_; }
private:
enum Kind kind_;
diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc
index 2f2307e..7802dd4 100644
--- a/src/tint/ir/unary_test.cc
+++ b/src/tint/ir/unary_test.cc
@@ -26,7 +26,7 @@
using IR_UnaryTest = IRTestHelper;
TEST_F(IR_UnaryTest, CreateComplement) {
- auto* inst = b.Complement(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Complement(mod.Types().i32(), 4_i);
ASSERT_TRUE(inst->Is<Unary>());
EXPECT_EQ(inst->Kind(), Unary::Kind::kComplement);
@@ -38,7 +38,7 @@
}
TEST_F(IR_UnaryTest, CreateNegation) {
- auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Negation(mod.Types().i32(), 4_i);
ASSERT_TRUE(inst->Is<Unary>());
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
@@ -50,7 +50,7 @@
}
TEST_F(IR_UnaryTest, Unary_Usage) {
- auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i));
+ auto* inst = b.Negation(mod.Types().i32(), 4_i);
EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation);
@@ -63,7 +63,7 @@
{
Module mod;
Builder b{mod};
- b.Negation(nullptr, b.Constant(1_i));
+ b.Negation(nullptr, 1_i);
},
"");
}
diff --git a/src/tint/ir/user_call.h b/src/tint/ir/user_call.h
index d165715..2d08482 100644
--- a/src/tint/ir/user_call.h
+++ b/src/tint/ir/user_call.h
@@ -32,12 +32,10 @@
~UserCall() override;
/// @returns the call arguments
- utils::Slice<Value const* const> Args() const override {
- return operands_.Slice().Offset(1).Reinterpret<Value const* const>();
- }
+ utils::Slice<Value* const> Args() override { return operands_.Slice().Offset(1); }
/// @returns the called function name
- const Function* Func() const { return operands_.Front()->As<ir::Function>(); }
+ Function* Func() { return operands_.Front()->As<ir::Function>(); }
private:
};
diff --git a/src/tint/ir/user_call_test.cc b/src/tint/ir/user_call_test.cc
index 2fd9151..6b73272 100644
--- a/src/tint/ir/user_call_test.cc
+++ b/src/tint/ir/user_call_test.cc
@@ -25,10 +25,10 @@
using IR_UserCallTest = IRTestHelper;
TEST_F(IR_UserCallTest, Usage) {
- auto* func = b.CreateFunction("myfunc", mod.Types().void_());
+ auto* func = b.Function("myfunc", mod.Types().void_());
auto* arg1 = b.Constant(1_u);
auto* arg2 = b.Constant(2_u);
- auto* e = b.UserCall(mod.Types().void_(), func, utils::Vector{arg1, arg2});
+ auto* e = b.Call(mod.Types().void_(), func, utils::Vector{arg1, arg2});
EXPECT_THAT(func->Usages(), testing::UnorderedElementsAre(Usage{e, 0u}));
EXPECT_THAT(arg1->Usages(), testing::UnorderedElementsAre(Usage{e, 1u}));
EXPECT_THAT(arg2->Usages(), testing::UnorderedElementsAre(Usage{e, 2u}));
@@ -39,7 +39,7 @@
{
Module mod;
Builder b{mod};
- b.UserCall(nullptr, b.CreateFunction("myfunc", mod.Types().void_()));
+ b.Call(nullptr, b.Function("myfunc", mod.Types().void_()));
},
"");
}
@@ -49,7 +49,7 @@
{
Module mod;
Builder b{mod};
- b.UserCall(mod.Types().f32(), nullptr);
+ b.Call(mod.Types().f32(), nullptr);
},
"");
}
@@ -59,8 +59,8 @@
{
Module mod;
Builder b{mod};
- b.UserCall(mod.Types().void_(), b.CreateFunction("myfunc", mod.Types().void_()),
- utils::Vector<Value*, 1>{nullptr});
+ b.Call(mod.Types().void_(), b.Function("myfunc", mod.Types().void_()),
+ utils::Vector<Value*, 1>{nullptr});
},
"");
}
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index 029ed8f..19beb21 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -22,7 +22,7 @@
#include "src/tint/ir/binary.h"
#include "src/tint/ir/bitcast.h"
#include "src/tint/ir/break_if.h"
-#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/builtin_call.h"
#include "src/tint/ir/construct.h"
#include "src/tint/ir/continue.h"
#include "src/tint/ir/convert.h"
@@ -60,7 +60,7 @@
utils::Result<Success, diag::List> IsValid() {
CheckRootBlock(mod_.root_block);
- for (const auto* func : mod_.functions) {
+ for (auto* func : mod_.functions) {
CheckFunction(func);
}
@@ -81,7 +81,7 @@
diag::List diagnostics_;
Disassembler dis_{mod_};
- const Block* current_block_ = nullptr;
+ Block* current_block_ = nullptr;
void DisassembleIfNeeded() {
if (mod_.disassembly_file) {
@@ -90,83 +90,86 @@
mod_.disassembly_file = std::make_unique<Source::File>("", dis_.Disassemble());
}
- void AddError(const Instruction* inst, const std::string& err) {
+ void AddError(Instruction* inst, std::string err) {
DisassembleIfNeeded();
auto src = dis_.InstructionSource(inst);
src.file = mod_.disassembly_file.get();
- AddError(err, src);
+ AddError(std::move(err), src);
if (current_block_) {
AddNote(current_block_, "In block");
}
}
- void AddError(const Instruction* inst, uint32_t idx, const std::string& err) {
+ void AddError(Instruction* inst, size_t idx, std::string err) {
DisassembleIfNeeded();
- auto src = dis_.OperandSource(Disassembler::Operand{inst, idx});
+ auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
src.file = mod_.disassembly_file.get();
- AddError(err, src);
+ AddError(std::move(err), src);
if (current_block_) {
AddNote(current_block_, "In block");
}
}
- void AddError(const Block* blk, const std::string& err) {
+ void AddError(Block* blk, std::string err) {
DisassembleIfNeeded();
auto src = dis_.BlockSource(blk);
src.file = mod_.disassembly_file.get();
- AddError(err, src);
+ AddError(std::move(err), src);
}
- void AddNote(const Block* blk, const std::string& err) {
+ void AddNote(Instruction* inst, size_t idx, std::string err) {
+ DisassembleIfNeeded();
+ auto src = dis_.OperandSource(Usage{inst, static_cast<uint32_t>(idx)});
+ src.file = mod_.disassembly_file.get();
+ AddNote(std::move(err), src);
+ }
+
+ void AddNote(Block* blk, std::string err) {
DisassembleIfNeeded();
auto src = dis_.BlockSource(blk);
src.file = mod_.disassembly_file.get();
- AddNote(err, src);
+ AddNote(std::move(err), src);
}
- void AddError(const std::string& err, Source src = {}) {
- diagnostics_.add_error(tint::diag::System::IR, err, src);
+ void AddError(std::string err, Source src = {}) {
+ diagnostics_.add_error(tint::diag::System::IR, std::move(err), src);
}
- void AddNote(const std::string& note, Source src = {}) {
- diagnostics_.add_note(tint::diag::System::IR, note, src);
+ void AddNote(std::string note, Source src = {}) {
+ diagnostics_.add_note(tint::diag::System::IR, std::move(note), src);
}
- std::string Name(const Value* v) { return mod_.NameOf(v).Name(); }
+ // std::string Name(Value* v) { return mod_.NameOf(v).Name(); }
- void CheckRootBlock(const Block* blk) {
+ void CheckRootBlock(Block* blk) {
if (!blk) {
return;
}
TINT_SCOPED_ASSIGNMENT(current_block_, blk);
- for (const auto* inst : *blk) {
+ for (auto* inst : *blk) {
auto* var = inst->As<ir::Var>();
if (!var) {
AddError(inst,
std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
continue;
}
- if (!var->Type()->Is<type::Pointer>()) {
- AddError(inst, std::string("root block: 'var' ") + Name(var) +
- "type is not a pointer: " + var->Type()->TypeInfo().name);
- }
}
}
- void CheckFunction(const Function* func) { CheckBlock(func->StartTarget()); }
+ void CheckFunction(Function* func) { CheckBlock(func->StartTarget()); }
- void CheckBlock(const Block* blk) {
+ void CheckBlock(Block* blk) {
TINT_SCOPED_ASSIGNMENT(current_block_, blk);
if (!blk->HasBranchTarget()) {
AddError(blk, "block: does not end in a branch");
}
- for (const auto* inst : *blk) {
+ for (auto* inst : *blk) {
if (inst->Is<ir::Branch>() && inst != blk->Branch()) {
AddError(inst, "block: branch which isn't the final instruction");
continue;
@@ -176,68 +179,136 @@
}
}
- void CheckInstruction(const Instruction* inst) {
+ void CheckInstruction(Instruction* inst) {
tint::Switch(
- inst, //
- [&](const ir::Access*) {}, //
- [&](const ir::Binary*) {}, //
- [&](const ir::Branch* b) { CheckBranch(b); }, //
- [&](const ir::Call* c) { CheckCall(c); }, //
- [&](const ir::Load*) {}, //
- [&](const ir::Store*) {}, //
- [&](const ir::Swizzle*) {}, //
- [&](const ir::Unary*) {}, //
- [&](const ir::Var*) {}, //
+ inst, //
+ [&](Access* a) { CheckAccess(a); }, //
+ [&](Binary*) {}, //
+ [&](Branch* b) { CheckBranch(b); }, //
+ [&](Call* c) { CheckCall(c); }, //
+ [&](Load*) {}, //
+ [&](Store*) {}, //
+ [&](Swizzle*) {}, //
+ [&](Unary*) {}, //
+ [&](Var*) {}, //
[&](Default) {
AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
});
}
- void CheckCall(const ir::Call* call) {
+ void CheckCall(Call* call) {
tint::Switch(
- call, //
- [&](const ir::Bitcast*) {}, //
- [&](const ir::Builtin*) {}, //
- [&](const ir::Construct*) {}, //
- [&](const ir::Convert*) {}, //
- [&](const ir::Discard*) {}, //
- [&](const ir::UserCall*) {}, //
+ call, //
+ [&](Bitcast*) {}, //
+ [&](BuiltinCall*) {}, //
+ [&](Construct*) {}, //
+ [&](Convert*) {}, //
+ [&](Discard*) {}, //
+ [&](UserCall*) {}, //
[&](Default) {
AddError(std::string("missing validation of call: ") + call->TypeInfo().name);
});
}
- void CheckBranch(const ir::Branch* b) {
+ void CheckAccess(ir::Access* a) {
+ bool is_ptr = a->Object()->Type()->Is<type::Pointer>();
+ auto* ty = a->Object()->Type()->UnwrapPtr();
+
+ auto current = [&] {
+ return is_ptr ? "ptr<" + ty->FriendlyName() + ">" : ty->FriendlyName();
+ };
+
+ for (size_t i = 0; i < a->Indices().Length(); i++) {
+ auto err = [&](std::string msg) {
+ AddError(a, i + Access::kIndicesOperandOffset, std::move(msg));
+ };
+ auto note = [&](std::string msg) {
+ AddNote(a, i + Access::kIndicesOperandOffset, std::move(msg));
+ };
+
+ auto* index = a->Indices()[i];
+ if (TINT_UNLIKELY(!index->Type()->is_integer_scalar())) {
+ err("access: index must be integer, got " + index->Type()->FriendlyName());
+ return;
+ }
+
+ if (auto* const_index = index->As<ir::Constant>()) {
+ auto* value = const_index->Value();
+ if (value->Type()->is_signed_integer_scalar()) {
+ // index is a signed integer scalar. Check that the index isn't negative.
+ // If the index is unsigned, we can skip this.
+ auto idx = value->ValueAs<AInt>();
+ if (TINT_UNLIKELY(idx < 0)) {
+ err("access: constant index must be positive, got " + std::to_string(idx));
+ return;
+ }
+ }
+
+ auto idx = value->ValueAs<uint32_t>();
+ auto* el = ty->Element(idx);
+ if (TINT_UNLIKELY(!el)) {
+ // Is index in bounds?
+ if (auto el_count = ty->Elements().count; el_count != 0 && idx >= el_count) {
+ err("access: index out of bounds for type " + current());
+ note("acceptable range: [0.." + std::to_string(el_count - 1) + "]");
+ return;
+ }
+ err("access: type " + current() + " cannot be indexed");
+ return;
+ }
+ ty = el;
+ } else {
+ auto* el = ty->Elements().type;
+ if (TINT_UNLIKELY(!el)) {
+ err("access: type " + current() + " cannot be dynamically indexed");
+ return;
+ }
+ ty = el;
+ }
+ }
+
+ auto* want_ty = a->Type()->UnwrapPtr();
+ bool want_ptr = a->Type()->Is<type::Pointer>();
+ if (TINT_UNLIKELY(ty != want_ty || is_ptr != want_ptr)) {
+ std::string want =
+ want_ptr ? "ptr<" + want_ty->FriendlyName() + ">" : want_ty->FriendlyName();
+ AddError(a, "access: result of access chain is type " + current() +
+ " but instruction type is " + want);
+ return;
+ }
+ }
+
+ void CheckBranch(ir::Branch* b) {
tint::Switch(
- b, //
- [&](const ir::BreakIf*) {}, //
- [&](const ir::Continue*) {}, //
- [&](const ir::ExitIf*) {}, //
- [&](const ir::ExitLoop*) {}, //
- [&](const ir::ExitSwitch*) {}, //
- [&](const ir::If* if_) { CheckIf(if_); }, //
- [&](const ir::Loop*) {}, //
- [&](const ir::NextIteration*) {}, //
- [&](const ir::Return* ret) {
+ b, //
+ [&](BreakIf*) {}, //
+ [&](Continue*) {}, //
+ [&](ExitIf*) {}, //
+ [&](ExitLoop*) {}, //
+ [&](ExitSwitch*) {}, //
+ [&](If* if_) { CheckIf(if_); }, //
+ [&](Loop*) {}, //
+ [&](NextIteration*) {}, //
+ [&](Return* ret) {
if (ret->Func() == nullptr) {
AddError("return: null function");
}
- }, //
- [&](const ir::Switch*) {}, //
+ }, //
+ [&](Switch*) {}, //
[&](Default) {
AddError(std::string("missing validation of branch: ") + b->TypeInfo().name);
});
}
- void CheckIf(const ir::If* if_) {
+ void CheckIf(If* if_) {
if (!if_->Condition()) {
AddError(if_, "if: condition is nullptr");
}
if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
- AddError(if_, If::kConditionOperandIndex, "if: condition must be a `bool` type");
+ AddError(if_, If::kConditionOperandOffset, "if: condition must be a `bool` type");
}
}
-};
+}; // namespace
} // namespace
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 0fab76f..d2f07d0 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/ir/validate.h"
+#include <utility>
+
#include "gmock/gmock.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/ir_test_helper.h"
+#include "src/tint/ir/validate.h"
+#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
+#include "src/tint/type/struct.h"
namespace tint::ir {
namespace {
@@ -26,18 +30,18 @@
using IR_ValidateTest = IRTestHelper;
TEST_F(IR_ValidateTest, RootBlock_Var) {
- mod.root_block = b.CreateRootBlockIfNeeded();
- mod.root_block->Append(b.Declare(mod.Types().pointer(
- mod.Types().i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite)));
+ mod.root_block = b.RootBlock();
+ mod.root_block->Append(
+ b.Var(ty.ptr(builtin::AddressSpace::kPrivate, ty.i32(), builtin::Access::kReadWrite)));
auto res = ir::Validate(mod);
EXPECT_TRUE(res) << res.Failure().str();
}
TEST_F(IR_ValidateTest, RootBlock_NonVar) {
- auto* l = b.CreateLoop();
+ auto* l = b.Loop();
l->Body()->Append(b.Continue(l));
- mod.root_block = b.CreateRootBlockIfNeeded();
+ mod.root_block = b.RootBlock();
mod.root_block->Append(l);
auto res = ir::Validate(mod);
@@ -48,21 +52,7 @@
:2:1 note: In block
%b1 = block {
-^^^^^^^^^^^^^
- loop [b: %b2]
-^^^^^^^^^^^^^^^
- # Body block
-^^^^^^^^^^^^^^^^
- %b2 = block {
-^^^^^^^^^^^^^^^^^
- continue %b3
-^^^^^^^^^^^^^^^^^^
- }
-^^^^^
-
-
-}
-^
+^^^^^^^^^^^
note: # Disassembly
# Root block
@@ -78,55 +68,25 @@
)");
}
-TEST_F(IR_ValidateTest, RootBlock_VarBadType) {
- mod.root_block = b.CreateRootBlockIfNeeded();
- mod.root_block->Append(b.Declare(mod.Types().i32()));
- auto res = ir::Validate(mod);
- ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(),
- R"(:3:12 error: root block: 'var' type is not a pointer: tint::type::I32
- %1:i32 = var
- ^^^
-
-:2:1 note: In block
-%b1 = block {
-^^^^^^^^^^^^^
- %1:i32 = var
-^^^^^^^^^^^^^^
-}
-^
-
-note: # Disassembly
-# Root block
-%b1 = block {
- %1:i32 = var
-}
-
-)");
-}
-
TEST_F(IR_ValidateTest, Function) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ auto* f = b.Function("my_func", ty.void_());
mod.functions.Push(f);
- f->SetParams(
- utils::Vector{b.FunctionParam(mod.Types().i32()), b.FunctionParam(mod.Types().f32())});
- f->StartTarget()->SetInstructions(utils::Vector{b.Return(f)});
+ f->SetParams({b.FunctionParam(ty.i32()), b.FunctionParam(ty.f32())});
+ f->StartTarget()->SetInstructions({b.Return(f)});
auto res = ir::Validate(mod);
EXPECT_TRUE(res) << res.Failure().str();
}
TEST_F(IR_ValidateTest, Block_NoBranchAtEnd) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ auto* f = b.Function("my_func", ty.void_());
mod.functions.Push(f);
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
- EXPECT_EQ(res.Failure().str(), R"(:2:1 error: block: does not end in a branch
+ EXPECT_EQ(res.Failure().str(), R"(:2:3 error: block: does not end in a branch
%b1 = block {
-^^^^^^^^^^^^^^^
- }
-^^^
+ ^^^^^^^^^^^
note: # Disassembly
%my_func = func():void -> %b1 {
@@ -136,26 +96,385 @@
)");
}
-TEST_F(IR_ValidateTest, Block_BranchInMiddle) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+TEST_F(IR_ValidateTest, Valid_Access_Value) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.mat3x2<f32>());
+ f->SetParams({obj});
mod.functions.Push(f);
- f->StartTarget()->SetInstructions(utils::Vector{b.Return(f), b.Return(f)});
+ f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 0_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure().str();
+}
+
+TEST_F(IR_ValidateTest, Valid_Access_Ptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(
+ ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ EXPECT_TRUE(res) << res.Failure().str();
+}
+
+TEST_F(IR_ValidateTest, Access_NegativeIndex) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.vec3<f32>());
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.f32(), obj, -1_i));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:25 error: access: constant index must be positive, got -1
+ %3:f32 = access %2, -1i
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:vec3<f32>):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, -1i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_OOB_Index_Value) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.mat3x2<f32>());
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 3_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:29 error: access: index out of bounds for type vec2<f32>
+ %3:f32 = access %2, 1u, 3u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+:3:29 note: acceptable range: [0..1]
+ %3:f32 = access %2, 1u, 3u
+ ^^
+
+note: # Disassembly
+%my_func = func(%2:mat3x2<f32>):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1u, 3u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_OOB_Index_Ptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(
+ ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u, 3_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:3:55 error: access: index out of bounds for type ptr<vec2<f32>>
+ %3:ptr<private, f32, read_write> = access %2, 1u, 3u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+:3:55 note: acceptable range: [0..1]
+ %3:ptr<private, f32, read_write> = access %2, 1u, 3u
+ ^^
+
+note: # Disassembly
+%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:ptr<private, f32, read_write> = access %2, 1u, 3u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_StaticallyUnindexableType_Value) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.f32());
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:25 error: access: type f32 cannot be indexed
+ %3:f32 = access %2, 1u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:f32):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_StaticallyUnindexableType_Ptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.ptr<private_, f32>());
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:51 error: access: type ptr<f32> cannot be indexed
+ %3:ptr<private, f32, read_write> = access %2, 1u
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:ptr<private, f32, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:ptr<private, f32, read_write> = access %2, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_DynamicallyUnindexableType_Value) {
+ utils::Vector members{
+ ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 1u, 4u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ };
+ auto* str_ty = ty.Get<type::Struct>(mod.symbols.New(), std::move(members), 4u, 8u, 8u);
+
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(str_ty);
+ auto* idx = b.FunctionParam(ty.i32());
+ f->SetParams({obj, idx});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.i32(), obj, idx));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:8:25 error: access: type tint_symbol_2 cannot be dynamically indexed
+ %4:i32 = access %2, %3
+ ^^
+
+:7:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+tint_symbol_2 = struct @align(4) {
+ tint_symbol:i32 @offset(0)
+ tint_symbol_1:i32 @offset(4)
+}
+
+%my_func = func(%2:tint_symbol_2, %3:i32):void -> %b1 {
+ %b1 = block {
+ %4:i32 = access %2, %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_DynamicallyUnindexableType_Ptr) {
+ utils::Vector members{
+ ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 1u, 4u, 4u, 4u,
+ type::StructMemberAttributes{}),
+ };
+ auto* str_ty = ty.Get<type::Struct>(mod.symbols.New(), std::move(members), 4u, 8u, 8u);
+
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(
+ ty.ptr(builtin::AddressSpace::kPrivate, str_ty, builtin::Access::kReadWrite));
+ auto* idx = b.FunctionParam(ty.i32());
+ f->SetParams({obj, idx});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.i32(), obj, idx));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:8:25 error: access: type ptr<tint_symbol_2> cannot be dynamically indexed
+ %4:i32 = access %2, %3
+ ^^
+
+:7:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+tint_symbol_2 = struct @align(4) {
+ tint_symbol:i32 @offset(0)
+ tint_symbol_1:i32 @offset(4)
+}
+
+%my_func = func(%2:ptr<private, tint_symbol_2, read_write>, %3:i32):void -> %b1 {
+ %b1 = block {
+ %4:i32 = access %2, %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_Incorrect_Type_Value_Value) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(ty.mat3x2<f32>());
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.i32(), obj, 1_u, 1_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(),
+ R"(:3:14 error: access: result of access chain is type f32 but instruction type is i32
+ %3:i32 = access %2, 1u, 1u
+ ^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:mat3x2<f32>):void -> %b1 {
+ %b1 = block {
+ %3:i32 = access %2, 1u, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Ptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(
+ ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.ptr<private_, i32>(), obj, 1_u, 1_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(
+ res.Failure().str(),
+ R"(:3:40 error: access: result of access chain is type ptr<f32> but instruction type is ptr<i32>
+ %3:ptr<private, i32, read_write> = access %2, 1u, 1u
+ ^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:ptr<private, i32, read_write> = access %2, 1u, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Value) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* obj = b.FunctionParam(
+ ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
+ f->SetParams({obj});
+ mod.functions.Push(f);
+
+ f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 1_u));
+ f->StartTarget()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(
+ res.Failure().str(),
+ R"(:3:14 error: access: result of access chain is type ptr<f32> but instruction type is f32
+ %3:f32 = access %2, 1u, 1u
+ ^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
+ %b1 = block {
+ %3:f32 = access %2, 1u, 1u
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Block_BranchInMiddle) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ f->StartTarget()->SetInstructions({b.Return(f), b.Return(f)});
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
EXPECT_EQ(res.Failure().str(), R"(:3:5 error: block: branch which isn't the final instruction
ret
^^^
-:2:1 note: In block
+:2:3 note: In block
%b1 = block {
-^^^^^^^^^^^^^^^
- ret
-^^^^^^^
- ret
-^^^^^^^
- }
-^^^
+ ^^^^^^^^^^^
note: # Disassembly
%my_func = func():void -> %b1 {
@@ -168,10 +487,10 @@
}
TEST_F(IR_ValidateTest, If_ConditionIsBool) {
- auto* f = b.CreateFunction("my_func", mod.Types().void_());
+ auto* f = b.Function("my_func", ty.void_());
mod.functions.Push(f);
- auto* if_ = b.CreateIf(b.Constant(1_i));
+ auto* if_ = b.If(1_i);
if_->True()->Append(b.Return(f));
if_->False()->Append(b.Return(f));
@@ -183,33 +502,9 @@
if 1i [t: %b2, f: %b3]
^^
-:2:1 note: In block
+:2:3 note: In block
%b1 = block {
-^^^^^^^^^^^^^^^
- if 1i [t: %b2, f: %b3]
-^^^^^^^^^^^^^^^^^^^^^^^^^^
- # True block
-^^^^^^^^^^^^^^^^^^
- %b2 = block {
-^^^^^^^^^^^^^^^^^^^
- ret
-^^^^^^^^^^^
- }
-^^^^^^^
-
-
- # False block
-^^^^^^^^^^^^^^^^^^^
- %b3 = block {
-^^^^^^^^^^^^^^^^^^^
- ret
-^^^^^^^^^^^
- }
-^^^^^^^
-
-
- }
-^^^
+ ^^^^^^^^^^^
note: # Disassembly
%my_func = func():void -> %b1 {
diff --git a/src/tint/ir/value.cc b/src/tint/ir/value.cc
index c735cd9..5eb2e1e 100644
--- a/src/tint/ir/value.cc
+++ b/src/tint/ir/value.cc
@@ -15,6 +15,7 @@
#include "src/tint/ir/value.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/instruction.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::Value);
@@ -24,4 +25,12 @@
Value::~Value() = default;
+void Value::ReplaceAllUsesWith(std::function<Value*(Usage use)> replacer) {
+ while (!uses_.IsEmpty()) {
+ auto& use = *uses_.begin();
+ auto* replacement = replacer(use);
+ use.instruction->SetOperand(use.operand_index, replacement);
+ }
+}
+
} // namespace tint::ir
diff --git a/src/tint/ir/value.h b/src/tint/ir/value.h
index 3166168..1d7f88d 100644
--- a/src/tint/ir/value.h
+++ b/src/tint/ir/value.h
@@ -66,10 +66,14 @@
/// @returns the set of usages of this value. An instruction may appear multiple times if it
/// uses the value for multiple different operands.
- const utils::Hashset<Usage, 4, Usage::Hasher>& Usages() const { return uses_; }
+ const utils::Hashset<Usage, 4, Usage::Hasher>& Usages() { return uses_; }
/// @returns the type of the value
- virtual const type::Type* Type() const { return nullptr; }
+ virtual const type::Type* Type() { return nullptr; }
+
+ /// Replace all uses of the value.
+ /// @param replacer a function which returns a replacement for a given use
+ void ReplaceAllUsesWith(std::function<Value*(Usage use)> replacer);
protected:
/// Constructor
diff --git a/src/tint/ir/var.cc b/src/tint/ir/var.cc
index 8b0225f..4dd00ce 100644
--- a/src/tint/ir/var.cc
+++ b/src/tint/ir/var.cc
@@ -19,7 +19,7 @@
namespace tint::ir {
-Var::Var(const type::Type* ty) : type_(ty) {
+Var::Var(const type::Pointer* ty) : type_(ty) {
TINT_ASSERT(IR, type_ != nullptr);
// Default to no initializer.
diff --git a/src/tint/ir/var.h b/src/tint/ir/var.h
index 3c633cc..6f22c7d 100644
--- a/src/tint/ir/var.h
+++ b/src/tint/ir/var.h
@@ -19,6 +19,7 @@
#include "src/tint/builtin/address_space.h"
#include "src/tint/ir/binding_point.h"
#include "src/tint/ir/operand_instruction.h"
+#include "src/tint/type/pointer.h"
#include "src/tint/utils/castable.h"
#include "src/tint/utils/vector.h"
@@ -29,27 +30,27 @@
public:
/// Constructor
/// @param type the type of the var
- explicit Var(const type::Type* type);
+ explicit Var(const type::Pointer* type);
~Var() override;
/// @returns the type of the var
- const type::Type* Type() const override { return type_; }
+ const type::Pointer* Type() override { return type_; }
/// Sets the var initializer
/// @param initializer the initializer
void SetInitializer(Value* initializer);
/// @returns the initializer
- const Value* Initializer() const { return operands_[0]; }
+ Value* Initializer() { return operands_[0]; }
/// Sets the binding point
/// @param group the group
/// @param binding the binding
void SetBindingPoint(uint32_t group, uint32_t binding) { binding_point_ = {group, binding}; }
/// @returns the binding points if `Attributes` contains `kBindingPoint`
- std::optional<struct BindingPoint> BindingPoint() const { return binding_point_; }
+ std::optional<struct BindingPoint> BindingPoint() { return binding_point_; }
private:
- const type::Type* type_ = nullptr;
+ const type::Pointer* type_ = nullptr;
std::optional<struct BindingPoint> binding_point_;
};
diff --git a/src/tint/ir/var_test.cc b/src/tint/ir/var_test.cc
index ebde1bc..7173237 100644
--- a/src/tint/ir/var_test.cc
+++ b/src/tint/ir/var_test.cc
@@ -31,7 +31,7 @@
{
Module mod;
Builder b{mod};
- b.Declare(nullptr);
+ b.Var(nullptr);
},
"");
}
@@ -39,7 +39,7 @@
TEST_F(IR_VarTest, Initializer_Usage) {
Module mod;
Builder b{mod};
- auto* var = b.Declare(mod.Types().f32());
+ auto* var = b.Var(ty.ptr<function, f32>());
auto* init = b.Constant(1_f);
var->SetInitializer(init);
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 459edaa..52a73b3 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -117,30 +117,6 @@
namespace tint {
-namespace detail {
-
-/// IsVectorLike<T>::value is true if T is a utils::Vector or utils::VectorRef.
-template <typename T>
-struct IsVectorLike {
- /// Non-specialized form of IsVectorLike defaults to false
- static constexpr bool value = false;
-};
-
-/// IsVectorLike specialization for utils::Vector
-template <typename T, size_t N>
-struct IsVectorLike<utils::Vector<T, N>> {
- /// True for the IsVectorLike specialization of utils::Vector
- static constexpr bool value = true;
-};
-
-/// IsVectorLike specialization for utils::VectorRef
-template <typename T>
-struct IsVectorLike<utils::VectorRef<T>> {
- /// True for the IsVectorLike specialization of utils::VectorRef
- static constexpr bool value = true;
-};
-} // namespace detail
-
// A sentinel type used by some template arguments to signal that the a type should be inferred.
struct Infer {};
@@ -194,11 +170,11 @@
using EnableIfScalar = utils::traits::EnableIf<
IsScalar<utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>>;
- /// A helper used to disable overloads if the first type in `TYPES` is a utils::Vector,
- /// utils::VectorRef or utils::VectorRef.
+ /// A helper used to disable overloads if the first type in `TYPES` is a utils::Vector or
+ /// utils::VectorRef.
template <typename... TYPES>
- using DisableIfVectorLike = utils::traits::EnableIf<!detail::IsVectorLike<
- utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>::value>;
+ using DisableIfVectorLike = utils::traits::EnableIf<
+ !utils::IsVectorLike<utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>>;
/// A helper used to enable overloads if the first type in `TYPES` is identifier-like.
template <typename... TYPES>
@@ -1148,25 +1124,25 @@
type);
}
- /// @param type the type of the pointer
/// @param address_space the address space of the pointer
+ /// @param type the type of the pointer
/// @param access the optional access control of the pointer
/// @return the pointer to `type` with the given builtin::AddressSpace
- ast::Type pointer(ast::Type type,
- builtin::AddressSpace address_space,
- builtin::Access access = builtin::Access::kUndefined) const {
- return pointer(builder->source_, type, address_space, access);
+ ast::Type ptr(builtin::AddressSpace address_space,
+ ast::Type type,
+ builtin::Access access = builtin::Access::kUndefined) const {
+ return ptr(builder->source_, address_space, type, access);
}
/// @param source the Source of the node
- /// @param type the type of the pointer
/// @param address_space the address space of the pointer
+ /// @param type the type of the pointer
/// @param access the optional access control of the pointer
/// @return the pointer to `type` with the given builtin::AddressSpace
- ast::Type pointer(const Source& source,
- ast::Type type,
- builtin::AddressSpace address_space,
- builtin::Access access = builtin::Access::kUndefined) const {
+ ast::Type ptr(const Source& source,
+ builtin::AddressSpace address_space,
+ ast::Type type,
+ builtin::Access access = builtin::Access::kUndefined) const {
if (access != builtin::Access::kUndefined) {
return (*this)(source, "ptr", address_space, type, access);
} else {
@@ -1178,9 +1154,9 @@
/// @param access the optional access control of the pointer
/// @return the pointer to type `T` with the given builtin::AddressSpace.
template <typename T>
- ast::Type pointer(builtin::AddressSpace address_space,
- builtin::Access access = builtin::Access::kUndefined) const {
- return pointer<T>(builder->source_, address_space, access);
+ ast::Type ptr(builtin::AddressSpace address_space,
+ builtin::Access access = builtin::Access::kUndefined) const {
+ return ptr<T>(builder->source_, address_space, access);
}
/// @param source the Source of the node
@@ -1188,9 +1164,9 @@
/// @param access the optional access control of the pointer
/// @return the pointer to type `T` with the given builtin::AddressSpace.
template <typename T>
- ast::Type pointer(const Source& source,
- builtin::AddressSpace address_space,
- builtin::Access access = builtin::Access::kUndefined) const {
+ ast::Type ptr(const Source& source,
+ builtin::AddressSpace address_space,
+ builtin::Access access = builtin::Access::kUndefined) const {
if (access != builtin::Access::kUndefined) {
return (*this)(source, "ptr", address_space, Of<T>(), access);
} else {
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index d0268c6..c78b9a0 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -2512,7 +2512,7 @@
Attributes{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt);
- auto* var_type = ty_.Reference(var_store_type, builtin::AddressSpace::kUndefined);
+ auto* var_type = ty_.Reference(builtin::AddressSpace::kUndefined, var_store_type);
identifier_types_.emplace(inst.result_id(), var_type);
}
return success();
@@ -3358,7 +3358,7 @@
Source{},
parser_impl_.MakeVar(id, builtin::AddressSpace::kUndefined, builtin::Access::kUndefined,
store_type, nullptr, Attributes{})));
- auto* type = ty_.Reference(store_type, builtin::AddressSpace::kUndefined);
+ auto* type = ty_.Reference(builtin::AddressSpace::kUndefined, store_type);
identifier_types_.emplace(id, type);
}
@@ -4892,11 +4892,11 @@
const Type* FunctionEmitter::RemapPointerProperties(const Type* type, uint32_t result_id) {
if (auto* ast_ptr_type = As<Pointer>(type)) {
const auto pi = GetPointerInfo(result_id);
- return ty_.Pointer(ast_ptr_type->type, pi.address_space, pi.access);
+ return ty_.Pointer(pi.address_space, ast_ptr_type->type, pi.access);
}
if (auto* ast_ptr_type = As<Reference>(type)) {
const auto pi = GetPointerInfo(result_id);
- return ty_.Reference(ast_ptr_type->type, pi.address_space, pi.access);
+ return ty_.Reference(pi.address_space, ast_ptr_type->type, pi.access);
}
return type;
}
@@ -6338,7 +6338,7 @@
return {};
}
return {
- ty_.Pointer(ref->type, ref->address_space),
+ ty_.Pointer(ref->address_space, ref->type),
create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kAddressOf, expr.expr),
};
}
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index fa9572a..7a3089d 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -1226,9 +1226,9 @@
}
switch (ptr_as) {
case PtrAs::Ref:
- return ty_.Reference(ast_elem_ty, ast_address_space);
+ return ty_.Reference(ast_address_space, ast_elem_ty);
case PtrAs::Ptr:
- return ty_.Pointer(ast_elem_ty, ast_address_space);
+ return ty_.Pointer(ast_address_space, ast_elem_ty);
}
Fail() << "invalid value for ptr_as: " << static_cast<int>(ptr_as);
return nullptr;
diff --git a/src/tint/reader/spirv/parser_type.cc b/src/tint/reader/spirv/parser_type.cc
index 5abd739..8ace8a0 100644
--- a/src/tint/reader/spirv/parser_type.cc
+++ b/src/tint/reader/spirv/parser_type.cc
@@ -55,13 +55,13 @@
namespace {
struct PointerHasher {
size_t operator()(const Pointer& t) const {
- return utils::Hash(t.type, t.address_space, t.access);
+ return utils::Hash(t.address_space, t.type, t.access);
}
};
struct ReferenceHasher {
size_t operator()(const Reference& t) const {
- return utils::Hash(t.type, t.address_space, t.access);
+ return utils::Hash(t.address_space, t.type, t.access);
}
};
@@ -178,8 +178,8 @@
Texture::~Texture() = default;
-Pointer::Pointer(const Type* t, builtin::AddressSpace s, builtin::Access a)
- : type(t), address_space(s), access(a) {}
+Pointer::Pointer(builtin::AddressSpace s, const Type* t, builtin::Access a)
+ : address_space(s), type(t), access(a) {}
Pointer::Pointer(const Pointer&) = default;
ast::Type Pointer::Build(ProgramBuilder& b) const {
@@ -189,11 +189,11 @@
// types.
return b.ty("invalid_spirv_ptr_type");
}
- return b.ty.pointer(type->Build(b), address_space, access);
+ return b.ty.ptr(address_space, type->Build(b), access);
}
-Reference::Reference(const Type* t, builtin::AddressSpace s, builtin::Access a)
- : type(t), address_space(s), access(a) {}
+Reference::Reference(builtin::AddressSpace s, const Type* t, builtin::Access a)
+ : address_space(s), type(t), access(a) {}
Reference::Reference(const Reference&) = default;
ast::Type Reference::Build(ProgramBuilder& b) const {
@@ -487,16 +487,16 @@
});
}
-const spirv::Pointer* TypeManager::Pointer(const Type* el,
- builtin::AddressSpace address_space,
+const spirv::Pointer* TypeManager::Pointer(builtin::AddressSpace address_space,
+ const Type* el,
builtin::Access access) {
- return state->pointers_.Get(el, address_space, access);
+ return state->pointers_.Get(address_space, el, access);
}
-const spirv::Reference* TypeManager::Reference(const Type* el,
- builtin::AddressSpace address_space,
+const spirv::Reference* TypeManager::Reference(builtin::AddressSpace address_space,
+ const Type* el,
builtin::Access access) {
- return state->references_.Get(el, address_space, access);
+ return state->references_.Get(address_space, el, access);
}
const spirv::Vector* TypeManager::Vector(const Type* el, uint32_t size) {
diff --git a/src/tint/reader/spirv/parser_type.h b/src/tint/reader/spirv/parser_type.h
index d3295ee..e47646b 100644
--- a/src/tint/reader/spirv/parser_type.h
+++ b/src/tint/reader/spirv/parser_type.h
@@ -159,10 +159,10 @@
/// `ptr<SC, T, AM>` type
struct Pointer final : public utils::Castable<Pointer, Type> {
/// Constructor
- /// @param ty the store type
/// @param sc the pointer address space
+ /// @param ty the store type
/// @param access the declared access mode
- Pointer(const Type* ty, builtin::AddressSpace sc, builtin::Access access);
+ Pointer(builtin::AddressSpace sc, const Type* ty, builtin::Access access);
/// Copy constructor
/// @param other the other type to copy
@@ -177,10 +177,10 @@
std::string String() const override;
#endif // NDEBUG
- /// the store type
- Type const* const type;
/// the pointer address space
builtin::AddressSpace const address_space;
+ /// the store type
+ Type const* const type;
/// the pointer declared access mode
builtin::Access const access;
};
@@ -190,10 +190,10 @@
/// reader.
struct Reference final : public utils::Castable<Reference, Type> {
/// Constructor
- /// @param ty the referenced type
/// @param sc the reference address space
+ /// @param ty the referenced type
/// @param access the reference declared access mode
- Reference(const Type* ty, builtin::AddressSpace sc, builtin::Access access);
+ Reference(builtin::AddressSpace sc, const Type* ty, builtin::Access access);
/// Copy constructor
/// @param other the other type to copy
@@ -208,10 +208,10 @@
std::string String() const override;
#endif // NDEBUG
- /// the store type
- Type const* const type;
/// the pointer address space
builtin::AddressSpace const address_space;
+ /// the store type
+ Type const* const type;
/// the pointer declared access mode
builtin::Access const access;
};
@@ -543,21 +543,21 @@
/// otherwise nullptr.
const Type* AsUnsigned(const Type* ty);
- /// @param ty the store type
/// @param address_space the pointer address space
+ /// @param ty the store type
/// @param access the declared access mode
/// @return a Pointer type. Repeated calls with the same arguments will return
/// the same pointer.
- const spirv::Pointer* Pointer(const Type* ty,
- builtin::AddressSpace address_space,
+ const spirv::Pointer* Pointer(builtin::AddressSpace address_space,
+ const Type* ty,
builtin::Access access = builtin::Access::kUndefined);
- /// @param ty the referenced type
/// @param address_space the reference address space
+ /// @param ty the referenced type
/// @param access the declared access mode
/// @return a Reference type. Repeated calls with the same arguments will
/// return the same pointer.
- const spirv::Reference* Reference(const Type* ty,
- builtin::AddressSpace address_space,
+ const spirv::Reference* Reference(builtin::AddressSpace address_space,
+ const Type* ty,
builtin::Access access = builtin::Access::kUndefined);
/// @param ty the element type
/// @param sz the number of elements in the vector
diff --git a/src/tint/reader/spirv/parser_type_test.cc b/src/tint/reader/spirv/parser_type_test.cc
index f864818..9ebb047 100644
--- a/src/tint/reader/spirv/parser_type_test.cc
+++ b/src/tint/reader/spirv/parser_type_test.cc
@@ -29,8 +29,8 @@
EXPECT_EQ(ty.U32(), ty.U32());
EXPECT_EQ(ty.F32(), ty.F32());
EXPECT_EQ(ty.I32(), ty.I32());
- EXPECT_EQ(ty.Pointer(ty.I32(), builtin::AddressSpace::kUndefined),
- ty.Pointer(ty.I32(), builtin::AddressSpace::kUndefined));
+ EXPECT_EQ(ty.Pointer(builtin::AddressSpace::kUndefined, ty.I32()),
+ ty.Pointer(builtin::AddressSpace::kUndefined, ty.I32()));
EXPECT_EQ(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 3));
EXPECT_EQ(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.I32(), 3, 2));
EXPECT_EQ(ty.Array(ty.I32(), 3, 2), ty.Array(ty.I32(), 3, 2));
@@ -54,10 +54,10 @@
Symbol sym_b(Symbol(2, {}, "2"));
TypeManager ty;
- EXPECT_NE(ty.Pointer(ty.I32(), builtin::AddressSpace::kUndefined),
- ty.Pointer(ty.U32(), builtin::AddressSpace::kUndefined));
- EXPECT_NE(ty.Pointer(ty.I32(), builtin::AddressSpace::kUndefined),
- ty.Pointer(ty.I32(), builtin::AddressSpace::kIn));
+ EXPECT_NE(ty.Pointer(builtin::AddressSpace::kUndefined, ty.I32()),
+ ty.Pointer(builtin::AddressSpace::kUndefined, ty.U32()));
+ EXPECT_NE(ty.Pointer(builtin::AddressSpace::kUndefined, ty.I32()),
+ ty.Pointer(builtin::AddressSpace::kIn, ty.I32()));
EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.U32(), 3));
EXPECT_NE(ty.Vector(ty.I32(), 3), ty.Vector(ty.I32(), 2));
EXPECT_NE(ty.Matrix(ty.I32(), 3, 2), ty.Matrix(ty.U32(), 3, 2));
diff --git a/src/tint/resolver/address_space_validation_test.cc b/src/tint/resolver/address_space_validation_test.cc
index a1b5ea8..fae4a25 100644
--- a/src/tint/resolver/address_space_validation_test.cc
+++ b/src/tint/resolver/address_space_validation_test.cc
@@ -67,8 +67,8 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Private_RuntimeArray) {
// type t : ptr<private, array<i32>>;
- Alias("t", ty.pointer(Source{{56, 78}}, ty.array(Source{{12, 34}}, ty.i32()),
- builtin::AddressSpace::kPrivate));
+ Alias("t", ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kPrivate,
+ ty.array(Source{{12, 34}}, ty.i32())));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -93,7 +93,7 @@
// struct S { m : array<i32> };
// type t = ptr<private, S>;
Structure("S", utils::Vector{Member(Source{{12, 34}}, "m", ty.array(ty.i32()))});
- Alias("t", ty.pointer(ty("S"), builtin::AddressSpace::kPrivate));
+ Alias("t", ty.ptr(builtin::AddressSpace::kPrivate, ty("S")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -115,7 +115,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Workgroup_RuntimeArray) {
// type t = ptr<workgroup, array<i32>>;
- Alias("t", ty.pointer(ty.array(Source{{12, 34}}, ty.i32()), builtin::AddressSpace::kWorkgroup));
+ Alias("t", ty.ptr(builtin::AddressSpace::kWorkgroup, ty.array(Source{{12, 34}}, ty.i32())));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -140,7 +140,7 @@
// struct S { m : array<i32> };
// type t = ptr<workgroup, S>;
Structure("S", utils::Vector{Member(Source{{12, 34}}, "m", ty.array(ty.i32()))});
- Alias(Source{{56, 78}}, "t", ty.pointer(ty("S"), builtin::AddressSpace::kWorkgroup));
+ Alias(Source{{56, 78}}, "t", ty.ptr(builtin::AddressSpace::kWorkgroup, ty("S")));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -165,7 +165,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_Bool) {
// type t = ptr<storage, bool>;
Alias(Source{{56, 78}}, "t",
- ty.pointer(ty.bool_(Source{{12, 34}}), builtin::AddressSpace::kStorage));
+ ty.ptr(builtin::AddressSpace::kStorage, ty.bool_(Source{{12, 34}})));
ASSERT_FALSE(r()->Resolve());
@@ -195,7 +195,7 @@
// type t = ptr<storage, a>;
Alias("a", ty.bool_());
Alias(Source{{56, 78}}, "t",
- ty.pointer(ty(Source{{12, 34}}, "a"), builtin::AddressSpace::kStorage));
+ ty.ptr(builtin::AddressSpace::kStorage, ty(Source{{12, 34}}, "a")));
ASSERT_FALSE(r()->Resolve());
@@ -208,7 +208,7 @@
TEST_F(ResolverAddressSpaceValidationTest, GlobalVariable_Storage_Pointer) {
// var<storage> g : ptr<private, f32>;
GlobalVar(Source{{56, 78}}, "g",
- ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32()),
builtin::AddressSpace::kStorage, Binding(0_a), Group(0_a));
ASSERT_FALSE(r()->Resolve());
@@ -221,9 +221,8 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_Pointer) {
// type t = ptr<storage, ptr<private, f32>>;
- Alias("t", ty.pointer(Source{{56, 78}},
- ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
- builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kStorage,
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32())));
ASSERT_FALSE(r()->Resolve());
@@ -242,7 +241,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_IntScalar) {
// type t = ptr<storage, i32;
- Alias("t", ty.pointer(ty.i32(), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.i32()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -262,7 +261,7 @@
// type t = ptr<storage, f16>;
Enable(builtin::Extension::kF16);
- Alias("t", ty.pointer(ty.f16(), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.f16()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -286,7 +285,7 @@
Enable(builtin::Extension::kF16);
Alias("a", ty.f16());
- Alias("t", ty.pointer(ty("a"), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty("a")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -300,7 +299,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_VectorF32) {
// type t = ptr<storage, vec4<f32>>;
- Alias("t", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.vec4<f32>()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -316,7 +315,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_VectorF16) {
// type t = ptr<storage, vec4<f16>>;
Enable(builtin::Extension::kF16);
- Alias("t", ty.pointer(ty.vec(ty.f16(), 4u), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.vec(ty.f16(), 4u)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -335,7 +334,7 @@
// struct S{ a : f32 };
// type t = ptr<storage, array<S, 3u>>;
Structure("S", utils::Vector{Member("a", ty.f32())});
- Alias("t", ty.pointer(ty.array(ty("S"), 3_u), builtin::AddressSpace::kStorage));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.array(ty("S"), 3_u)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -356,12 +355,12 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_ArrayF16) {
// enable f16;
// struct S{ a : f16 };
- // type t = ptr<storage, read, array<S, 3u>>;
+ // type t = ptr<storage, array<S, 3u>, read>;
Enable(builtin::Extension::kF16);
Structure("S", utils::Vector{Member("a", ty.f16())});
- Alias("t", ty.pointer(ty.array(ty("S"), 3_u), builtin::AddressSpace::kStorage,
- builtin::Access::kRead));
+ Alias("t",
+ ty.ptr(builtin::AddressSpace::kStorage, ty.array(ty("S"), 3_u), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -378,9 +377,9 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_StructI32) {
// struct S { x : i32 };
- // type t = ptr<storage, read, S>;
+ // type t = ptr<storage, S, read>;
Structure("S", utils::Vector{Member("x", ty.i32())});
- Alias("t", ty.pointer(ty("S"), builtin::AddressSpace::kStorage, builtin::Access::kRead));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty("S"), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -401,11 +400,11 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_StructI32Aliases) {
// struct S { x : i32 };
// type a1 = S;
- // type t = ptr<storage, read, a1>;
+ // type t = ptr<storage, a1, read>;
Structure("S", utils::Vector{Member("x", ty.i32())});
Alias("a1", ty("S"));
Alias("a2", ty("a1"));
- Alias("t", ty.pointer(ty("a2"), builtin::AddressSpace::kStorage, builtin::Access::kRead));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty("a2"), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -424,11 +423,11 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_StructF16) {
// struct S { x : f16 };
- // type t = ptr<storage, read, S>;
+ // type t = ptr<storage, S, read>;
Enable(builtin::Extension::kF16);
Structure("S", utils::Vector{Member("x", ty.f16())});
- Alias("t", ty.pointer(ty("S"), builtin::AddressSpace::kStorage, builtin::Access::kRead));
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty("S"), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -451,13 +450,13 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_StructF16Aliases) {
// struct S { x : f16 };
// type a1 = S;
- // type t = ptr<storage, read, a1>;
+ // type t = ptr<storage, a1, read>;
Enable(builtin::Extension::kF16);
Structure("S", utils::Vector{Member("x", ty.f16())});
Alias("a1", ty("S"));
Alias("a2", ty("a1"));
- Alias("g", ty.pointer(ty("a2"), builtin::AddressSpace::kStorage, builtin::Access::kRead));
+ Alias("g", ty.ptr(builtin::AddressSpace::kStorage, ty("a2"), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -476,8 +475,8 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_NotStorage_AccessMode) {
// type t = ptr<private, i32, read>;
- Alias("t", ty.pointer(Source{{12, 34}}, ty.i32(), builtin::AddressSpace::kPrivate,
- builtin::Access::kRead));
+ Alias("t", ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.i32(),
+ builtin::Access::kRead));
ASSERT_FALSE(r()->Resolve());
@@ -495,8 +494,8 @@
}
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_ReadAccessMode) {
- // type t = ptr<storage, read, i32>;
- Alias("t", ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kRead));
+ // type t = ptr<storage, i32, read>;
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.i32(), builtin::Access::kRead));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -510,8 +509,8 @@
}
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_ReadWriteAccessMode) {
- // type t = ptr<storage, read_write, i32>;
- Alias("t", ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ // type t = ptr<storage, i32, read_write>;
+ Alias("t", ty.ptr(builtin::AddressSpace::kStorage, ty.i32(), builtin::Access::kReadWrite));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -528,9 +527,9 @@
}
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_Storage_WriteAccessMode) {
- // type t = ptr<storage, read_write, i32>;
- Alias("t", ty.pointer(Source{{12, 34}}, ty.i32(), builtin::AddressSpace::kStorage,
- builtin::Access::kWrite));
+ // type t = ptr<storage, i32, read_write>;
+ Alias("t", ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kStorage, ty.i32(),
+ builtin::Access::kWrite));
ASSERT_FALSE(r()->Resolve());
@@ -562,7 +561,7 @@
Structure("S",
utils::Vector{Member(Source{{56, 78}}, "m", ty.array(Source{{12, 34}}, ty.i32()))});
- Alias("t", ty.pointer(Source{{90, 12}}, ty("S"), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(Source{{90, 12}}, builtin::AddressSpace::kUniform, ty("S")));
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
@@ -590,8 +589,8 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_UniformBufferBool) {
// type t = ptr<uniform, bool>;
- Alias("t", ty.pointer(Source{{56, 78}}, ty.bool_(Source{{12, 34}}),
- builtin::AddressSpace::kUniform));
+ Alias("t",
+ ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kUniform, ty.bool_(Source{{12, 34}})));
ASSERT_FALSE(r()->Resolve());
@@ -621,7 +620,7 @@
// type t = ptr<uniform, a>;
Alias("a", ty.bool_());
Alias("t",
- ty.pointer(Source{{56, 78}}, ty(Source{{12, 34}}, "a"), builtin::AddressSpace::kUniform));
+ ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kUniform, ty(Source{{12, 34}}, "a")));
ASSERT_FALSE(r()->Resolve());
@@ -634,7 +633,7 @@
TEST_F(ResolverAddressSpaceValidationTest, GlobalVariable_UniformPointer) {
// var<uniform> g : ptr<private, f32>;
GlobalVar(Source{{56, 78}}, "g",
- ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32()),
builtin::AddressSpace::kUniform, Binding(0_a), Group(0_a));
ASSERT_FALSE(r()->Resolve());
@@ -647,9 +646,8 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_UniformPointer) {
// type t = ptr<uniform, ptr<private, f32>>;
- Alias("t", ty.pointer(Source{{56, 78}},
- ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
- builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kUniform,
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32())));
ASSERT_FALSE(r()->Resolve());
@@ -669,7 +667,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_UniformBufferIntScalar) {
// type t = ptr<uniform, i32>;
- Alias("t", ty.pointer(ty.i32(), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.i32()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -689,7 +687,7 @@
// type t = ptr<uniform, f16>;
Enable(builtin::Extension::kF16);
- Alias("t", ty.pointer(ty.f16(), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.f16()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -703,7 +701,7 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_UniformBufferVectorF32) {
// type t = ptr<uniform, vec4<f32>>;
- Alias("t", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.vec4<f32>()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -723,7 +721,7 @@
// type t = ptr<uniform, vec4<f16>>;
Enable(builtin::Extension::kF16);
- Alias("t", ty.pointer(ty.vec4<f16>(), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.vec4<f16>()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -746,7 +744,7 @@
// }
// type t = ptr<uniform, array<S, 3u>>;
Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(16_a)})});
- Alias("t", ty.pointer(ty.array(ty("S"), 3_u), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.array(ty("S"), 3_u)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -775,7 +773,7 @@
Enable(builtin::Extension::kF16);
Structure("S", utils::Vector{Member("a", ty.f16(), utils::Vector{MemberSize(16_a)})});
- Alias("t", ty.pointer(ty.array(ty("S"), 3_u), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty.array(ty("S"), 3_u)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -793,7 +791,7 @@
// struct S { x : i32 };
// type t = ptr<uniform, S>;
Structure("S", utils::Vector{Member("x", ty.i32())});
- Alias("t", ty.pointer(ty("S"), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty("S")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -815,7 +813,7 @@
// type t = ptr<uniform, a1>;
Structure("S", utils::Vector{Member("x", ty.i32())});
Alias("a1", ty("S"));
- Alias("t", ty.pointer(ty("a1"), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty("a1")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -839,7 +837,7 @@
Enable(builtin::Extension::kF16);
Structure("S", utils::Vector{Member("x", ty.f16())});
- Alias("t", ty.pointer(ty("S"), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty("S")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -867,7 +865,7 @@
Structure("S", utils::Vector{Member("x", ty.f16())});
Alias("a1", ty("S"));
- Alias("t", ty.pointer(ty("a1"), builtin::AddressSpace::kUniform));
+ Alias("t", ty.ptr(builtin::AddressSpace::kUniform, ty("a1")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -891,7 +889,7 @@
// type t = ptr<push_constant, bool>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
Alias(Source{{56, 78}}, "t",
- ty.pointer(ty.bool_(Source{{12, 34}}), builtin::AddressSpace::kPushConstant));
+ ty.ptr(builtin::AddressSpace::kPushConstant, ty.bool_(Source{{12, 34}})));
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
@@ -919,7 +917,7 @@
// type t = ptr<push_constant, f16>;
Enable(builtin::Extension::kF16);
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
- Alias("t", ty.pointer(ty.f16(Source{{56, 78}}), builtin::AddressSpace::kPushConstant));
+ Alias("t", ty.ptr(builtin::AddressSpace::kPushConstant, ty.f16(Source{{56, 78}})));
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -931,7 +929,7 @@
// var<push_constant> g : ptr<private, f32>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
GlobalVar(Source{{56, 78}}, "g",
- ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32()),
builtin::AddressSpace::kPushConstant);
ASSERT_FALSE(r()->Resolve());
@@ -946,8 +944,8 @@
// type t = ptr<push_constant, ptr<private, f32>>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
Alias(Source{{56, 78}}, "t",
- ty.pointer(ty.pointer(Source{{12, 34}}, ty.f32(), builtin::AddressSpace::kPrivate),
- builtin::AddressSpace::kPushConstant));
+ ty.ptr(builtin::AddressSpace::kPushConstant,
+ ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kPrivate, ty.f32())));
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
@@ -969,7 +967,7 @@
// enable chromium_experimental_push_constant;
// type t = ptr<push_constant, i32>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
- Alias("t", ty.pointer(ty.i32(), builtin::AddressSpace::kPushConstant));
+ Alias("t", ty.ptr(builtin::AddressSpace::kPushConstant, ty.i32()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -987,7 +985,7 @@
// enable chromium_experimental_push_constant;
// var<push_constant> g : vec4<f32>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
- Alias("t", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kPushConstant));
+ Alias("t", ty.ptr(builtin::AddressSpace::kPushConstant, ty.vec4<f32>()));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
@@ -1009,7 +1007,7 @@
// type t = ptr<push_constant, array<S, 3u>>;
Enable(builtin::Extension::kChromiumExperimentalPushConstant);
Structure("S", utils::Vector{Member("a", ty.f32())});
- Alias("t", ty.pointer(ty.array(ty("S"), 3_u), builtin::AddressSpace::kPushConstant));
+ Alias("t", ty.ptr(builtin::AddressSpace::kPushConstant, ty.array(ty("S"), 3_u)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
diff --git a/src/tint/resolver/alias_analysis_test.cc b/src/tint/resolver/alias_analysis_test.cc
index 895ed88..7eab9bc 100644
--- a/src/tint/resolver/alias_analysis_test.cc
+++ b/src/tint/resolver/alias_analysis_test.cc
@@ -59,8 +59,8 @@
auto addrspace = GetParam().address_space;
Func("target",
utils::Vector{
- Param("p1", ty.pointer<i32>(addrspace)),
- Param("p2", ty.pointer<i32>(addrspace)),
+ Param("p1", ty.ptr<i32>(addrspace)),
+ Param("p2", ty.ptr<i32>(addrspace)),
},
ty.void_(), std::move(body));
if (GetParam().aliased && err) {
@@ -129,8 +129,8 @@
// f1(p1, p2);
Func("f2",
utils::Vector{
- Param("p1", ty.pointer<i32>(GetParam().address_space)),
- Param("p2", ty.pointer<i32>(GetParam().address_space)),
+ Param("p1", ty.ptr<i32>(GetParam().address_space)),
+ Param("p2", ty.ptr<i32>(GetParam().address_space)),
},
ty.void_(),
utils::Vector{
@@ -139,8 +139,8 @@
});
Func("f1",
utils::Vector{
- Param("p1", ty.pointer<i32>(GetParam().address_space)),
- Param("p2", ty.pointer<i32>(GetParam().address_space)),
+ Param("p1", ty.ptr<i32>(GetParam().address_space)),
+ Param("p2", ty.ptr<i32>(GetParam().address_space)),
},
ty.void_(),
utils::Vector{
@@ -166,7 +166,7 @@
// f2(p2);
Func("f1",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(GetParam().address_space)),
+ Param("p1", ty.ptr<i32>(GetParam().address_space)),
},
ty.void_(),
utils::Vector{
@@ -174,7 +174,7 @@
});
Func("f2",
utils::Vector<const ast::Parameter*, 4>{
- Param("p2", ty.pointer<i32>(GetParam().address_space)),
+ Param("p2", ty.ptr<i32>(GetParam().address_space)),
},
ty.void_(),
utils::Vector{
@@ -228,7 +228,7 @@
void Run(utils::Vector<const ast::Statement*, 4>&& body, const char* err = nullptr) {
Func("target",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(), std::move(body));
if (GetParam() && err) {
@@ -297,7 +297,7 @@
// f1(p1);
Func("f2",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -305,7 +305,7 @@
});
Func("f1",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -332,7 +332,7 @@
// f1(p1);
Func("f2",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -341,7 +341,7 @@
});
Func("f1",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -367,7 +367,7 @@
// f1(p1);
Func("f2",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -375,7 +375,7 @@
});
Func("f1",
utils::Vector<const ast::Parameter*, 4>{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -402,7 +402,7 @@
// f1(p1);
Func("f2",
utils::Vector{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -411,7 +411,7 @@
});
Func("f1",
utils::Vector{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -437,7 +437,7 @@
// f2();
Func("f1",
utils::Vector{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(),
utils::Vector{
@@ -489,8 +489,8 @@
void Run(const ast::Statement* stmt, const char* err = nullptr) {
Func("target",
utils::Vector{
- Param("p1", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
+ Param("p2", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -604,7 +604,7 @@
// foo(p2);
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.i32(),
utils::Vector{
@@ -661,8 +661,8 @@
void Run(const ast::Statement* stmt, const char* err = nullptr) {
Func("target",
utils::Vector{
- Param("p1", ty.pointer<bool>(builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer<bool>(builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr<bool>(builtin::AddressSpace::kFunction)),
+ Param("p2", ty.ptr<bool>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -726,8 +726,8 @@
Structure("S", utils::Vector{Member("a", ty.i32())});
Func("f2",
utils::Vector{
- Param("p1", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
+ Param("p2", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
},
ty.void_(),
utils::Vector{
@@ -755,8 +755,8 @@
Structure("S", utils::Vector{Member("a", ty.i32())});
Func("f2",
utils::Vector{
- Param("p1", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
+ Param("p2", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
},
ty.void_(),
utils::Vector{
@@ -787,8 +787,8 @@
Structure("S", utils::Vector{Member("a", ty.i32())});
Func("f2",
utils::Vector{
- Param("p1", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer(ty("S"), builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
+ Param("p2", ty.ptr(builtin::AddressSpace::kFunction, ty("S"))),
},
ty.void_(),
utils::Vector{
@@ -818,8 +818,8 @@
Structure("S", utils::Vector{Member("a", ty.i32())});
Func("f2",
utils::Vector{
- Param("p1", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction)),
- Param("p2", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction)),
+ Param("p1", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>())),
+ Param("p2", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>())),
},
ty.void_(),
utils::Vector{
@@ -850,7 +850,7 @@
// }
Func("f1",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -903,7 +903,7 @@
// }
Func("f2",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -911,7 +911,7 @@
});
Func("f3",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
diff --git a/src/tint/resolver/array_accessor_test.cc b/src/tint/resolver/array_accessor_test.cc
index e18a618..118bcff 100644
--- a/src/tint/resolver/array_accessor_test.cc
+++ b/src/tint/resolver/array_accessor_test.cc
@@ -316,7 +316,7 @@
// let x: f32 = (*p)[idx];
// return x;
// }
- auto* p = Param("p", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction));
+ auto* p = Param("p", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()));
auto* idx = Let("idx", ty.u32(), Call<u32>());
auto* star_p = Deref(p);
auto* acc = IndexAccessor(Source{{12, 34}}, star_p, idx);
@@ -337,7 +337,7 @@
// let x: f32 = *p[idx];
// return x;
// }
- auto* p = Param("p", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction));
+ auto* p = Param("p", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()));
auto* idx = Let("idx", ty.u32(), Call<u32>());
auto* accessor_expr = IndexAccessor(Source{{12, 34}}, p, idx);
auto* star_p = Deref(accessor_expr);
diff --git a/src/tint/resolver/assignment_validation_test.cc b/src/tint/resolver/assignment_validation_test.cc
index 1c57463..bd4cd71 100644
--- a/src/tint/resolver/assignment_validation_test.cc
+++ b/src/tint/resolver/assignment_validation_test.cc
@@ -183,8 +183,8 @@
// let b : ptr<function,i32> = &a;
// *b = 2i;
const auto func = builtin::AddressSpace::kFunction;
- WrapInFunction(Var("a", ty.i32(), func, Expr(2_i)), //
- Let("b", ty.pointer<i32>(func), AddressOf(Expr("a"))), //
+ WrapInFunction(Var("a", ty.i32(), func, Expr(2_i)), //
+ Let("b", ty.ptr<i32>(func), AddressOf(Expr("a"))), //
Assign(Deref("b"), 2_i));
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -196,7 +196,7 @@
// *b = 2;
const auto func = builtin::AddressSpace::kFunction;
auto* var_a = Var("a", ty.i32(), func, Expr(2_i));
- auto* var_b = Let("b", ty.pointer<i32>(func), AddressOf(Expr("a")));
+ auto* var_b = Let("b", ty.ptr<i32>(func), AddressOf(Expr("a")));
WrapInFunction(var_a, var_b, Assign(Deref("b"), 2_a));
EXPECT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc
index b00db68..bf533cc 100644
--- a/src/tint/resolver/call_validation_test.cc
+++ b/src/tint/resolver/call_validation_test.cc
@@ -103,7 +103,7 @@
// var z: i32 = 1i;
// foo(&z);
// }
- auto* param = Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction));
+ auto* param = Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
@@ -120,7 +120,7 @@
// let z: i32 = 1i;
// foo(&z);
// }
- auto* param = Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction));
+ auto* param = Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
@@ -142,7 +142,7 @@
auto* S = Structure("S", utils::Vector{
Member("m", ty.i32()),
});
- auto* param = Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction));
+ auto* param = Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
@@ -169,7 +169,7 @@
auto* S = Structure("S", utils::Vector{
Member("m", ty.i32()),
});
- auto* param = Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction));
+ auto* param = Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
@@ -189,7 +189,7 @@
auto* S = Structure("S", utils::Vector{
Member("m", ty.i32()),
});
- auto* param = Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction));
+ auto* param = Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
@@ -208,12 +208,12 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("bar",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -235,12 +235,12 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("bar",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
@@ -268,13 +268,13 @@
// }
Func("x",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.i32())),
- Decl(Let("p", ty.pointer(ty.i32(), builtin::AddressSpace::kFunction), AddressOf("v"))),
+ Decl(Let("p", ty.ptr(builtin::AddressSpace::kFunction, ty.i32()), AddressOf("v"))),
CallStmt(Call("x", "p")),
},
utils::Vector{
@@ -293,13 +293,13 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
},
ty.void_(), utils::Empty);
GlobalVar("v", ty.i32(), builtin::AddressSpace::kPrivate);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
- Decl(Let("p", ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate), AddressOf("v"))),
+ Decl(Let("p", ty.ptr(builtin::AddressSpace::kPrivate, ty.i32()), AddressOf("v"))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
@@ -318,13 +318,13 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
- Decl(Let("p", ty.pointer(ty.i32(), builtin::AddressSpace::kFunction),
+ Decl(Let("p", ty.ptr(builtin::AddressSpace::kFunction, ty.i32()),
AddressOf(IndexAccessor("v", 0_a)))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
@@ -349,13 +349,13 @@
Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
- Decl(Let("p", ty.pointer(ty.i32(), builtin::AddressSpace::kFunction),
+ Decl(Let("p", ty.ptr(builtin::AddressSpace::kFunction, ty.i32()),
AddressOf(IndexAccessor("v", 0_a)))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
@@ -377,7 +377,7 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer(ty.array<i32, 4>(), builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr(builtin::AddressSpace::kFunction, ty.array<i32, 4>())),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
@@ -406,7 +406,7 @@
// }
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
@@ -440,7 +440,7 @@
Enable(builtin::Extension::kChromiumExperimentalFullPtrParameters);
Func("foo",
utils::Vector{
- Param("p", ty.pointer<i32>(builtin::AddressSpace::kFunction)),
+ Param("p", ty.ptr<i32>(builtin::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
diff --git a/src/tint/resolver/compound_assignment_validation_test.cc b/src/tint/resolver/compound_assignment_validation_test.cc
index ee18de3..0343f08 100644
--- a/src/tint/resolver/compound_assignment_validation_test.cc
+++ b/src/tint/resolver/compound_assignment_validation_test.cc
@@ -53,7 +53,7 @@
// *b += 2;
const auto func = builtin::AddressSpace::kFunction;
auto* var_a = Var("a", ty.i32(), func, Expr(2_i));
- auto* var_b = Let("b", ty.pointer<i32>(func), AddressOf(Expr("a")));
+ auto* var_b = Let("b", ty.ptr<i32>(func), AddressOf(Expr("a")));
WrapInFunction(var_a, var_b,
CompoundAssign(Source{{12, 34}}, Deref("b"), 2_i, ast::BinaryOp::kAdd));
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index caf3e37..3fa71cd 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -438,7 +438,7 @@
}
}
} else {
- target_el_ty = type::Type::ElementOf(convert->target_ty);
+ target_el_ty = convert->target_ty->Elements(convert->target_ty).type;
}
// Convert the single splatted element type.
@@ -465,7 +465,7 @@
}
} else {
// Non-struct composites have the same type for all elements.
- auto* el_ty = type::Type::ElementOf(convert->target_ty);
+ auto* el_ty = convert->target_ty->Elements(convert->target_ty).type;
for (size_t i = 0; i < el_count; i++) {
auto* el = composite->Index(i);
pending.Push(ActionConvert{el, el_ty});
@@ -491,10 +491,8 @@
F&& f,
size_t index,
CONSTANTS&&... cs) {
- uint32_t n = 0;
- auto* ty = First(cs...)->Type();
- auto* el_ty = type::Type::ElementOf(ty, &n);
- if (el_ty == ty) {
+ auto [el_ty, n] = First(cs...)->Type()->Elements();
+ if (!el_ty) {
constexpr bool kHasIndexParam =
utils::traits::IsType<size_t, utils::traits::LastParameterType<F>>;
if constexpr (kHasIndexParam) {
@@ -503,11 +501,14 @@
return f(cs...);
}
}
+
+ auto* composite_el_ty = composite_ty->Elements(composite_ty).type;
+
utils::Vector<const constant::Value*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
- if (auto el = detail::TransformElements(builder, type::Type::ElementOf(composite_ty),
- std::forward<F>(f), index + i, cs->Index(i)...)) {
+ if (auto el = detail::TransformElements(builder, composite_el_ty, std::forward<F>(f),
+ index + i, cs->Index(i)...)) {
els.Push(el.Get());
} else {
@@ -541,16 +542,16 @@
F&& f,
const constant::Value* c0,
const constant::Value* c1) {
- uint32_t n0 = 0;
- type::Type::ElementOf(c0->Type(), &n0);
- uint32_t n1 = 0;
- type::Type::ElementOf(c1->Type(), &n1);
+ uint32_t n0 = c0->Type()->Elements(nullptr, 1).count;
+ uint32_t n1 = c1->Type()->Elements(nullptr, 1).count;
uint32_t max_n = std::max(n0, n1);
// If arity of both constants is 1, invoke callback
if (max_n == 1u) {
return f(c0, c1);
}
+ const auto* element_ty = composite_ty->Elements(composite_ty).type;
+
utils::Vector<const constant::Value*, 8> els;
els.Reserve(max_n);
for (uint32_t i = 0; i < max_n; i++) {
@@ -560,9 +561,8 @@
}
return c->Index(i);
};
- if (auto el = TransformBinaryElements(builder, type::Type::ElementOf(composite_ty),
- std::forward<F>(f), nested_or_self(c0, n0),
- nested_or_self(c1, n1))) {
+ if (auto el = TransformBinaryElements(builder, element_ty, std::forward<F>(f),
+ nested_or_self(c0, n0), nested_or_self(c1, n1))) {
els.Push(el.Get());
} else {
return el.Failure();
@@ -578,7 +578,7 @@
template <typename T>
ConstEval::Result ConstEval::CreateScalar(const Source& source, const type::Type* t, T v) {
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
- TINT_ASSERT(Resolver, t->is_scalar());
+ TINT_ASSERT(Resolver, t->Is<type::Scalar>());
if constexpr (IsFloatingPoint<T>) {
if (!std::isfinite(v.value)) {
@@ -1314,8 +1314,7 @@
ConstEval::Result ConstEval::Conv(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- uint32_t el_count = 0;
- auto* el_ty = type::Type::ElementOf(ty, &el_count);
+ auto* el_ty = ty->Elements(ty).type;
if (!el_ty) {
return nullptr;
}
@@ -1411,8 +1410,7 @@
return nullptr;
}
- uint32_t el_count = 0;
- type::Type::ElementOf(obj_expr->Type()->UnwrapRef(), &el_count);
+ uint32_t el_count = obj_expr->Type()->UnwrapRef()->Elements().count;
AInt idx = idx_val->ValueAs<AInt>();
if (idx < 0 || (el_count > 0 && idx >= el_count)) {
@@ -1464,7 +1462,7 @@
ConstEval::Result ConstEval::Bitcast(const type::Type* ty,
const constant::Value* value,
const Source& source) {
- auto* el_ty = type::Type::DeepestElementOf(ty);
+ auto* el_ty = ty->DeepestElement();
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto e) {
return Switch(
@@ -1746,7 +1744,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i == j);
+ return CreateScalar(source, ty->DeepestElement(), i == j);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -1759,7 +1757,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i != j);
+ return CreateScalar(source, ty->DeepestElement(), i != j);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -1772,7 +1770,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i < j);
+ return CreateScalar(source, ty->DeepestElement(), i < j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1785,7 +1783,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i > j);
+ return CreateScalar(source, ty->DeepestElement(), i > j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1798,7 +1796,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i <= j);
+ return CreateScalar(source, ty->DeepestElement(), i <= j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1811,7 +1809,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), i >= j);
+ return CreateScalar(source, ty->DeepestElement(), i >= j);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1822,16 +1820,18 @@
ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- // Note: Due to short-circuiting, this function is only called if lhs is true, so we could
- // technically only return the value of the rhs.
- return CreateScalar(source, ty, args[0]->ValueAs<bool>() && args[1]->ValueAs<bool>());
+ // Due to short-circuiting, this function is only called if lhs is true, so we only return the
+ // value of the rhs.
+ TINT_ASSERT(Resolver, args[0]->ValueAs<bool>());
+ return CreateScalar(source, ty, args[1]->ValueAs<bool>());
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- // Note: Due to short-circuiting, this function is only called if lhs is false, so we could
- // technically only return the value of the rhs.
+ // Due to short-circuiting, this function is only called if lhs is false, so we only only return
+ // the value of the rhs.
+ TINT_ASSERT(Resolver, !args[0]->ValueAs<bool>());
return CreateScalar(source, ty, args[1]->ValueAs<bool>());
}
@@ -1847,7 +1847,7 @@
} else { // integral
result = i & j;
}
- return CreateScalar(source, type::Type::DeepestElementOf(ty), result);
+ return CreateScalar(source, ty->DeepestElement(), result);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
@@ -1867,7 +1867,7 @@
} else { // integral
result = i | j;
}
- return CreateScalar(source, type::Type::DeepestElementOf(ty), result);
+ return CreateScalar(source, ty->DeepestElement(), result);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
@@ -1880,7 +1880,7 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), decltype(i){i ^ j});
+ return CreateScalar(source, ty->DeepestElement(), decltype(i){i ^ j});
};
return Dispatch_ia_iu32(create, c0, c1);
};
@@ -1972,12 +1972,12 @@
// Avoid UB by left shifting as unsigned value
auto result = static_cast<T>(static_cast<UT>(e1) << e2u);
- return CreateScalar(source, type::Type::DeepestElementOf(ty), NumberT{result});
+ return CreateScalar(source, ty->DeepestElement(), NumberT{result});
};
return Dispatch_ia_iu32(create, c0, c1);
};
- if (TINT_UNLIKELY(!type::Type::DeepestElementOf(args[1]->Type())->Is<type::U32>())) {
+ if (TINT_UNLIKELY(!args[1]->Type()->DeepestElement()->Is<type::U32>())) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "Element type of rhs of ShiftLeft must be a u32";
return utils::Failure;
@@ -2040,12 +2040,12 @@
result = e1 >> e2u;
}
}
- return CreateScalar(source, type::Type::DeepestElementOf(ty), NumberT{result});
+ return CreateScalar(source, ty->DeepestElement(), NumberT{result});
};
return Dispatch_ia_iu32(create, c0, c1);
};
- if (TINT_UNLIKELY(!type::Type::DeepestElementOf(args[1]->Type())->Is<type::U32>())) {
+ if (TINT_UNLIKELY(!args[1]->Type()->DeepestElement()->Is<type::U32>())) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "Element type of rhs of ShiftLeft must be a u32";
return utils::Failure;
@@ -2901,7 +2901,7 @@
}
}
- auto target_ty = type::Type::DeepestElementOf(ty);
+ auto target_ty = ty->DeepestElement();
auto r = std::ldexp(e1, static_cast<int>(e2));
return CreateScalar(source, target_ty, E1Type{r});
@@ -3067,7 +3067,7 @@
ConstEval::Result ConstEval::normalize(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* len_ty = type::Type::DeepestElementOf(ty);
+ auto* len_ty = ty->DeepestElement();
auto len = Length(source, len_ty, args[0]);
if (!len) {
AddNote("when calculating normalize", source);
@@ -3460,7 +3460,7 @@
auto cond = args[2]->ValueAs<bool>();
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto f, auto t) -> ConstEval::Result {
- return CreateScalar(source, type::Type::DeepestElementOf(ty), cond ? t : f);
+ return CreateScalar(source, ty->DeepestElement(), cond ? t : f);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -3475,7 +3475,7 @@
auto create = [&](auto f, auto t) -> ConstEval::Result {
// Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->ValueAs<bool>();
- return CreateScalar(source, type::Type::DeepestElementOf(ty), cond ? t : f);
+ return CreateScalar(source, ty->DeepestElement(), cond ? t : f);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -3667,7 +3667,7 @@
ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* inner_ty = type::Type::DeepestElementOf(ty);
+ auto* inner_ty = ty->DeepestElement();
auto e = args[0]->ValueAs<u32>().value;
utils::Vector<const constant::Value*, 2> els;
@@ -3695,7 +3695,7 @@
ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* inner_ty = type::Type::DeepestElementOf(ty);
+ auto* inner_ty = ty->DeepestElement();
auto e = args[0]->ValueAs<u32>().value;
utils::Vector<const constant::Value*, 2> els;
@@ -3715,7 +3715,7 @@
ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* inner_ty = type::Type::DeepestElementOf(ty);
+ auto* inner_ty = ty->DeepestElement();
auto e = args[0]->ValueAs<u32>().value;
utils::Vector<const constant::Value*, 2> els;
@@ -3734,7 +3734,7 @@
ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* inner_ty = type::Type::DeepestElementOf(ty);
+ auto* inner_ty = ty->DeepestElement();
auto e = args[0]->ValueAs<u32>().value;
utils::Vector<const constant::Value*, 4> els;
@@ -3754,7 +3754,7 @@
ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- auto* inner_ty = type::Type::DeepestElementOf(ty);
+ auto* inner_ty = ty->DeepestElement();
auto e = args[0]->ValueAs<u32>().value;
utils::Vector<const constant::Value*, 4> els;
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 6bf5830..e61599f 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -2239,6 +2239,84 @@
}
////////////////////////////////////////////////
+// Short-Circuit with RHS Variable Access
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSConstDecl) {
+ // const FALSE = false;
+ // const result = FALSE && FALSE;
+ GlobalConst("FALSE", Expr(false));
+ auto* binary = LogicalAnd("FALSE", "FALSE");
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSConstDecl) {
+ // const TRUE = true;
+ // const result = TRUE || TRUE;
+ GlobalConst("TRUE", Expr(true));
+ auto* binary = LogicalOr("TRUE", "TRUE");
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSLetDecl) {
+ // fn f() {
+ // let b = false;
+ // let result = false && b;
+ // }
+ auto* binary = LogicalAnd(false, "b");
+ WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSLetDecl) {
+ // fn f() {
+ // let b = false;
+ // let result = true || b;
+ // }
+ auto* binary = LogicalOr(true, "b");
+ WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSVarDecl) {
+ // fn f() {
+ // var b = false;
+ // let result = false && b;
+ // }
+ auto* binary = LogicalAnd(false, "b");
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSVarDecl) {
+ // fn f() {
+ // var b = false;
+ // let result = true || b;
+ // }
+ auto* binary = LogicalOr(true, "b");
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+////////////////////////////////////////////////
// Short-Circuit Swizzle
////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc
index d8efe8e..3d59514 100644
--- a/src/tint/resolver/const_eval_construction_test.cc
+++ b/src/tint/resolver/const_eval_construction_test.cc
@@ -154,7 +154,7 @@
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->AllZero());
- if (sem->Type()->is_scalar()) {
+ if (sem->Type()->Is<type::Scalar>()) {
EXPECT_EQ(sem->ConstantValue()->Index(0), nullptr);
EXPECT_EQ(sem->ConstantValue()->ValueAs<f32>(), 0.0f);
} else if (auto* vec = sem->Type()->As<type::Vector>()) {
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 1850d6f..ee509a7 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -1732,7 +1732,7 @@
GlobalVar(Sym(), ty.array(T, V));
GlobalVar(Sym(), ty.vec3(T));
GlobalVar(Sym(), ty.mat3x2(T));
- GlobalVar(Sym(), ty.pointer(T, builtin::AddressSpace::kPrivate));
+ GlobalVar(Sym(), ty.ptr(builtin::AddressSpace::kPrivate, T));
GlobalVar(Sym(), ty.sampled_texture(type::TextureDimension::k2d, T));
GlobalVar(Sym(), ty.depth_texture(type::TextureDimension::k2d));
GlobalVar(Sym(), ty.depth_multisampled_texture(type::TextureDimension::k2d));
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 6654c43..b97e7f3 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -926,7 +926,7 @@
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {
- auto ret_type = ty.pointer(Source{{12, 34}}, ty.i32(), builtin::AddressSpace::kFunction);
+ auto ret_type = ty.ptr(Source{{12, 34}}, builtin::AddressSpace::kFunction, ty.i32());
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/increment_decrement_validation_test.cc b/src/tint/resolver/increment_decrement_validation_test.cc
index c5a6c98..c3a0c78 100644
--- a/src/tint/resolver/increment_decrement_validation_test.cc
+++ b/src/tint/resolver/increment_decrement_validation_test.cc
@@ -65,7 +65,7 @@
// let b : ptr<function,i32> = &a;
// *b++;
auto* var_a = Var("a", ty.i32(), builtin::AddressSpace::kFunction);
- auto* var_b = Let("b", ty.pointer<i32>(builtin::AddressSpace::kFunction), AddressOf(Expr("a")));
+ auto* var_b = Let("b", ty.ptr<i32>(builtin::AddressSpace::kFunction), AddressOf(Expr("a")));
WrapInFunction(var_a, var_b, Increment(Source{{12, 34}}, Deref("b")));
EXPECT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 4295bb3..27cf5cf 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -586,7 +586,7 @@
}
const type::Pointer* build_ptr(MatchState& state, Number S, const type::Type* T, Number& A) {
- return state.builder.create<type::Pointer>(T, static_cast<builtin::AddressSpace>(S.Value()),
+ return state.builder.create<type::Pointer>(static_cast<builtin::AddressSpace>(S.Value()), T,
static_cast<builtin::Access>(A.Value()));
}
diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc
index 19ead21..1d5627c 100644
--- a/src/tint/resolver/intrinsic_table_test.cc
+++ b/src/tint/resolver/intrinsic_table_test.cc
@@ -230,7 +230,7 @@
TEST_F(IntrinsicTableTest, MatchPointer) {
auto* i32 = create<type::I32>();
auto* atomicI32 = create<type::Atomic>(i32);
- auto* ptr = create<type::Pointer>(atomicI32, builtin::AddressSpace::kWorkgroup,
+ auto* ptr = create<type::Pointer>(builtin::AddressSpace::kWorkgroup, atomicI32,
builtin::Access::kReadWrite);
auto result = table->Lookup(builtin::Function::kAtomicLoad, utils::Vector{ptr},
sem::EvaluationStage::kConstant, Source{});
@@ -255,7 +255,7 @@
auto* arr =
create<type::Array>(create<type::U32>(), create<type::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
auto* arr_ptr =
- create<type::Pointer>(arr, builtin::AddressSpace::kStorage, builtin::Access::kReadWrite);
+ create<type::Pointer>(builtin::AddressSpace::kStorage, arr, builtin::Access::kReadWrite);
auto result = table->Lookup(builtin::Function::kArrayLength, utils::Vector{arr_ptr},
sem::EvaluationStage::kConstant, Source{});
ASSERT_NE(result.sem, nullptr) << Diagnostics().str();
@@ -450,7 +450,7 @@
auto* f32 = create<type::F32>();
auto result = table->Lookup(builtin::Function::kCos,
utils::Vector{
- create<type::Reference>(f32, builtin::AddressSpace::kFunction,
+ create<type::Reference>(builtin::AddressSpace::kFunction, f32,
builtin::Access::kReadWrite),
},
sem::EvaluationStage::kConstant, Source{});
@@ -552,7 +552,7 @@
TEST_F(IntrinsicTableTest, MatchDifferentArgsElementType_Builtin_RuntimeEval) {
auto* af = create<type::AbstractFloat>();
- auto* bool_ref = create<type::Reference>(create<type::Bool>(), builtin::AddressSpace::kFunction,
+ auto* bool_ref = create<type::Reference>(builtin::AddressSpace::kFunction, create<type::Bool>(),
builtin::Access::kReadWrite);
auto result = table->Lookup(builtin::Function::kSelect, utils::Vector{af, af, bool_ref},
sem::EvaluationStage::kRuntime, Source{});
diff --git a/src/tint/resolver/is_host_shareable_test.cc b/src/tint/resolver/is_host_shareable_test.cc
index 4570d4d..32ff06d 100644
--- a/src/tint/resolver/is_host_shareable_test.cc
+++ b/src/tint/resolver/is_host_shareable_test.cc
@@ -95,7 +95,7 @@
}
TEST_F(ResolverIsHostShareable, Pointer) {
- auto* ptr = create<type::Pointer>(create<type::I32>(), builtin::AddressSpace::kPrivate,
+ auto* ptr = create<type::Pointer>(builtin::AddressSpace::kPrivate, create<type::I32>(),
builtin::Access::kReadWrite);
EXPECT_FALSE(r()->IsHostShareable(ptr));
}
diff --git a/src/tint/resolver/is_storeable_test.cc b/src/tint/resolver/is_storeable_test.cc
index 43abdbc..65c1ab9 100644
--- a/src/tint/resolver/is_storeable_test.cc
+++ b/src/tint/resolver/is_storeable_test.cc
@@ -78,7 +78,7 @@
}
TEST_F(ResolverIsStorableTest, Pointer) {
- auto* ptr = create<type::Pointer>(create<type::I32>(), builtin::AddressSpace::kPrivate,
+ auto* ptr = create<type::Pointer>(builtin::AddressSpace::kPrivate, create<type::I32>(),
builtin::Access::kReadWrite);
EXPECT_FALSE(r()->IsStorable(ptr));
}
@@ -112,7 +112,7 @@
TEST_F(ResolverIsStorableTest, Struct_SomeMembersNonStorable) {
Structure("S", utils::Vector{
Member("a", ty.i32()),
- Member("b", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Member("b", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
});
EXPECT_FALSE(r()->Resolve());
@@ -138,7 +138,7 @@
auto* non_storable =
Structure("nonstorable", utils::Vector{
Member("a", ty.i32()),
- Member("b", ty.pointer<i32>(builtin::AddressSpace::kPrivate)),
+ Member("b", ty.ptr<i32>(builtin::AddressSpace::kPrivate)),
});
Structure("S", utils::Vector{
Member("a", ty.i32()),
diff --git a/src/tint/resolver/ptr_ref_test.cc b/src/tint/resolver/ptr_ref_test.cc
index 1e702a5..ddb2e3b 100644
--- a/src/tint/resolver/ptr_ref_test.cc
+++ b/src/tint/resolver/ptr_ref_test.cc
@@ -76,15 +76,15 @@
GlobalVar("sb", ty.Of(buf), builtin::AddressSpace::kStorage, Binding(1_a), Group(0_a));
auto* function_ptr =
- Let("f_ptr", ty.pointer(ty.i32(), builtin::AddressSpace::kFunction), AddressOf(function));
+ Let("f_ptr", ty.ptr(builtin::AddressSpace::kFunction, ty.i32()), AddressOf(function));
auto* private_ptr =
- Let("p_ptr", ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate), AddressOf(private_));
+ Let("p_ptr", ty.ptr(builtin::AddressSpace::kPrivate, ty.i32()), AddressOf(private_));
auto* workgroup_ptr =
- Let("w_ptr", ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup), AddressOf(workgroup));
+ Let("w_ptr", ty.ptr(builtin::AddressSpace::kWorkgroup, ty.i32()), AddressOf(workgroup));
auto* uniform_ptr =
- Let("ub_ptr", ty.pointer(ty.Of(buf), builtin::AddressSpace::kUniform), AddressOf(uniform));
+ Let("ub_ptr", ty.ptr(builtin::AddressSpace::kUniform, ty.Of(buf)), AddressOf(uniform));
auto* storage_ptr =
- Let("sb_ptr", ty.pointer(ty.Of(buf), builtin::AddressSpace::kStorage), AddressOf(storage));
+ Let("sb_ptr", ty.ptr(builtin::AddressSpace::kStorage, ty.Of(buf)), AddressOf(storage));
WrapInFunction(function, function_ptr, private_ptr, workgroup_ptr, uniform_ptr, storage_ptr);
diff --git a/src/tint/resolver/ptr_ref_validation_test.cc b/src/tint/resolver/ptr_ref_validation_test.cc
index d042402..f6da233 100644
--- a/src/tint/resolver/ptr_ref_validation_test.cc
+++ b/src/tint/resolver/ptr_ref_validation_test.cc
@@ -147,8 +147,8 @@
builtin::Access::kReadWrite, Binding(0_a), Group(0_a));
auto* expr = IndexAccessor(MemberAccessor(MemberAccessor(storage, "inner"), "arr"), 2_i);
- auto* ptr = Let(Source{{12, 34}}, "p", ty.pointer<i32>(builtin::AddressSpace::kStorage),
- AddressOf(expr));
+ auto* ptr =
+ Let(Source{{12, 34}}, "p", ty.ptr<i32>(builtin::AddressSpace::kStorage), AddressOf(expr));
WrapInFunction(ptr);
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index fb06435..74a462a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -559,7 +559,7 @@
return nullptr;
}
- auto* var_ty = builder_->create<type::Reference>(storage_ty, address_space, access);
+ auto* var_ty = builder_->create<type::Reference>(address_space, storage_ty, access);
if (!ApplyAddressSpaceUsageToType(address_space, var_ty,
var->type ? var->type->source : var->source)) {
@@ -1869,7 +1869,7 @@
}
bool Resolver::ShouldMaterializeArgument(const type::Type* parameter_ty) const {
- const auto* param_el_ty = type::Type::DeepestElementOf(parameter_ty);
+ const auto* param_el_ty = parameter_ty->DeepestElement();
return param_el_ty && !param_el_ty->Is<type::AbstractNumeric>();
}
@@ -1939,7 +1939,7 @@
// If we're extracting from a reference, we return a reference.
if (auto* ref = obj_raw_ty->As<type::Reference>()) {
- ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
+ ty = builder_->create<type::Reference>(ref->AddressSpace(), ty, ref->Access());
}
const constant::Value* val = nullptr;
@@ -2583,7 +2583,7 @@
access = access_expr->Value();
}
- auto* out = b.create<type::Pointer>(store_ty, address_space, access);
+ auto* out = b.create<type::Pointer>(address_space, store_ty, access);
if (!validator_.Pointer(tmpl_ident, out)) {
return nullptr;
}
@@ -3057,9 +3057,16 @@
return Switch(
resolved_node, //
[&](sem::Variable* variable) -> sem::VariableUser* {
- auto symbol = ident->symbol;
- auto* user =
- builder_->create<sem::VariableUser>(expr, current_statement_, variable);
+ auto stage = variable->Stage();
+ const constant::Value* value = variable->ConstantValue();
+ if (skip_const_eval_.Contains(expr)) {
+ // This expression is short-circuited by an ancestor expression.
+ // Do not const-eval.
+ stage = sem::EvaluationStage::kNotEvaluated;
+ value = nullptr;
+ }
+ auto* user = builder_->create<sem::VariableUser>(expr, stage, current_statement_,
+ value, variable);
if (current_statement_) {
// If identifier is part of a loop continuing block, make sure it
@@ -3073,6 +3080,7 @@
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
+ auto symbol = ident->symbol;
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
@@ -3118,7 +3126,7 @@
// Note: The spec is currently vague around the rules here. See
// https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when
// resolved.
- std::string desc = "var '" + symbol.Name() + "' ";
+ std::string desc = "var '" + ident->symbol.Name() + "' ";
AddError(desc + "cannot be referenced at module-scope", expr->source);
AddNote(desc + "declared here", variable->Declaration()->source);
return nullptr;
@@ -3267,7 +3275,7 @@
// If we're extracting from a reference, we return a reference.
if (auto* ref = object_ty->As<type::Reference>()) {
- ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
+ ty = builder_->create<type::Reference>(ref->AddressSpace(), ty, ref->Access());
}
auto val = const_eval_.MemberAccess(object, member);
@@ -3336,7 +3344,7 @@
ty = vec->type();
// If we're extracting from a reference, we return a reference.
if (auto* ref = object_ty->As<type::Reference>()) {
- ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
+ ty = builder_->create<type::Reference>(ref->AddressSpace(), ty, ref->Access());
}
} else {
// The vector will have a number of components equal to the length of
@@ -3368,8 +3376,19 @@
if (!lhs || !rhs) {
return nullptr;
}
- auto* lhs_ty = lhs->Type()->UnwrapRef();
- auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+ // Load arguments if they are references
+ lhs = Load(lhs);
+ if (!lhs) {
+ return nullptr;
+ }
+ rhs = Load(rhs);
+ if (!rhs) {
+ return nullptr;
+ }
+
+ auto* lhs_ty = lhs->Type();
+ auto* rhs_ty = rhs->Type();
auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, stage, expr->source, false);
@@ -3389,16 +3408,6 @@
}
}
- // Load arguments if they are references
- lhs = Load(lhs);
- if (!lhs) {
- return nullptr;
- }
- rhs = Load(rhs);
- if (!rhs) {
- return nullptr;
- }
-
const constant::Value* value = nullptr;
if (skip_const_eval_.Contains(expr)) {
// This expression is short-circuited by an ancestor expression.
@@ -3471,7 +3480,7 @@
return nullptr;
}
- ty = builder_->create<type::Pointer>(ref->StoreType(), ref->AddressSpace(),
+ ty = builder_->create<type::Pointer>(ref->AddressSpace(), ref->StoreType(),
ref->Access());
root_ident = expr->RootIdentifier();
@@ -3483,7 +3492,7 @@
case ast::UnaryOp::kIndirection:
if (auto* ptr = expr_ty->As<type::Pointer>()) {
- ty = builder_->create<type::Reference>(ptr->StoreType(), ptr->AddressSpace(),
+ ty = builder_->create<type::Reference>(ptr->AddressSpace(), ptr->StoreType(),
ptr->Access());
root_ident = expr->RootIdentifier();
} else {
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index dfe42ae..eab9b86 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -803,7 +803,7 @@
auto* v = Expr("v");
auto* p = Expr("p");
auto* v_decl = Decl(Var("v", ty.f32()));
- auto* p_decl = Decl(Let("p", ty.pointer<f32>(builtin::AddressSpace::kFunction), AddressOf(v)));
+ auto* p_decl = Decl(Let("p", ty.ptr<f32>(builtin::AddressSpace::kFunction), AddressOf(v)));
auto* assign = Assign(Deref(p), 1.23_f);
Func("my_func", utils::Empty, ty.void_(),
utils::Vector{
@@ -2299,10 +2299,10 @@
Func("helper",
utils::Vector{
- Param("sl", ty.pointer(ty.sampler(type::SamplerKind::kSampler),
- builtin::AddressSpace::kFunction)),
- Param("tl", ty.pointer(ty.sampled_texture(type::TextureDimension::k2d, ty.f32()),
- builtin::AddressSpace::kFunction)),
+ Param("sl", ty.ptr(builtin::AddressSpace::kFunction,
+ ty.sampler(type::SamplerKind::kSampler))),
+ Param("tl", ty.ptr(builtin::AddressSpace::kFunction,
+ ty.sampled_texture(type::TextureDimension::k2d, ty.f32()))),
},
ty.vec4<f32>(),
utils::Vector{
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index c4aec90..706ba89 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -628,13 +628,13 @@
/// @param b the ProgramBuilder
/// @return a new AST alias type
static inline ast::Type AST(ProgramBuilder& b) {
- return b.ty.pointer(DataType<T>::AST(b), builtin::AddressSpace::kPrivate,
- builtin::Access::kUndefined);
+ return b.ty.ptr(builtin::AddressSpace::kPrivate, DataType<T>::AST(b),
+ builtin::Access::kUndefined);
}
/// @param b the ProgramBuilder
/// @return the semantic aliased type
static inline const type::Type* Sem(ProgramBuilder& b) {
- return b.create<type::Pointer>(DataType<T>::Sem(b), builtin::AddressSpace::kPrivate,
+ return b.create<type::Pointer>(builtin::AddressSpace::kPrivate, DataType<T>::Sem(b),
builtin::Access::kReadWrite);
}
diff --git a/src/tint/resolver/root_identifier_test.cc b/src/tint/resolver/root_identifier_test.cc
index 0353ff5..4ff43cb 100644
--- a/src/tint/resolver/root_identifier_test.cc
+++ b/src/tint/resolver/root_identifier_test.cc
@@ -142,7 +142,7 @@
// {
// let b = a;
// }
- auto* param = Param("a", ty.pointer(ty.f32(), builtin::AddressSpace::kFunction));
+ auto* param = Param("a", ty.ptr(builtin::AddressSpace::kFunction, ty.f32()));
auto* expr_param = Expr(param);
auto* let = Let("b", expr_param);
auto* expr_let = Expr("b");
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index fe9622b..de79f85 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -807,8 +807,8 @@
TEST_F(ResolverTypeValidationTest, PtrToRuntimeArrayAsPointerParameter_Fail) {
// fn func(a : ptr<workgroup, array<u32>>) {}
- auto* param = Param("a", ty.pointer(Source{{56, 78}}, ty.array(Source{{12, 34}}, ty.i32()),
- builtin::AddressSpace::kWorkgroup));
+ auto* param = Param("a", ty.ptr(Source{{56, 78}}, builtin::AddressSpace::kWorkgroup,
+ ty.array(Source{{12, 34}}, ty.i32())));
Func("func", utils::Vector{param}, ty.void_(),
utils::Vector{
@@ -881,7 +881,7 @@
}
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableTypeWithStride) {
- auto ptr_ty = ty.pointer<u32>(Source{{12, 34}}, builtin::AddressSpace::kUniform);
+ auto ptr_ty = ty.ptr<u32>(Source{{12, 34}}, builtin::AddressSpace::kUniform);
GlobalVar("arr", ty.array(ptr_ty, 4_i, utils::Vector{Stride(16)}),
builtin::AddressSpace::kPrivate);
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index 0bf2436..e890bcb 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -5301,8 +5301,8 @@
}
foo_body.Push(b.Decl(b.Let("rhs", rhs_init)));
for (int i = 0; i < 255; i++) {
- params.Push(b.Param("p" + std::to_string(i),
- ty.pointer(ty.i32(), builtin::AddressSpace::kFunction)));
+ params.Push(
+ b.Param("p" + std::to_string(i), ty.ptr(builtin::AddressSpace::kFunction, ty.i32())));
if (i > 0) {
foo_body.Push(b.Assign(b.Deref("p" + std::to_string(i)), "rhs"));
}
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index 6b84d85..ca958ad 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -380,7 +380,7 @@
// let x: f32 = (*p).z;
// return x;
// }
- auto* p = Param("p", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction));
+ auto* p = Param("p", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()));
auto* star_p = Deref(p);
auto* accessor_expr = MemberAccessor(star_p, "z");
auto* x = Var("x", ty.f32(), accessor_expr);
@@ -397,7 +397,7 @@
// let x: f32 = *p.z;
// return x;
// }
- auto* p = Param("p", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction));
+ auto* p = Param("p", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()));
auto* accessor_expr = MemberAccessor(p, Ident(Source{{12, 34}}, "z"));
auto* star_p = Deref(accessor_expr);
auto* x = Var("x", ty.f32(), star_p);
@@ -1234,9 +1234,8 @@
TEST_F(ResolverTest, Expr_Initializer_Cast_Pointer) {
auto* vf = Var("vf", ty.f32());
- auto* c =
- Call(Source{{12, 34}}, ty.pointer<i32>(builtin::AddressSpace::kFunction), ExprList(vf));
- auto* ip = Let("ip", ty.pointer<i32>(builtin::AddressSpace::kFunction), c);
+ auto* c = Call(Source{{12, 34}}, ty.ptr<i32>(builtin::AddressSpace::kFunction), ExprList(vf));
+ auto* ip = Let("ip", ty.ptr<i32>(builtin::AddressSpace::kFunction), c);
WrapInFunction(Decl(vf), Decl(ip));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 63f2b4d..eb009c2 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -199,8 +199,8 @@
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
bool Validator::IsPlain(const type::Type* type) const {
- return type->is_scalar() ||
- type->IsAnyOf<type::Atomic, type::Vector, type::Matrix, type::Array, type::Struct>();
+ return type->IsAnyOf<type::Scalar, type::Atomic, type::Vector, type::Matrix, type::Array,
+ type::Struct>();
}
// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
@@ -222,7 +222,7 @@
}
return true;
},
- [&](Default) { return type->is_scalar(); });
+ [&](Default) { return type->Is<type::Scalar>(); });
}
// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
@@ -432,7 +432,7 @@
// Among three host-shareable address spaces, f16 is supported in "uniform" and
// "storage" address space, but not "push_constant" address space yet.
- if (Is<type::F16>(type::Type::DeepestElementOf(store_ty)) &&
+ if (Is<type::F16>(store_ty->DeepestElement()) &&
address_space == builtin::AddressSpace::kPushConstant) {
AddError("using f16 types in 'push_constant' address space is not implemented yet", source);
return false;
@@ -522,7 +522,7 @@
// Since WGSL has no stride attribute, try to provide a useful hint for how the
// shader author can resolve the issue.
std::string hint;
- if (arr->ElemType()->is_scalar()) {
+ if (arr->ElemType()->Is<type::Scalar>()) {
hint = "Consider using a vector or struct as the element type instead.";
} else if (auto* vec = arr->ElemType()->As<type::Vector>();
vec && vec->type()->Size() == 4) {
@@ -754,7 +754,7 @@
}
}
- if (!storage_ty->is_scalar()) {
+ if (!storage_ty->Is<type::Scalar>()) {
AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a 'override'",
decl->source);
return false;
@@ -1842,7 +1842,7 @@
}
bool Validator::Vector(const type::Type* el_ty, const Source& source) const {
- if (!el_ty->is_scalar()) {
+ if (!el_ty->Is<type::Scalar>()) {
AddError("vector element type must be 'bool', 'f32', 'f16', 'i32' or 'u32'", source);
return false;
}
diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc
index d01d21c..826ac57 100644
--- a/src/tint/resolver/validator_is_storeable_test.cc
+++ b/src/tint/resolver/validator_is_storeable_test.cc
@@ -78,7 +78,7 @@
}
TEST_F(ValidatorIsStorableTest, Pointer) {
- auto* ptr = create<type::Pointer>(create<type::I32>(), builtin::AddressSpace::kPrivate,
+ auto* ptr = create<type::Pointer>(builtin::AddressSpace::kPrivate, create<type::I32>(),
builtin::Access::kReadWrite);
EXPECT_FALSE(v()->IsStorable(ptr));
}
diff --git a/src/tint/resolver/value_constructor_validation_test.cc b/src/tint/resolver/value_constructor_validation_test.cc
index 772b9c1..3b7ad64 100644
--- a/src/tint/resolver/value_constructor_validation_test.cc
+++ b/src/tint/resolver/value_constructor_validation_test.cc
@@ -99,7 +99,7 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident);
auto* expected =
- create<type::Reference>(params.create_rhs_sem_type(*this), builtin::AddressSpace::kFunction,
+ create<type::Reference>(builtin::AddressSpace::kFunction, params.create_rhs_sem_type(*this),
builtin::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
@@ -154,7 +154,7 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident);
auto* expected =
- create<type::Reference>(params.create_rhs_sem_type(*this), builtin::AddressSpace::kFunction,
+ create<type::Reference>(builtin::AddressSpace::kFunction, params.create_rhs_sem_type(*this),
builtin::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
@@ -203,7 +203,7 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(a_ident);
auto* expected =
- create<type::Reference>(params.create_rhs_sem_type(*this), builtin::AddressSpace::kFunction,
+ create<type::Reference>(builtin::AddressSpace::kFunction, params.create_rhs_sem_type(*this),
builtin::Access::kReadWrite);
ASSERT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"
<< "expected: " << FriendlyName(expected) << "\n";
diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc
index 5daee08..7095cfe 100644
--- a/src/tint/resolver/variable_test.cc
+++ b/src/tint/resolver/variable_test.cc
@@ -421,7 +421,7 @@
auto* b = Let("b", ty.bool_(), b_c);
auto* s = Let("s", ty.Of(S), s_c);
auto* a = Let("a", ty.Of(A), a_c);
- auto* p = Let("p", ty.pointer<i32>(builtin::AddressSpace::kFunction), p_c);
+ auto* p = Let("p", ty.ptr<i32>(builtin::AddressSpace::kFunction), p_c);
Func("F", utils::Empty, ty.void_(),
utils::Vector{
diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc
index 7f62e8e..9656574 100644
--- a/src/tint/resolver/variable_validation_test.cc
+++ b/src/tint/resolver/variable_validation_test.cc
@@ -132,7 +132,7 @@
// var i : i32;
// var p : pointer<function, i32> = &v;
auto* i = Var("i", ty.i32());
- auto* p = Var("a", ty.pointer<i32>(Source{{56, 78}}, builtin::AddressSpace::kFunction),
+ auto* p = Var("a", ty.ptr<i32>(Source{{56, 78}}, builtin::AddressSpace::kFunction),
builtin::AddressSpace::kUndefined, AddressOf(Source{{12, 34}}, "i"));
WrapInFunction(i, p);
@@ -227,7 +227,7 @@
// let b : ptr<function,f32> = a;
const auto priv = builtin::AddressSpace::kFunction;
auto* var_a = Var("a", ty.f32(), priv);
- auto* var_b = Let(Source{{12, 34}}, "b", ty.pointer<f32>(priv), Expr("a"));
+ auto* var_b = Let(Source{{12, 34}}, "b", ty.ptr<f32>(priv), Expr("a"));
WrapInFunction(var_a, var_b);
ASSERT_FALSE(r()->Resolve());
@@ -319,7 +319,7 @@
auto* expr = IndexAccessor(MemberAccessor(MemberAccessor(storage, "inner"), "arr"), 2_i);
auto* ptr = Let(Source{{12, 34}}, "p",
- ty.pointer<i32>(builtin::AddressSpace::kStorage, builtin::Access::kReadWrite),
+ ty.ptr<i32>(builtin::AddressSpace::kStorage, builtin::Access::kReadWrite),
AddressOf(expr));
WrapInFunction(ptr);
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index 2e3cb88..a220a80 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -86,13 +86,15 @@
Parameter::~Parameter() = default;
VariableUser::VariableUser(const ast::IdentifierExpression* declaration,
+ EvaluationStage stage,
Statement* statement,
+ const constant::Value* constant,
sem::Variable* variable)
: Base(declaration,
variable->Type(),
- variable->Stage(),
+ stage,
statement,
- variable->ConstantValue(),
+ constant,
/* has_side_effects */ false),
variable_(variable) {
auto* type = variable->Type();
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 4cc0b1e..7660df3 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -259,10 +259,14 @@
public:
/// Constructor
/// @param declaration the AST identifier node
+ /// @param stage the evaluation stage for an expression of this variable type
/// @param statement the statement that owns this expression
+ /// @param constant the constant value of the expression. May be null
/// @param variable the semantic variable
VariableUser(const ast::IdentifierExpression* declaration,
+ EvaluationStage stage,
Statement* statement,
+ const constant::Value* constant,
sem::Variable* variable);
~VariableUser() override;
diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc
index 86c5961..a41fa87 100644
--- a/src/tint/transform/manager_test.cc
+++ b/src/tint/transform/manager_test.cc
@@ -50,7 +50,7 @@
class IR_AddFunction final : public ir::transform::Transform {
void Run(ir::Module* mod, const DataMap&, DataMap&) const override {
ir::Builder builder(*mod);
- auto* func = builder.CreateFunction("ir_func", mod->Types().Get<type::Void>());
+ auto* func = builder.Function("ir_func", mod->Types().Get<type::Void>());
func->StartTarget()->SetInstructions(utils::Vector{builder.Return(func)});
mod->functions.Push(func);
}
@@ -67,7 +67,7 @@
ir::Module MakeIR() {
ir::Module mod;
ir::Builder builder(mod);
- auto* func = builder.CreateFunction("main", mod.Types().Get<type::Void>());
+ auto* func = builder.Function("main", mod.Types().Get<type::Void>());
func->StartTarget()->SetInstructions(utils::Vector{builder.Return(func)});
builder.ir.functions.Push(func);
return mod;
diff --git a/src/tint/type/abstract_float.cc b/src/tint/type/abstract_float.cc
index 41e0430..c08ad4b 100644
--- a/src/tint/type/abstract_float.cc
+++ b/src/tint/type/abstract_float.cc
@@ -26,10 +26,6 @@
AbstractFloat::~AbstractFloat() = default;
-bool AbstractFloat::Equals(const UniqueNode& other) const {
- return other.Is<AbstractFloat>();
-}
-
std::string AbstractFloat::FriendlyName() const {
return "abstract-float";
}
diff --git a/src/tint/type/abstract_float.h b/src/tint/type/abstract_float.h
index 10d6f3a..d3a9f2b 100644
--- a/src/tint/type/abstract_float.h
+++ b/src/tint/type/abstract_float.h
@@ -31,10 +31,6 @@
/// Destructor
~AbstractFloat() override;
- /// @param other the other type to compare against
- /// @returns true if this type is equal to the given type
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type when printed in diagnostics.
std::string FriendlyName() const override;
diff --git a/src/tint/type/abstract_int.cc b/src/tint/type/abstract_int.cc
index 6b97622..88750c6 100644
--- a/src/tint/type/abstract_int.cc
+++ b/src/tint/type/abstract_int.cc
@@ -25,10 +25,6 @@
AbstractInt::~AbstractInt() = default;
-bool AbstractInt::Equals(const UniqueNode& other) const {
- return other.Is<AbstractInt>();
-}
-
std::string AbstractInt::FriendlyName() const {
return "abstract-int";
}
diff --git a/src/tint/type/abstract_int.h b/src/tint/type/abstract_int.h
index aaaaadd..35430d1 100644
--- a/src/tint/type/abstract_int.h
+++ b/src/tint/type/abstract_int.h
@@ -31,10 +31,6 @@
/// Destructor
~AbstractInt() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type when printed in diagnostics.
std::string FriendlyName() const override;
diff --git a/src/tint/type/abstract_numeric.h b/src/tint/type/abstract_numeric.h
index 217b2a2..128bc34 100644
--- a/src/tint/type/abstract_numeric.h
+++ b/src/tint/type/abstract_numeric.h
@@ -17,13 +17,13 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/numeric_scalar.h"
namespace tint::type {
/// The base class for abstract-int and abstract-float types.
/// @see https://www.w3.org/TR/WGSL/#types-for-creation-time-constants
-class AbstractNumeric : public utils::Castable<AbstractNumeric, Type> {
+class AbstractNumeric : public utils::Castable<AbstractNumeric, NumericScalar> {
public:
/// Constructor
/// @param hash the unique hash of the node
diff --git a/src/tint/type/array.cc b/src/tint/type/array.cc
index 31893f8..47e1ea2 100644
--- a/src/tint/type/array.cc
+++ b/src/tint/type/array.cc
@@ -105,6 +105,22 @@
return size_;
}
+TypeAndCount Array::Elements(const Type* /* type_if_invalid = nullptr */,
+ uint32_t count_if_invalid /* = 0 */) const {
+ uint32_t n = count_if_invalid;
+ if (auto* const_count = count_->As<ConstantArrayCount>()) {
+ n = const_count->value;
+ }
+ return {element_, n};
+}
+
+const Type* Array::Element(uint32_t index) const {
+ if (auto* count = count_->As<ConstantArrayCount>()) {
+ return index < count->value ? element_ : nullptr;
+ }
+ return element_;
+}
+
Array* Array::Clone(CloneContext& ctx) const {
auto* elem_ty = element_->Clone(ctx);
auto* count = count_->Clone(ctx);
diff --git a/src/tint/type/array.h b/src/tint/type/array.h
index c400475..72a7479 100644
--- a/src/tint/type/array.h
+++ b/src/tint/type/array.h
@@ -97,6 +97,13 @@
/// declared in WGSL.
std::string FriendlyName() const override;
+ /// @copydoc Type::Elements
+ TypeAndCount Elements(const Type* type_if_invalid = nullptr,
+ uint32_t count_if_invalid = 0) const override;
+
+ /// @copydoc Type::Element
+ const Type* Element(uint32_t index) const override;
+
/// @param ctx the clone context
/// @returns a clone of this type
Array* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/bool.cc b/src/tint/type/bool.cc
index 26741c5..382f894 100644
--- a/src/tint/type/bool.cc
+++ b/src/tint/type/bool.cc
@@ -30,10 +30,6 @@
Bool::~Bool() = default;
-bool Bool::Equals(const UniqueNode& other) const {
- return other.Is<Bool>();
-}
-
std::string Bool::FriendlyName() const {
return "bool";
}
diff --git a/src/tint/type/bool.h b/src/tint/type/bool.h
index 33906b8..19c6bf8 100644
--- a/src/tint/type/bool.h
+++ b/src/tint/type/bool.h
@@ -17,7 +17,7 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/scalar.h"
// X11 likes to #define Bool leading to confusing error messages.
// If its defined, undefine it.
@@ -28,7 +28,7 @@
namespace tint::type {
/// A boolean type
-class Bool final : public utils::Castable<Bool, Type> {
+class Bool final : public utils::Castable<Bool, Scalar> {
public:
/// Constructor
Bool();
@@ -36,10 +36,6 @@
/// Destructor
~Bool() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName() const override;
diff --git a/src/tint/type/f16.cc b/src/tint/type/f16.cc
index 288f14a..485666f 100644
--- a/src/tint/type/f16.cc
+++ b/src/tint/type/f16.cc
@@ -30,10 +30,6 @@
F16::~F16() = default;
-bool F16::Equals(const UniqueNode& other) const {
- return other.Is<F16>();
-}
-
std::string F16::FriendlyName() const {
return "f16";
}
diff --git a/src/tint/type/f16.h b/src/tint/type/f16.h
index c7fc3eb..31fe336 100644
--- a/src/tint/type/f16.h
+++ b/src/tint/type/f16.h
@@ -17,12 +17,12 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/numeric_scalar.h"
namespace tint::type {
/// A float 16 type
-class F16 final : public utils::Castable<F16, Type> {
+class F16 final : public utils::Castable<F16, NumericScalar> {
public:
/// Constructor
F16();
@@ -30,10 +30,6 @@
/// Destructor
~F16() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName() const override;
diff --git a/src/tint/type/f32.cc b/src/tint/type/f32.cc
index e3afcc7..0b34816 100644
--- a/src/tint/type/f32.cc
+++ b/src/tint/type/f32.cc
@@ -30,10 +30,6 @@
F32::~F32() = default;
-bool F32::Equals(const UniqueNode& other) const {
- return other.Is<F32>();
-}
-
std::string F32::FriendlyName() const {
return "f32";
}
diff --git a/src/tint/type/f32.h b/src/tint/type/f32.h
index 68b3c77..c22da37 100644
--- a/src/tint/type/f32.h
+++ b/src/tint/type/f32.h
@@ -17,12 +17,12 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/numeric_scalar.h"
namespace tint::type {
/// A float 32 type
-class F32 final : public utils::Castable<F32, Type> {
+class F32 final : public utils::Castable<F32, NumericScalar> {
public:
/// Constructor
F32();
@@ -30,10 +30,6 @@
/// Destructor
~F32() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName() const override;
diff --git a/src/tint/type/i32.cc b/src/tint/type/i32.cc
index 66da527..f509b00 100644
--- a/src/tint/type/i32.cc
+++ b/src/tint/type/i32.cc
@@ -30,10 +30,6 @@
I32::~I32() = default;
-bool I32::Equals(const UniqueNode& other) const {
- return other.Is<I32>();
-}
-
std::string I32::FriendlyName() const {
return "i32";
}
diff --git a/src/tint/type/i32.h b/src/tint/type/i32.h
index 563ee69..9563349 100644
--- a/src/tint/type/i32.h
+++ b/src/tint/type/i32.h
@@ -17,12 +17,12 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/numeric_scalar.h"
namespace tint::type {
/// A signed int 32 type.
-class I32 final : public utils::Castable<I32, Type> {
+class I32 final : public utils::Castable<I32, NumericScalar> {
public:
/// Constructor
I32();
@@ -30,10 +30,6 @@
/// Destructor
~I32() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName() const override;
diff --git a/src/tint/type/manager.cc b/src/tint/type/manager.cc
index d9c8667..f5f382b 100644
--- a/src/tint/type/manager.cc
+++ b/src/tint/type/manager.cc
@@ -156,10 +156,10 @@
/* implicit stride */ elem_ty->Align());
}
-const type::Pointer* Manager::pointer(const type::Type* subtype,
- builtin::AddressSpace address_space,
- builtin::Access access) {
- return Get<type::Pointer>(subtype, address_space, access);
+const type::Pointer* Manager::ptr(builtin::AddressSpace address_space,
+ const type::Type* subtype,
+ builtin::Access access) {
+ return Get<type::Pointer>(address_space, subtype, access);
}
} // namespace tint::type
diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h
index fa29ae0..a9bdbeb 100644
--- a/src/tint/type/manager.h
+++ b/src/tint/type/manager.h
@@ -19,6 +19,7 @@
#include "src/tint/builtin/access.h"
#include "src/tint/builtin/address_space.h"
+#include "src/tint/number.h"
#include "src/tint/type/type.h"
#include "src/tint/type/unique_node.h"
#include "src/tint/utils/hash.h"
@@ -42,6 +43,9 @@
namespace tint::type {
+template <typename T>
+struct CppToType;
+
/// The type manager holds all the pointers to the known types.
class Manager final {
public:
@@ -78,18 +82,23 @@
return out;
}
+ /// Constructs or returns an existing type, unique node or node
/// @param args the arguments used to construct the type, unique node or node.
+ /// @tparam T a class deriving from type::Node, or a C-like type that's automatically translated
+ /// to the equivalent type node type. For example `Get<i32>()` is equivalent to
+ /// `Get<type::I32>()`
/// @return a pointer to an instance of `T` with the provided arguments.
- /// If NODE derives from UniqueNode and an existing instance of `T` has been
- /// constructed, then the same pointer is returned.
- template <typename NODE, typename... ARGS>
- NODE* Get(ARGS&&... args) {
- if constexpr (utils::traits::IsTypeOrDerived<NODE, Type>) {
- return types_.Get<NODE>(std::forward<ARGS>(args)...);
- } else if constexpr (utils::traits::IsTypeOrDerived<NODE, UniqueNode>) {
- return unique_nodes_.Get<NODE>(std::forward<ARGS>(args)...);
+ /// If `T` derives from UniqueNode and an existing instance of `T` has been constructed, then
+ /// the same pointer is returned.
+ template <typename T, typename... ARGS>
+ auto* Get(ARGS&&... args) {
+ using N = ToType<T>;
+ if constexpr (utils::traits::IsTypeOrDerived<N, Type>) {
+ return types_.Get<N>(std::forward<ARGS>(args)...);
+ } else if constexpr (utils::traits::IsTypeOrDerived<N, UniqueNode>) {
+ return unique_nodes_.Get<T>(std::forward<ARGS>(args)...);
} else {
- return nodes_.Create<NODE>(std::forward<ARGS>(args)...);
+ return nodes_.Create<T>(std::forward<ARGS>(args)...);
}
}
@@ -99,8 +108,8 @@
template <typename TYPE,
typename _ = std::enable_if<utils::traits::IsTypeOrDerived<TYPE, Type>>,
typename... ARGS>
- TYPE* Find(ARGS&&... args) const {
- return types_.Find<TYPE>(std::forward<ARGS>(args)...);
+ auto* Find(ARGS&&... args) const {
+ return types_.Find<ToType<TYPE>>(std::forward<ARGS>(args)...);
}
/// @returns a void type
@@ -133,17 +142,54 @@
const type::Vector* vec(const type::Type* inner, uint32_t size);
/// @param inner the inner type
- /// @returns the vector type
+ /// @returns a vec2 type with the element type @p inner
const type::Vector* vec2(const type::Type* inner);
/// @param inner the inner type
- /// @returns the vector type
+ /// @returns a vec3 type with the element type @p inner
const type::Vector* vec3(const type::Type* inner);
/// @param inner the inner type
- /// @returns the vector type
+ /// @returns a vec4 type with the element type @p inner
const type::Vector* vec4(const type::Type* inner);
+ /// @tparam T the element type
+ /// @tparam N the vector width
+ /// @returns the vector type
+ template <typename T, size_t N>
+ const type::Vector* vec() {
+ static_assert(N >= 2 && N <= 4);
+ switch (N) {
+ case 2:
+ return vec2<T>();
+ case 3:
+ return vec3<T>();
+ case 4:
+ return vec4<T>();
+ }
+ }
+
+ /// @tparam T the element type
+ /// @returns a vec2 with the element type `T`
+ template <typename T>
+ const type::Vector* vec2() {
+ return vec2(Get<T>());
+ }
+
+ /// @tparam T the element type
+ /// @returns a vec2 with the element type `T`
+ template <typename T>
+ const type::Vector* vec3() {
+ return vec3(Get<T>());
+ }
+
+ /// @tparam T the element type
+ /// @returns a vec2 with the element type `T`
+ template <typename T>
+ const type::Vector* vec4() {
+ return vec4(Get<T>());
+ }
+
/// @param inner the inner type
/// @param cols the number of columns
/// @param rows the number of rows
@@ -151,41 +197,104 @@
const type::Matrix* mat(const type::Type* inner, uint32_t cols, uint32_t rows);
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat2x2 with the element @p inner
const type::Matrix* mat2x2(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat2x2 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat2x2() {
+ return mat2x2(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat2x3 with the element @p inner
const type::Matrix* mat2x3(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat2x3 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat2x3() {
+ return mat2x3(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat2x4 with the element @p inner
const type::Matrix* mat2x4(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat2x4 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat2x4() {
+ return mat2x4(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat3x2 with the element @p inner
const type::Matrix* mat3x2(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat3x2 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat3x2() {
+ return mat3x2(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat3x3 with the element @p inner
const type::Matrix* mat3x3(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat3x3 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat3x3() {
+ return mat3x3(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat3x4 with the element @p inner
const type::Matrix* mat3x4(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat3x4 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat3x4() {
+ return mat3x4(Get<T>());
+ }
+
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat4x2 with the element @p inner
const type::Matrix* mat4x2(const type::Type* inner);
- /// @param inner the inner type
- /// @returns the matrix type
- const type::Matrix* mat4x3(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat4x2 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat4x2() {
+ return mat4x2(Get<T>());
+ }
/// @param inner the inner type
- /// @returns the matrix type
+ /// @returns a mat4x3 with the element @p inner
+ const type::Matrix* mat4x3(const type::Type* inner);
+
+ /// @tparam T the element type
+ /// @returns a mat4x3 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat4x3() {
+ return mat4x3(Get<T>());
+ }
+
+ /// @param inner the inner type
+ /// @returns a mat4x4 with the element @p inner
const type::Matrix* mat4x4(const type::Type* inner);
+ /// @tparam T the element type
+ /// @returns a mat4x4 with the element type `T`
+ template <typename T>
+ const type::Matrix* mat4x4() {
+ return mat4x4(Get<T>());
+ }
+
/// @param elem_ty the array element type
/// @param count the array element count
/// @param stride the optional array element stride
@@ -197,13 +306,37 @@
/// @returns the runtime array type
const type::Array* runtime_array(const type::Type* elem_ty, uint32_t stride = 0);
- /// @param subtype the pointer subtype
+ /// @returns an array type with the element type `T` and size `N`.
+ /// @tparam T the element type
+ /// @tparam N the array length. If zero, then constructs a runtime-sized array.
+ /// @param stride the optional array element stride
+ template <typename T, size_t N = 0>
+ const type::Array* array(uint32_t stride = 0) {
+ if constexpr (N == 0) {
+ return runtime_array(Get<T>(), stride);
+ } else {
+ return array(Get<T>(), N, stride);
+ }
+ }
+
/// @param address_space the address space
+ /// @param subtype the pointer subtype
/// @param access the access settings
/// @returns the pointer type
- const type::Pointer* pointer(const type::Type* subtype,
- builtin::AddressSpace address_space,
- builtin::Access access);
+ const type::Pointer* ptr(builtin::AddressSpace address_space,
+ const type::Type* subtype,
+ builtin::Access access);
+
+ /// @tparam SPACE the address space
+ /// @tparam T the storage type
+ /// @tparam ACCESS the access mode
+ /// @returns the pointer type with the templated address space, storage type and access.
+ template <builtin::AddressSpace SPACE,
+ typename T,
+ builtin::Access ACCESS = builtin::Access::kReadWrite>
+ const type::Pointer* ptr() {
+ return ptr(SPACE, Get<T>(), ACCESS);
+ }
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
@@ -211,6 +344,16 @@
TypeIterator end() const { return types_.end(); }
private:
+ /// ToType<T> is specialized for various `T` types and each specialization contains a single
+ /// `type` alias to the corresponding type deriving from `type::Type`.
+ template <typename T>
+ struct ToTypeImpl {
+ using type = T;
+ };
+
+ template <typename T>
+ using ToType = typename ToTypeImpl<T>::type;
+
/// Unique types owned by the manager
utils::UniqueAllocator<Type> types_;
/// Unique nodes (excluding types) owned by the manager
@@ -219,6 +362,46 @@
utils::BlockAllocator<Node> nodes_;
};
+//! @cond Doxygen_Suppress
+// Various template specializations for Manager::ToTypeImpl.
+template <>
+struct Manager::ToTypeImpl<AInt> {
+ using type = type::AbstractInt;
+};
+template <>
+struct Manager::ToTypeImpl<AFloat> {
+ using type = type::AbstractFloat;
+};
+template <>
+struct Manager::ToTypeImpl<i32> {
+ using type = type::I32;
+};
+template <>
+struct Manager::ToTypeImpl<u32> {
+ using type = type::U32;
+};
+template <>
+struct Manager::ToTypeImpl<f32> {
+ using type = type::F32;
+};
+template <>
+struct Manager::ToTypeImpl<f16> {
+ using type = type::F16;
+};
+template <>
+struct Manager::ToTypeImpl<bool> {
+ using type = type::Bool;
+};
+template <typename T>
+struct Manager::ToTypeImpl<const T> {
+ using type = const Manager::ToType<T>;
+};
+template <typename T>
+struct Manager::ToTypeImpl<T*> {
+ using type = Manager::ToType<T>*;
+};
+//! @endcond
+
} // namespace tint::type
#endif // SRC_TINT_TYPE_MANAGER_H_
diff --git a/src/tint/type/manager_test.cc b/src/tint/type/manager_test.cc
index 4fb2179..62d3044 100644
--- a/src/tint/type/manager_test.cc
+++ b/src/tint/type/manager_test.cc
@@ -15,6 +15,9 @@
#include "src/tint/type/manager.h"
#include "gtest/gtest.h"
+#include "src/tint/type/bool.h"
+#include "src/tint/type/f16.h"
+#include "src/tint/type/f32.h"
#include "src/tint/type/i32.h"
#include "src/tint/type/u32.h"
@@ -62,6 +65,29 @@
EXPECT_TRUE(t2->Is<U32>());
}
+TEST_F(ManagerTest, CppToType) {
+ Manager tm;
+ const Type* b1 = tm.Get<bool>();
+ const Type* b2 = tm.Get<Bool>();
+ ASSERT_EQ(b1, b2);
+
+ const Type* i1 = tm.Get<i32>();
+ const Type* i2 = tm.Get<I32>();
+ ASSERT_EQ(i1, i2);
+
+ const Type* u1 = tm.Get<u32>();
+ const Type* u2 = tm.Get<U32>();
+ ASSERT_EQ(u1, u2);
+
+ const Type* f1 = tm.Get<f32>();
+ const Type* f2 = tm.Get<F32>();
+ ASSERT_EQ(f1, f2);
+
+ const Type* h1 = tm.Get<f16>();
+ const Type* h2 = tm.Get<F16>();
+ ASSERT_EQ(h1, h2);
+}
+
TEST_F(ManagerTest, Find) {
Manager tm;
auto* created = tm.Get<I32>();
diff --git a/src/tint/type/matrix.cc b/src/tint/type/matrix.cc
index 35664f3..bce7335 100644
--- a/src/tint/type/matrix.cc
+++ b/src/tint/type/matrix.cc
@@ -69,6 +69,15 @@
return column_type_->Align();
}
+TypeAndCount Matrix::Elements(const Type* /* type_if_invalid = nullptr */,
+ uint32_t /* count_if_invalid = 0 */) const {
+ return {column_type_, columns_};
+}
+
+const Vector* Matrix::Element(uint32_t index) const {
+ return index < columns_ ? column_type_ : nullptr;
+}
+
Matrix* Matrix::Clone(CloneContext& ctx) const {
auto* col_ty = column_type_->Clone(ctx);
return ctx.dst.mgr->Get<Matrix>(col_ty, columns_);
diff --git a/src/tint/type/matrix.h b/src/tint/type/matrix.h
index 1519d6a..bce61ef 100644
--- a/src/tint/type/matrix.h
+++ b/src/tint/type/matrix.h
@@ -17,7 +17,7 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/vector.h"
// Forward declarations
namespace tint::type {
@@ -65,6 +65,13 @@
/// @returns the number of bytes between columns of the matrix
uint32_t ColumnStride() const;
+ /// @copydoc Type::Elements
+ TypeAndCount Elements(const Type* type_if_invalid = nullptr,
+ uint32_t count_if_invalid = 0) const override;
+
+ /// @copydoc Type::Element
+ const Vector* Element(uint32_t index) const override;
+
/// @param ctx the clone context
/// @returns a clone of this type
Matrix* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/numeric_scalar.cc b/src/tint/type/numeric_scalar.cc
new file mode 100644
index 0000000..40f1e14
--- /dev/null
+++ b/src/tint/type/numeric_scalar.cc
@@ -0,0 +1,25 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/type/numeric_scalar.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::type::NumericScalar);
+
+namespace tint::type {
+
+NumericScalar::NumericScalar(size_t hash, type::Flags flags) : Base(hash, flags) {}
+
+NumericScalar::~NumericScalar() = default;
+
+} // namespace tint::type
diff --git a/src/tint/type/numeric_scalar.h b/src/tint/type/numeric_scalar.h
new file mode 100644
index 0000000..4772cdc
--- /dev/null
+++ b/src/tint/type/numeric_scalar.h
@@ -0,0 +1,38 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TYPE_NUMERIC_SCALAR_H_
+#define SRC_TINT_TYPE_NUMERIC_SCALAR_H_
+
+#include "src/tint/type/scalar.h"
+
+namespace tint::type {
+
+/// Base class for all numeric-scalar types
+/// @see https://www.w3.org/TR/WGSL/#scalar-types
+class NumericScalar : public utils::Castable<NumericScalar, Scalar> {
+ public:
+ /// Destructor
+ ~NumericScalar() override;
+
+ protected:
+ /// Constructor
+ /// @param hash the immutable hash for the node
+ /// @param flags the flags of this type
+ NumericScalar(size_t hash, type::Flags flags);
+};
+
+} // namespace tint::type
+
+#endif // SRC_TINT_TYPE_NUMERIC_SCALAR_H_
diff --git a/src/tint/type/pointer.cc b/src/tint/type/pointer.cc
index 636f8e5..95f4556 100644
--- a/src/tint/type/pointer.cc
+++ b/src/tint/type/pointer.cc
@@ -25,7 +25,7 @@
namespace tint::type {
-Pointer::Pointer(const Type* subtype, builtin::AddressSpace address_space, builtin::Access access)
+Pointer::Pointer(builtin::AddressSpace address_space, const Type* subtype, builtin::Access access)
: Base(
utils::Hash(utils::TypeInfo::Of<Pointer>().full_hashcode, address_space, subtype, access),
type::Flags{}),
@@ -59,7 +59,7 @@
Pointer* Pointer::Clone(CloneContext& ctx) const {
auto* ty = subtype_->Clone(ctx);
- return ctx.dst.mgr->Get<Pointer>(ty, address_space_, access_);
+ return ctx.dst.mgr->Get<Pointer>(address_space_, ty, access_);
}
} // namespace tint::type
diff --git a/src/tint/type/pointer.h b/src/tint/type/pointer.h
index 80626e9..e22db02 100644
--- a/src/tint/type/pointer.h
+++ b/src/tint/type/pointer.h
@@ -27,10 +27,10 @@
class Pointer final : public utils::Castable<Pointer, Type> {
public:
/// Constructor
- /// @param subtype the pointee type
/// @param address_space the address space of the pointer
+ /// @param subtype the pointee type
/// @param access the resolved access control of the reference
- Pointer(const Type* subtype, builtin::AddressSpace address_space, builtin::Access access);
+ Pointer(builtin::AddressSpace address_space, const Type* subtype, builtin::Access access);
/// Destructor
~Pointer() override;
diff --git a/src/tint/type/pointer_test.cc b/src/tint/type/pointer_test.cc
index 6399242..322fac5 100644
--- a/src/tint/type/pointer_test.cc
+++ b/src/tint/type/pointer_test.cc
@@ -22,16 +22,16 @@
using PointerTest = TestHelper;
TEST_F(PointerTest, Creation) {
- auto* a = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* c = create<Pointer>(create<F32>(), builtin::AddressSpace::kStorage,
+ auto* c = create<Pointer>(builtin::AddressSpace::kStorage, create<F32>(),
builtin::Access::kReadWrite);
- auto* d = create<Pointer>(create<I32>(), builtin::AddressSpace::kPrivate,
+ auto* d = create<Pointer>(builtin::AddressSpace::kPrivate, create<I32>(),
builtin::Access::kReadWrite);
auto* e =
- create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage, builtin::Access::kRead);
+ create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(), builtin::Access::kRead);
EXPECT_TRUE(a->StoreType()->Is<I32>());
EXPECT_EQ(a->AddressSpace(), builtin::AddressSpace::kStorage);
@@ -44,25 +44,25 @@
}
TEST_F(PointerTest, Hash) {
- auto* a = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
EXPECT_EQ(a->unique_hash, b->unique_hash);
}
TEST_F(PointerTest, Equals) {
- auto* a = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* c = create<Pointer>(create<F32>(), builtin::AddressSpace::kStorage,
+ auto* c = create<Pointer>(builtin::AddressSpace::kStorage, create<F32>(),
builtin::Access::kReadWrite);
- auto* d = create<Pointer>(create<I32>(), builtin::AddressSpace::kPrivate,
+ auto* d = create<Pointer>(builtin::AddressSpace::kPrivate, create<I32>(),
builtin::Access::kReadWrite);
auto* e =
- create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage, builtin::Access::kRead);
+ create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(), builtin::Access::kRead);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@@ -73,18 +73,18 @@
TEST_F(PointerTest, FriendlyName) {
auto* r =
- create<Pointer>(create<I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kRead);
+ create<Pointer>(builtin::AddressSpace::kUndefined, create<I32>(), builtin::Access::kRead);
EXPECT_EQ(r->FriendlyName(), "ptr<i32, read>");
}
TEST_F(PointerTest, FriendlyNameWithAddressSpace) {
auto* r =
- create<Pointer>(create<I32>(), builtin::AddressSpace::kWorkgroup, builtin::Access::kRead);
+ create<Pointer>(builtin::AddressSpace::kWorkgroup, create<I32>(), builtin::Access::kRead);
EXPECT_EQ(r->FriendlyName(), "ptr<workgroup, i32, read>");
}
TEST_F(PointerTest, Clone) {
- auto* a = create<Pointer>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Pointer>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
type::Manager mgr;
diff --git a/src/tint/type/reference.cc b/src/tint/type/reference.cc
index e0fd92a..a8d26a5 100644
--- a/src/tint/type/reference.cc
+++ b/src/tint/type/reference.cc
@@ -24,8 +24,8 @@
namespace tint::type {
-Reference::Reference(const Type* subtype,
- builtin::AddressSpace address_space,
+Reference::Reference(builtin::AddressSpace address_space,
+ const Type* subtype,
builtin::Access access)
: Base(utils::Hash(utils::TypeInfo::Of<Reference>().full_hashcode,
address_space,
@@ -62,7 +62,7 @@
Reference* Reference::Clone(CloneContext& ctx) const {
auto* ty = subtype_->Clone(ctx);
- return ctx.dst.mgr->Get<Reference>(ty, address_space_, access_);
+ return ctx.dst.mgr->Get<Reference>(address_space_, ty, access_);
}
} // namespace tint::type
diff --git a/src/tint/type/reference.h b/src/tint/type/reference.h
index 617b2ca..4a1de4d 100644
--- a/src/tint/type/reference.h
+++ b/src/tint/type/reference.h
@@ -27,10 +27,10 @@
class Reference final : public utils::Castable<Reference, Type> {
public:
/// Constructor
- /// @param subtype the pointee type
/// @param address_space the address space of the reference
+ /// @param subtype the pointee type
/// @param access the resolved access control of the reference
- Reference(const Type* subtype, builtin::AddressSpace address_space, builtin::Access access);
+ Reference(builtin::AddressSpace address_space, const Type* subtype, builtin::Access access);
/// Destructor
~Reference() override;
diff --git a/src/tint/type/reference_test.cc b/src/tint/type/reference_test.cc
index a34ffa1..608e2ee 100644
--- a/src/tint/type/reference_test.cc
+++ b/src/tint/type/reference_test.cc
@@ -22,16 +22,16 @@
using ReferenceTest = TestHelper;
TEST_F(ReferenceTest, Creation) {
- auto* a = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* c = create<Reference>(create<F32>(), builtin::AddressSpace::kStorage,
+ auto* c = create<Reference>(builtin::AddressSpace::kStorage, create<F32>(),
builtin::Access::kReadWrite);
- auto* d = create<Reference>(create<I32>(), builtin::AddressSpace::kPrivate,
+ auto* d = create<Reference>(builtin::AddressSpace::kPrivate, create<I32>(),
builtin::Access::kReadWrite);
auto* e =
- create<Reference>(create<I32>(), builtin::AddressSpace::kStorage, builtin::Access::kRead);
+ create<Reference>(builtin::AddressSpace::kStorage, create<I32>(), builtin::Access::kRead);
EXPECT_TRUE(a->StoreType()->Is<I32>());
EXPECT_EQ(a->AddressSpace(), builtin::AddressSpace::kStorage);
@@ -44,25 +44,25 @@
}
TEST_F(ReferenceTest, Hash) {
- auto* a = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
EXPECT_EQ(a->unique_hash, b->unique_hash);
}
TEST_F(ReferenceTest, Equals) {
- auto* a = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* b = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* b = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
- auto* c = create<Reference>(create<F32>(), builtin::AddressSpace::kStorage,
+ auto* c = create<Reference>(builtin::AddressSpace::kStorage, create<F32>(),
builtin::Access::kReadWrite);
- auto* d = create<Reference>(create<I32>(), builtin::AddressSpace::kPrivate,
+ auto* d = create<Reference>(builtin::AddressSpace::kPrivate, create<I32>(),
builtin::Access::kReadWrite);
auto* e =
- create<Reference>(create<I32>(), builtin::AddressSpace::kStorage, builtin::Access::kRead);
+ create<Reference>(builtin::AddressSpace::kStorage, create<I32>(), builtin::Access::kRead);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@@ -73,18 +73,18 @@
TEST_F(ReferenceTest, FriendlyName) {
auto* r =
- create<Reference>(create<I32>(), builtin::AddressSpace::kUndefined, builtin::Access::kRead);
+ create<Reference>(builtin::AddressSpace::kUndefined, create<I32>(), builtin::Access::kRead);
EXPECT_EQ(r->FriendlyName(), "ref<i32, read>");
}
TEST_F(ReferenceTest, FriendlyNameWithAddressSpace) {
auto* r =
- create<Reference>(create<I32>(), builtin::AddressSpace::kWorkgroup, builtin::Access::kRead);
+ create<Reference>(builtin::AddressSpace::kWorkgroup, create<I32>(), builtin::Access::kRead);
EXPECT_EQ(r->FriendlyName(), "ref<workgroup, i32, read>");
}
TEST_F(ReferenceTest, Clone) {
- auto* a = create<Reference>(create<I32>(), builtin::AddressSpace::kStorage,
+ auto* a = create<Reference>(builtin::AddressSpace::kStorage, create<I32>(),
builtin::Access::kReadWrite);
type::Manager mgr;
diff --git a/src/tint/type/scalar.cc b/src/tint/type/scalar.cc
new file mode 100644
index 0000000..8e8d4b3
--- /dev/null
+++ b/src/tint/type/scalar.cc
@@ -0,0 +1,29 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/type/scalar.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::type::Scalar);
+
+namespace tint::type {
+
+Scalar::Scalar(size_t hash, type::Flags flags) : Base(hash, flags) {}
+
+Scalar::~Scalar() = default;
+
+bool Scalar::Equals(const UniqueNode& other) const {
+ return &other.TypeInfo() == &TypeInfo();
+}
+
+} // namespace tint::type
diff --git a/src/tint/type/scalar.h b/src/tint/type/scalar.h
new file mode 100644
index 0000000..5c1a1fc
--- /dev/null
+++ b/src/tint/type/scalar.h
@@ -0,0 +1,42 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TYPE_SCALAR_H_
+#define SRC_TINT_TYPE_SCALAR_H_
+
+#include "src/tint/type/type.h"
+
+namespace tint::type {
+
+/// Base class for all scalar types
+/// @see https://www.w3.org/TR/WGSL/#scalar-types
+class Scalar : public utils::Castable<Scalar, Type> {
+ public:
+ /// Destructor
+ ~Scalar() override;
+
+ /// @param other the other node to compare against
+ /// @returns true if the this type is equal to @p other
+ bool Equals(const UniqueNode& other) const override;
+
+ protected:
+ /// Constructor
+ /// @param hash the immutable hash for the node
+ /// @param flags the flags of this type
+ Scalar(size_t hash, type::Flags flags);
+};
+
+} // namespace tint::type
+
+#endif // SRC_TINT_TYPE_SCALAR_H_
diff --git a/src/tint/type/struct.cc b/src/tint/type/struct.cc
index c9864cb..9b3bf49 100644
--- a/src/tint/type/struct.cc
+++ b/src/tint/type/struct.cc
@@ -160,6 +160,15 @@
return ss.str();
}
+TypeAndCount Struct::Elements(const Type* type_if_invalid /* = nullptr */,
+ uint32_t /* count_if_invalid = 0 */) const {
+ return {type_if_invalid, static_cast<uint32_t>(members_.Length())};
+}
+
+const Type* Struct::Element(uint32_t index) const {
+ return index < members_.Length() ? members_[index]->Type() : nullptr;
+}
+
Struct* Struct::Clone(CloneContext& ctx) const {
auto sym = ctx.dst.st->Register(name_.Name());
diff --git a/src/tint/type/struct.h b/src/tint/type/struct.h
index dd2f7da..720beca 100644
--- a/src/tint/type/struct.h
+++ b/src/tint/type/struct.h
@@ -158,6 +158,13 @@
/// @note only structures returned by builtins may be abstract (e.g. modf, frexp)
utils::VectorRef<const Struct*> ConcreteTypes() const { return concrete_types_; }
+ /// @copydoc Type::Elements
+ TypeAndCount Elements(const Type* type_if_invalid = nullptr,
+ uint32_t count_if_invalid = 0) const override;
+
+ /// @copydoc Type::Element
+ const Type* Element(uint32_t index) const override;
+
/// @param ctx the clone context
/// @returns a clone of this type
Struct* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/type.cc b/src/tint/type/type.cc
index 7ed62fa..47f8bb4 100644
--- a/src/tint/type/type.cc
+++ b/src/tint/type/type.cc
@@ -67,16 +67,8 @@
return 0;
}
-bool Type::is_scalar() const {
- return IsAnyOf<F16, F32, U32, I32, AbstractNumeric, Bool>();
-}
-
-bool Type::is_numeric_scalar() const {
- return IsAnyOf<F16, F32, U32, I32, AbstractNumeric>();
-}
-
bool Type::is_float_scalar() const {
- return IsAnyOf<F16, F32, AbstractNumeric>();
+ return IsAnyOf<F16, F32, AbstractFloat>();
}
bool Type::is_float_matrix() const {
@@ -157,15 +149,15 @@
}
bool Type::is_numeric_vector() const {
- return Is([](const Vector* v) { return v->type()->is_numeric_scalar(); });
+ return Is([](const Vector* v) { return v->type()->Is<type::NumericScalar>(); });
}
bool Type::is_scalar_vector() const {
- return Is([](const Vector* v) { return v->type()->is_scalar(); });
+ return Is([](const Vector* v) { return v->type()->Is<type::Scalar>(); });
}
bool Type::is_numeric_scalar_or_vector() const {
- return is_numeric_scalar() || is_numeric_vector();
+ return Is<type::NumericScalar>() || is_numeric_vector();
}
bool Type::is_handle() const {
@@ -249,55 +241,24 @@
[&](Default) { return kNoConversion; });
}
-const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
- if (ty->is_scalar()) {
- if (count) {
- *count = 1;
- }
- return ty;
- }
- return Switch(
- ty, //
- [&](const Vector* v) {
- if (count) {
- *count = v->Width();
- }
- return v->type();
- },
- [&](const Matrix* m) {
- if (count) {
- *count = m->columns();
- }
- return m->ColumnType();
- },
- [&](const Array* a) {
- if (count) {
- if (auto* const_count = a->Count()->As<ConstantArrayCount>()) {
- *count = const_count->value;
- }
- }
- return a->ElemType();
- },
- [&](Default) {
- if (count) {
- *count = 1;
- }
- return ty;
- });
+TypeAndCount Type::Elements(const Type* type_if_invalid /* = nullptr */,
+ uint32_t count_if_invalid /* = 0 */) const {
+ return {type_if_invalid, count_if_invalid};
}
-const Type* Type::DeepestElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
- auto el_ty = ElementOf(ty, count);
- while (el_ty && ty != el_ty) {
- ty = el_ty;
+const Type* Type::Element(uint32_t /* index */) const {
+ return nullptr;
+}
- uint32_t n = 0;
- el_ty = ElementOf(ty, &n);
- if (count) {
- *count *= n;
+const Type* Type::DeepestElement() const {
+ const Type* ty = this;
+ while (true) {
+ auto [el, n] = ty->Elements();
+ if (!el) {
+ return ty;
}
+ ty = el;
}
- return el_ty;
}
const Type* Type::Common(utils::VectorRef<const Type*> types) {
diff --git a/src/tint/type/type.h b/src/tint/type/type.h
index b6f34a6..d7c0e0a 100644
--- a/src/tint/type/type.h
+++ b/src/tint/type/type.h
@@ -28,6 +28,9 @@
class ProgramBuilder;
class SymbolTable;
} // namespace tint
+namespace tint::type {
+class Type;
+} // namespace tint::type
namespace tint::type {
@@ -46,6 +49,22 @@
/// An alias to utils::EnumSet<Flag>
using Flags = utils::EnumSet<Flag>;
+/// TypeAndCount holds a type and count
+struct TypeAndCount {
+ /// The type
+ const Type* type = nullptr;
+ /// The count
+ uint32_t count = 0;
+};
+
+/// Equality operator.
+/// @param lhs the LHS TypeAndCount
+/// @param rhs the RHS TypeAndCount
+/// @returns true if the two TypeAndCounts have the same type and count
+inline bool operator==(TypeAndCount lhs, TypeAndCount rhs) {
+ return lhs.type == rhs.type && lhs.count == rhs.count;
+}
+
/// Base class for a type in the system
class Type : public utils::Castable<Type, UniqueNode> {
public:
@@ -93,10 +112,6 @@
/// @see https://www.w3.org/TR/WGSL/#fixed-footprint-types
inline bool HasFixedFootprint() const { return flags_.Contains(Flag::kFixedFootprint); }
- /// @returns true if this type is a scalar
- bool is_scalar() const;
- /// @returns true if this type is a numeric scalar
- bool is_numeric_scalar() const;
/// @returns true if this type is a float scalar
bool is_float_scalar() const;
/// @returns true if this type is a float matrix
@@ -163,24 +178,44 @@
/// @see https://www.w3.org/TR/WGSL/#conversion-rank
static uint32_t ConversionRank(const Type* from, const Type* to);
- /// @param ty the type to obtain the element type from
- /// @param count if not null, then this is assigned the number of child elements in the type.
- /// For example, the count of an `array<vec3<f32>, 5>` type would be 5.
- /// @returns
- /// * the element type if `ty` is a vector or array
- /// * the column type if `ty` is a matrix
- /// * `ty` if `ty` is none of the above
- static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
+ /// @param type_if_invalid the type to return if this type has no child elements.
+ /// @param count_if_invalid the count to return if this type has no child elements, or the
+ /// number is unbounded.
+ /// @returns The child element type and the the number of child elements held by this type.
+ /// If this type has no child element types, then @p invalid is returned.
+ /// If this type can hold a mix of different elements types (like a Struct), then
+ /// `[type_if_invalid, N]` is returned, where `N` is the number of elements.
+ /// If this type is unbounded in size (e.g. runtime sized arrays), then the returned count will
+ /// equal `count_if_invalid`.
+ ///
+ /// Examples:
+ /// * Elements() of `array<vec3<f32>, 5>` returns `[vec3<f32>, 5]`.
+ /// * Elements() of `array<f32>` returns `[f32, count_if_invalid]`.
+ /// * Elements() of `struct S { a : f32, b : i32 }` returns `[count_if_invalid, 2]`.
+ /// * Elements() of `struct S { a : i32, b : i32 }` also returns `[count_if_invalid, 2]`.
+ virtual TypeAndCount Elements(const Type* type_if_invalid = nullptr,
+ uint32_t count_if_invalid = 0) const;
- /// @param ty the type to obtain the deepest element type from
- /// @param count if not null, then this is assigned the full number of most deeply nested
- /// elements in the type. For example, the count of an `array<vec3<f32>, 5>` type would be 15.
- /// @returns
- /// * the element type if `ty` is a vector
- /// * the matrix element type if `ty` is a matrix
- /// * the deepest element type if `ty` is an array
- /// * `ty` if `ty` is none of the above
- static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr);
+ /// @param index the i'th element index to return
+ /// @returns The child element with the given index, or nullptr if the element does not exist.
+ ///
+ /// Examples:
+ /// * Element(1) of `mat3x2<f32>` returns `vec2<f32>`.
+ /// * Element(1) of `array<vec3<f32>, 5>` returns `vec3<f32>`.
+ /// * Element(0) of `struct S { a : f32, b : i32 }` returns `f32`.
+ /// * Element(0) of `f32` returns `nullptr`.
+ /// * Element(3) of `vec3<f32>` returns `nullptr`.
+ /// * Element(3) of `struct S { a : f32, b : i32 }` returns `nullptr`.
+ virtual const Type* Element(uint32_t index) const;
+
+ /// @returns the most deeply nested element of the type. For non-composite types,
+ /// DeepestElement() will return this type. Examples:
+ /// * Element() of `f32` returns `f32`.
+ /// * Element() of `vec3<f32>` returns `f32`.
+ /// * Element() of `mat3x2<f32>` returns `f32`.
+ /// * Element() of `array<vec3<f32>, 5>` returns `f32`.
+ /// * Element() of `struct S { a : f32, b : i32 }` returns `S`.
+ const Type* DeepestElement() const;
/// @param types the list of types
/// @returns the lowest-ranking type that all types in `types` can be implicitly converted to,
diff --git a/src/tint/type/type_test.cc b/src/tint/type/type_test.cc
index b76a43a..f260675 100644
--- a/src/tint/type/type_test.cc
+++ b/src/tint/type/type_test.cc
@@ -44,7 +44,7 @@
const Matrix* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
const Matrix* mat4x3_af = create<Matrix>(vec3_af, 4u);
const Reference* ref_u32 =
- create<Reference>(u32, builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite);
+ create<Reference>(builtin::AddressSpace::kPrivate, u32, builtin::Access::kReadWrite);
const Struct* str_f32 = create<Struct>(Sym("str_f32"),
utils::Vector{
create<StructMember>(
@@ -220,160 +220,104 @@
EXPECT_EQ(Type::ConversionRank(str_f16, str_af), Type::kNoConversion);
}
-TEST_F(TypeTest, ElementOf) {
- // No count
- EXPECT_TYPE(Type::ElementOf(f32), f32);
- EXPECT_TYPE(Type::ElementOf(f16), f16);
- EXPECT_TYPE(Type::ElementOf(i32), i32);
- EXPECT_TYPE(Type::ElementOf(u32), u32);
- EXPECT_TYPE(Type::ElementOf(vec2_f32), f32);
- EXPECT_TYPE(Type::ElementOf(vec3_f16), f16);
- EXPECT_TYPE(Type::ElementOf(vec4_f32), f32);
- EXPECT_TYPE(Type::ElementOf(vec3_u32), u32);
- EXPECT_TYPE(Type::ElementOf(vec3_i32), i32);
- EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32);
- EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32);
- EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16);
- EXPECT_TYPE(Type::ElementOf(str_f16), str_f16);
- EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
- EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
- EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16);
- EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af), mat4x3_af);
- EXPECT_TYPE(Type::ElementOf(arr_str_f16), str_f16);
-
- // With count
- uint32_t count = 42;
- EXPECT_TYPE(Type::ElementOf(f32, &count), f32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(f16, &count), f16);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(i32, &count), i32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(u32, &count), u32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(vec2_f32, &count), f32);
- EXPECT_EQ(count, 2u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(vec3_f16, &count), f16);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(vec4_f32, &count), f32);
- EXPECT_EQ(count, 4u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(vec3_u32, &count), u32);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), vec4_f32);
- EXPECT_EQ(count, 2u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), vec2_f32);
- EXPECT_EQ(count, 4u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16);
- EXPECT_EQ(count, 4u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(str_f16, &count), str_f16);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
- EXPECT_EQ(count, 5u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_vec3_i32, &count), vec3_i32);
- EXPECT_EQ(count, 5u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16, &count), mat4x3_f16);
- EXPECT_EQ(count, 5u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af, &count), mat4x3_af);
- EXPECT_EQ(count, 5u);
- count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_str_f16, &count), str_f16);
- EXPECT_EQ(count, 5u);
+TEST_F(TypeTest, Elements) {
+ EXPECT_EQ(f32->Elements(), (TypeAndCount{nullptr, 0u}));
+ EXPECT_EQ(f16->Elements(), (TypeAndCount{nullptr, 0u}));
+ EXPECT_EQ(i32->Elements(), (TypeAndCount{nullptr, 0u}));
+ EXPECT_EQ(u32->Elements(), (TypeAndCount{nullptr, 0u}));
+ EXPECT_EQ(vec2_f32->Elements(), (TypeAndCount{f32, 2u}));
+ EXPECT_EQ(vec3_f16->Elements(), (TypeAndCount{f16, 3u}));
+ EXPECT_EQ(vec4_f32->Elements(), (TypeAndCount{f32, 4u}));
+ EXPECT_EQ(vec3_u32->Elements(), (TypeAndCount{u32, 3u}));
+ EXPECT_EQ(vec3_i32->Elements(), (TypeAndCount{i32, 3u}));
+ EXPECT_EQ(mat2x4_f32->Elements(), (TypeAndCount{vec4_f32, 2u}));
+ EXPECT_EQ(mat4x2_f32->Elements(), (TypeAndCount{vec2_f32, 4u}));
+ EXPECT_EQ(mat4x3_f16->Elements(), (TypeAndCount{vec3_f16, 4u}));
+ EXPECT_EQ(str_f16->Elements(), (TypeAndCount{nullptr, 1u}));
+ EXPECT_EQ(arr_i32->Elements(), (TypeAndCount{i32, 5u}));
+ EXPECT_EQ(arr_vec3_i32->Elements(), (TypeAndCount{vec3_i32, 5u}));
+ EXPECT_EQ(arr_mat4x3_f16->Elements(), (TypeAndCount{mat4x3_f16, 5u}));
+ EXPECT_EQ(arr_mat4x3_af->Elements(), (TypeAndCount{mat4x3_af, 5u}));
+ EXPECT_EQ(arr_str_f16->Elements(), (TypeAndCount{str_f16, 5u}));
}
-TEST_F(TypeTest, DeepestElementOf) {
- // No count
- EXPECT_TYPE(Type::DeepestElementOf(f32), f32);
- EXPECT_TYPE(Type::DeepestElementOf(f16), f16);
- EXPECT_TYPE(Type::DeepestElementOf(i32), i32);
- EXPECT_TYPE(Type::DeepestElementOf(u32), u32);
- EXPECT_TYPE(Type::DeepestElementOf(vec2_f32), f32);
- EXPECT_TYPE(Type::DeepestElementOf(vec3_f16), f16);
- EXPECT_TYPE(Type::DeepestElementOf(vec4_f32), f32);
- EXPECT_TYPE(Type::DeepestElementOf(vec3_u32), u32);
- EXPECT_TYPE(Type::DeepestElementOf(vec3_i32), i32);
- EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32);
- EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32);
- EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16);
- EXPECT_TYPE(Type::DeepestElementOf(str_f16), str_f16);
- EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32);
- EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32);
- EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16);
- EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af), af);
- EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16), str_f16);
+TEST_F(TypeTest, ElementsWithCustomInvalid) {
+ EXPECT_EQ(f32->Elements(f32, 42), (TypeAndCount{f32, 42}));
+ EXPECT_EQ(f16->Elements(f16, 42), (TypeAndCount{f16, 42}));
+ EXPECT_EQ(i32->Elements(i32, 42), (TypeAndCount{i32, 42}));
+ EXPECT_EQ(u32->Elements(u32, 42), (TypeAndCount{u32, 42}));
+ EXPECT_EQ(vec2_f32->Elements(vec2_f32, 42), (TypeAndCount{f32, 2u}));
+ EXPECT_EQ(vec3_f16->Elements(vec3_f16, 42), (TypeAndCount{f16, 3u}));
+ EXPECT_EQ(vec4_f32->Elements(vec4_f32, 42), (TypeAndCount{f32, 4u}));
+ EXPECT_EQ(vec3_u32->Elements(vec3_u32, 42), (TypeAndCount{u32, 3u}));
+ EXPECT_EQ(vec3_i32->Elements(vec3_i32, 42), (TypeAndCount{i32, 3u}));
+ EXPECT_EQ(mat2x4_f32->Elements(mat2x4_f32, 42), (TypeAndCount{vec4_f32, 2u}));
+ EXPECT_EQ(mat4x2_f32->Elements(mat4x2_f32, 42), (TypeAndCount{vec2_f32, 4u}));
+ EXPECT_EQ(mat4x3_f16->Elements(mat4x3_f16, 42), (TypeAndCount{vec3_f16, 4u}));
+ EXPECT_EQ(str_f16->Elements(str_f16, 42), (TypeAndCount{str_f16, 1}));
+ EXPECT_EQ(arr_i32->Elements(arr_i32, 42), (TypeAndCount{i32, 5u}));
+ EXPECT_EQ(arr_vec3_i32->Elements(arr_vec3_i32, 42), (TypeAndCount{vec3_i32, 5u}));
+ EXPECT_EQ(arr_mat4x3_f16->Elements(arr_mat4x3_f16, 42), (TypeAndCount{mat4x3_f16, 5u}));
+ EXPECT_EQ(arr_mat4x3_af->Elements(arr_mat4x3_af, 42), (TypeAndCount{mat4x3_af, 5u}));
+ EXPECT_EQ(arr_str_f16->Elements(arr_str_f16, 42), (TypeAndCount{str_f16, 5u}));
+}
- // With count
- uint32_t count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(f32, &count), f32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(f16, &count), f16);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(i32, &count), i32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(u32, &count), u32);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(vec2_f32, &count), f32);
- EXPECT_EQ(count, 2u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(vec3_f16, &count), f16);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(vec4_f32, &count), f32);
- EXPECT_EQ(count, 4u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(vec3_u32, &count), u32);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(vec3_i32, &count), i32);
- EXPECT_EQ(count, 3u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32, &count), f32);
- EXPECT_EQ(count, 8u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32, &count), f32);
- EXPECT_EQ(count, 8u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16);
- EXPECT_EQ(count, 12u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(str_f16, &count), str_f16);
- EXPECT_EQ(count, 1u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32);
- EXPECT_EQ(count, 5u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32, &count), i32);
- EXPECT_EQ(count, 15u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16, &count), f16);
- EXPECT_EQ(count, 60u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af, &count), af);
- EXPECT_EQ(count, 60u);
- count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16, &count), str_f16);
- EXPECT_EQ(count, 5u);
+TEST_F(TypeTest, Element) {
+ EXPECT_TYPE(f32->Element(0), nullptr);
+ EXPECT_TYPE(f16->Element(1), nullptr);
+ EXPECT_TYPE(i32->Element(2), nullptr);
+ EXPECT_TYPE(u32->Element(3), nullptr);
+ EXPECT_TYPE(vec2_f32->Element(0), f32);
+ EXPECT_TYPE(vec2_f32->Element(1), f32);
+ EXPECT_TYPE(vec2_f32->Element(2), nullptr);
+ EXPECT_TYPE(vec3_f16->Element(0), f16);
+ EXPECT_TYPE(vec4_f32->Element(3), f32);
+ EXPECT_TYPE(vec4_f32->Element(4), nullptr);
+ EXPECT_TYPE(vec3_u32->Element(2), u32);
+ EXPECT_TYPE(vec3_u32->Element(3), nullptr);
+ EXPECT_TYPE(vec3_i32->Element(1), i32);
+ EXPECT_TYPE(vec3_i32->Element(4), nullptr);
+ EXPECT_TYPE(mat2x4_f32->Element(1), vec4_f32);
+ EXPECT_TYPE(mat2x4_f32->Element(2), nullptr);
+ EXPECT_TYPE(mat4x2_f32->Element(3), vec2_f32);
+ EXPECT_TYPE(mat4x2_f32->Element(4), nullptr);
+ EXPECT_TYPE(mat4x3_f16->Element(1), vec3_f16);
+ EXPECT_TYPE(mat4x3_f16->Element(5), nullptr);
+ EXPECT_TYPE(str_f16->Element(0), f16);
+ EXPECT_TYPE(str_f16->Element(1), nullptr);
+ EXPECT_TYPE(arr_i32->Element(0), i32);
+ EXPECT_TYPE(arr_i32->Element(4), i32);
+ EXPECT_TYPE(arr_i32->Element(5), nullptr);
+ EXPECT_TYPE(arr_vec3_i32->Element(4), vec3_i32);
+ EXPECT_TYPE(arr_vec3_i32->Element(5), nullptr);
+ EXPECT_TYPE(arr_mat4x3_f16->Element(1), mat4x3_f16);
+ EXPECT_TYPE(arr_mat4x3_f16->Element(10), nullptr);
+ EXPECT_TYPE(arr_mat4x3_af->Element(2), mat4x3_af);
+ EXPECT_TYPE(arr_mat4x3_af->Element(6), nullptr);
+ EXPECT_TYPE(arr_str_f16->Element(0), str_f16);
+ EXPECT_TYPE(arr_str_f16->Element(1), str_f16);
+ EXPECT_TYPE(arr_str_f16->Element(10), nullptr);
+}
+
+TEST_F(TypeTest, DeepestElement) {
+ EXPECT_TYPE(f32->DeepestElement(), f32);
+ EXPECT_TYPE(f16->DeepestElement(), f16);
+ EXPECT_TYPE(i32->DeepestElement(), i32);
+ EXPECT_TYPE(u32->DeepestElement(), u32);
+ EXPECT_TYPE(vec2_f32->DeepestElement(), f32);
+ EXPECT_TYPE(vec3_f16->DeepestElement(), f16);
+ EXPECT_TYPE(vec4_f32->DeepestElement(), f32);
+ EXPECT_TYPE(vec3_u32->DeepestElement(), u32);
+ EXPECT_TYPE(vec3_i32->DeepestElement(), i32);
+ EXPECT_TYPE(mat2x4_f32->DeepestElement(), f32);
+ EXPECT_TYPE(mat4x2_f32->DeepestElement(), f32);
+ EXPECT_TYPE(mat4x3_f16->DeepestElement(), f16);
+ EXPECT_TYPE(str_f16->DeepestElement(), str_f16);
+ EXPECT_TYPE(arr_i32->DeepestElement(), i32);
+ EXPECT_TYPE(arr_vec3_i32->DeepestElement(), i32);
+ EXPECT_TYPE(arr_mat4x3_f16->DeepestElement(), f16);
+ EXPECT_TYPE(arr_mat4x3_af->DeepestElement(), af);
+ EXPECT_TYPE(arr_str_f16->DeepestElement(), str_f16);
}
TEST_F(TypeTest, Common2) {
diff --git a/src/tint/type/u32.cc b/src/tint/type/u32.cc
index 4e40c73..79cd0d4 100644
--- a/src/tint/type/u32.cc
+++ b/src/tint/type/u32.cc
@@ -30,10 +30,6 @@
U32::~U32() = default;
-bool U32::Equals(const UniqueNode& other) const {
- return other.Is<U32>();
-}
-
std::string U32::FriendlyName() const {
return "u32";
}
diff --git a/src/tint/type/u32.h b/src/tint/type/u32.h
index de9b550..67ab20e 100644
--- a/src/tint/type/u32.h
+++ b/src/tint/type/u32.h
@@ -17,12 +17,12 @@
#include <string>
-#include "src/tint/type/type.h"
+#include "src/tint/type/numeric_scalar.h"
namespace tint::type {
/// A unsigned int 32 type.
-class U32 final : public utils::Castable<U32, Type> {
+class U32 final : public utils::Castable<U32, NumericScalar> {
public:
/// Constructor
U32();
@@ -30,10 +30,6 @@
/// Destructor
~U32() override;
- /// @param other the other node to compare against
- /// @returns true if the this type is equal to @p other
- bool Equals(const UniqueNode& other) const override;
-
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
std::string FriendlyName() const override;
diff --git a/src/tint/type/vector.cc b/src/tint/type/vector.cc
index 9c5bebd..49f0f1b 100644
--- a/src/tint/type/vector.cc
+++ b/src/tint/type/vector.cc
@@ -77,4 +77,13 @@
return ctx.dst.mgr->Get<Vector>(subtype, width_, packed_);
}
+TypeAndCount Vector::Elements(const Type* /* type_if_invalid = nullptr */,
+ uint32_t /* count_if_invalid = 0 */) const {
+ return {subtype_, width_};
+}
+
+const Type* Vector::Element(uint32_t index) const {
+ return index < width_ ? subtype_ : nullptr;
+}
+
} // namespace tint::type
diff --git a/src/tint/type/vector.h b/src/tint/type/vector.h
index 47b5c40..e83d917 100644
--- a/src/tint/type/vector.h
+++ b/src/tint/type/vector.h
@@ -64,6 +64,13 @@
/// @returns the alignment in bytes of a vector of the given width.
static uint32_t AlignOf(uint32_t width);
+ /// @copydoc Type::Elements
+ TypeAndCount Elements(const Type* type_if_invalid = nullptr,
+ uint32_t count_if_invalid = 0) const override;
+
+ /// @copydoc Type::Element
+ const Type* Element(uint32_t index) const override;
+
/// @param ctx the clone context
/// @returns a clone of this type
Vector* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h
index 2b433fb..0254149 100644
--- a/src/tint/utils/result.h
+++ b/src/tint/utils/result.h
@@ -104,6 +104,13 @@
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure value will result in UB.
+ SUCCESS_TYPE& Get() {
+ Validate();
+ return std::get<SUCCESS_TYPE>(value);
+ }
+
+ /// @returns the success value
+ /// @warning attempting to call this when the Result holds an failure value will result in UB.
SUCCESS_TYPE&& Move() {
Validate();
return std::get<SUCCESS_TYPE>(std::move(value));
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index 5cb0ff3..8b8fc0f 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -766,6 +766,34 @@
return out;
}
+namespace detail {
+
+/// IsVectorLike<T>::value is true if T is a utils::Vector or utils::VectorRef.
+template <typename T>
+struct IsVectorLike {
+ /// Non-specialized form of IsVectorLike defaults to false
+ static constexpr bool value = false;
+};
+
+/// IsVectorLike specialization for utils::Vector
+template <typename T, size_t N>
+struct IsVectorLike<utils::Vector<T, N>> {
+ /// True for the IsVectorLike specialization of utils::Vector
+ static constexpr bool value = true;
+};
+
+/// IsVectorLike specialization for utils::VectorRef
+template <typename T>
+struct IsVectorLike<utils::VectorRef<T>> {
+ /// True for the IsVectorLike specialization of utils::VectorRef
+ static constexpr bool value = true;
+};
+} // namespace detail
+
+/// True if T is a Vector<T, N> or VectorRef<T>
+template <typename T>
+static constexpr bool IsVectorLike = detail::IsVectorLike<T>::value;
+
} // namespace tint::utils
#endif // SRC_TINT_UTILS_VECTOR_H_
diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc
index 31e1a86..d65f570 100644
--- a/src/tint/utils/vector_test.cc
+++ b/src/tint/utils/vector_test.cc
@@ -81,6 +81,10 @@
static_assert(std::is_same_v<VectorCommonType<C2a*, const C2b*>, const C1*>);
static_assert(std::is_same_v<VectorCommonType<const C2a*, const C2b*>, const C1*>);
+static_assert(IsVectorLike<Vector<int, 3>>);
+static_assert(IsVectorLike<VectorRef<int>>);
+static_assert(!IsVectorLike<int>);
+
////////////////////////////////////////////////////////////////////////////////
// TintVectorTest
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 549f366..630e24f 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -505,7 +505,7 @@
}
void GeneratorImpl::EmitBinary(utils::StringStream& out, const ast::BinaryExpression* expr) {
- if (IsRelational(expr->op) && !TypeOf(expr->lhs)->UnwrapRef()->is_scalar()) {
+ if (IsRelational(expr->op) && !TypeOf(expr->lhs)->UnwrapRef()->Is<type::Scalar>()) {
EmitVectorRelational(out, expr);
return;
}
@@ -719,7 +719,7 @@
EmitExpression(out, expr->args[0]);
} else if ((builtin->Type() == builtin::Function::kAny ||
builtin->Type() == builtin::Function::kAll) &&
- TypeOf(expr->args[0])->UnwrapRef()->is_scalar()) {
+ TypeOf(expr->args[0])->UnwrapRef()->Is<type::Scalar>()) {
// GLSL does not support any() or all() on scalar arguments. It's a no-op.
EmitExpression(out, expr->args[0]);
} else if (builtin->IsBarrier()) {
@@ -1082,7 +1082,7 @@
void GeneratorImpl::EmitDegreesCall(utils::StringStream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin) {
- auto* return_elem_type = type::Type::DeepestElementOf(builtin->ReturnType());
+ auto* return_elem_type = builtin->ReturnType()->DeepestElement();
const std::string suffix = Is<type::F16>(return_elem_type) ? "hf" : "f";
CallBuiltinHelper(out, expr, builtin,
[&](TextBuffer* b, const std::vector<std::string>& params) {
@@ -1094,7 +1094,7 @@
void GeneratorImpl::EmitRadiansCall(utils::StringStream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin) {
- auto* return_elem_type = type::Type::DeepestElementOf(builtin->ReturnType());
+ auto* return_elem_type = builtin->ReturnType()->DeepestElement();
const std::string suffix = Is<type::F16>(return_elem_type) ? "hf" : "f";
CallBuiltinHelper(out, expr, builtin,
[&](TextBuffer* b, const std::vector<std::string>& params) {
@@ -1181,8 +1181,7 @@
auto* texture_type = TypeOf(texture)->UnwrapRef()->As<type::Texture>();
auto emit_signed_int_type = [&](const type::Type* ty) {
- uint32_t width = 0;
- type::Type::ElementOf(ty, &width);
+ uint32_t width = ty->Elements().count;
if (width > 1) {
out << "ivec" << width;
} else {
@@ -1191,8 +1190,7 @@
};
auto emit_unsigned_int_type = [&](const type::Type* ty) {
- uint32_t width = 0;
- type::Type::ElementOf(ty, &width);
+ uint32_t width = ty->Elements().count;
if (width > 1) {
out << "uvec" << width;
} else {
@@ -2706,7 +2704,7 @@
out << "~";
break;
case ast::UnaryOp::kNot:
- if (TypeOf(expr)->UnwrapRef()->is_scalar()) {
+ if (TypeOf(expr)->UnwrapRef()->Is<type::Scalar>()) {
out << "!";
} else {
out << "not";
diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc
index a0bfef9..34375f7 100644
--- a/src/tint/writer/glsl/generator_impl_function_test.cc
+++ b/src/tint/writer/glsl/generator_impl_function_test.cc
@@ -110,8 +110,8 @@
// fn f(foo : ptr<function, f32>) -> f32 {
// return *foo;
// }
- Func("f", utils::Vector{Param("foo", ty.pointer<f32>(builtin::AddressSpace::kFunction))},
- ty.f32(), utils::Vector{Return(Deref("foo"))});
+ Func("f", utils::Vector{Param("foo", ty.ptr<f32>(builtin::AddressSpace::kFunction))}, ty.f32(),
+ utils::Vector{Return(Deref("foo"))});
GeneratorImpl& gen = SanitizeAndBuild();
gen.Generate();
diff --git a/src/tint/writer/glsl/generator_impl_sanitizer_test.cc b/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
index 5a690a5..fdd216f 100644
--- a/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
+++ b/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
@@ -235,7 +235,7 @@
// let p : ptr<function, i32> = &v;
// let x : i32 = *p;
auto* v = Var("v", ty.i32());
- auto* p = Let("p", ty.pointer<i32>(builtin::AddressSpace::kFunction), AddressOf(v));
+ auto* p = Let("p", ty.ptr<i32>(builtin::AddressSpace::kFunction), AddressOf(v));
auto* x = Var("x", ty.i32(), Deref(p));
Func("main", utils::Empty, ty.void_(),
@@ -276,12 +276,11 @@
// let vp : ptr<function, vec4<f32>> = &(*mp)[2i];
// let v : vec4<f32> = *vp;
auto* a = Var("a", ty.array(ty.mat4x4<f32>(), 4_u));
- auto* ap =
- Let("ap", ty.pointer(ty.array(ty.mat4x4<f32>(), 4_u), builtin::AddressSpace::kFunction),
- AddressOf(a));
- auto* mp = Let("mp", ty.pointer(ty.mat4x4<f32>(), builtin::AddressSpace::kFunction),
+ auto* ap = Let("ap", ty.ptr(builtin::AddressSpace::kFunction, ty.array(ty.mat4x4<f32>(), 4_u)),
+ AddressOf(a));
+ auto* mp = Let("mp", ty.ptr(builtin::AddressSpace::kFunction, ty.mat4x4<f32>()),
AddressOf(IndexAccessor(Deref(ap), 3_i)));
- auto* vp = Let("vp", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction),
+ auto* vp = Let("vp", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()),
AddressOf(IndexAccessor(Deref(mp), 2_i)));
auto* v = Var("v", ty.vec4<f32>(), Deref(vp));
diff --git a/src/tint/writer/hlsl/generator.h b/src/tint/writer/hlsl/generator.h
index 804b449..9fbc01c 100644
--- a/src/tint/writer/hlsl/generator.h
+++ b/src/tint/writer/hlsl/generator.h
@@ -80,6 +80,9 @@
/// Set to `true` to generate polyfill for `reflect` builtin for vec2<f32>
bool polyfill_reflect_vec2_f32 = false;
+ /// The binding points that will be ignored in the rebustness transform.
+ std::vector<sem::BindingPoint> binding_points_ignored_in_robustness_transform;
+
/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(disable_robustness,
root_constant_binding_point,
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index eb07e8e..def946a 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -192,6 +192,12 @@
// Robustness must come after PromoteSideEffectsToDecl
// Robustness must come before BuiltinPolyfill and CanonicalizeEntryPointIO
manager.Add<ast::transform::Robustness>();
+
+ ast::transform::Robustness::Config config = {};
+ config.bindings_ignored = std::unordered_set<sem::BindingPoint>(
+ options.binding_points_ignored_in_robustness_transform.cbegin(),
+ options.binding_points_ignored_in_robustness_transform.cend());
+ data.Add<ast::transform::Robustness::Config>(config);
}
// Note: it is more efficient for MultiplanarExternalTexture to come after Robustness
@@ -1101,7 +1107,7 @@
// vector dimension using .x
const bool is_single_value_vector_init = type->is_scalar_vector() &&
call->Arguments().Length() == 1 &&
- ctor->Parameters()[0]->Type()->is_scalar();
+ ctor->Parameters()[0]->Type()->Is<type::Scalar>();
if (brackets) {
out << "{";
@@ -1732,6 +1738,10 @@
return true;
}
case Op::kAtomicCompareExchangeWeak: {
+ if (!EmitStructType(&helpers_, result_ty->As<type::Struct>())) {
+ return false;
+ }
+
auto* const value_ty = sem_func->Parameters()[1]->Type()->UnwrapRef();
// NOTE: We don't need to emit the return type struct here as DecomposeMemoryAccess
// already added it to the AST, and it should have already been emitted by now.
@@ -2043,7 +2053,7 @@
}
std::string member_type;
- if (Is<type::F16>(type::Type::DeepestElementOf(ty))) {
+ if (Is<type::F16>(ty->DeepestElement())) {
member_type = width.empty() ? "float16_t" : ("vector<float16_t, " + width + ">");
} else {
member_type = "float" + width;
diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc
index 17df895..37d94c3 100644
--- a/src/tint/writer/hlsl/generator_impl_function_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_function_test.cc
@@ -101,8 +101,8 @@
// fn f(foo : ptr<function, f32>) -> f32 {
// return *foo;
// }
- Func("f", utils::Vector{Param("foo", ty.pointer<f32>(builtin::AddressSpace::kFunction))},
- ty.f32(), utils::Vector{Return(Deref("foo"))});
+ Func("f", utils::Vector{Param("foo", ty.ptr<f32>(builtin::AddressSpace::kFunction))}, ty.f32(),
+ utils::Vector{Return(Deref("foo"))});
GeneratorImpl& gen = SanitizeAndBuild();
diff --git a/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc b/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
index c511a83..b1fef10 100644
--- a/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -242,7 +242,7 @@
// let p : ptr<function, i32> = &v;
// let x : i32 = *p;
auto* v = Var("v", ty.i32());
- auto* p = Let("p", ty.pointer<i32>(builtin::AddressSpace::kFunction), AddressOf(v));
+ auto* p = Let("p", ty.ptr<i32>(builtin::AddressSpace::kFunction), AddressOf(v));
auto* x = Var("x", ty.i32(), Deref(p));
Func("main", utils::Empty, ty.void_(),
@@ -276,12 +276,11 @@
// let vp : ptr<function, vec4<f32>> = &(*mp)[2i];
// let v : vec4<f32> = *vp;
auto* a = Var("a", ty.array(ty.mat4x4<f32>(), 4_u));
- auto* ap =
- Let("ap", ty.pointer(ty.array(ty.mat4x4<f32>(), 4_u), builtin::AddressSpace::kFunction),
- AddressOf(a));
- auto* mp = Let("mp", ty.pointer(ty.mat4x4<f32>(), builtin::AddressSpace::kFunction),
+ auto* ap = Let("ap", ty.ptr(builtin::AddressSpace::kFunction, ty.array(ty.mat4x4<f32>(), 4_u)),
+ AddressOf(a));
+ auto* mp = Let("mp", ty.ptr(builtin::AddressSpace::kFunction, ty.mat4x4<f32>()),
AddressOf(IndexAccessor(Deref(ap), 3_i)));
- auto* vp = Let("vp", ty.pointer(ty.vec4<f32>(), builtin::AddressSpace::kFunction),
+ auto* vp = Let("vp", ty.ptr(builtin::AddressSpace::kFunction, ty.vec4<f32>()),
AddressOf(IndexAccessor(Deref(mp), 2_i)));
auto* v = Var("v", ty.vec4<f32>(), Deref(vp));
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 4c51b17..28101df 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -144,7 +144,7 @@
// If we need to promote from scalar to vector, bitcast the scalar to the
// vector element type.
- if (curr_type->is_scalar() && target_vec_type) {
+ if (curr_type->Is<type::Scalar>() && target_vec_type) {
target_type = target_vec_type->type();
}
@@ -741,7 +741,7 @@
case builtin::Function::kLength: {
auto* sem = builder_.Sem().GetVal(expr->args[0]);
- if (sem->Type()->UnwrapRef()->is_scalar()) {
+ if (sem->Type()->UnwrapRef()->Is<type::Scalar>()) {
// Emulate scalar overload using fabs(x).
name = "fabs";
}
@@ -750,7 +750,7 @@
case builtin::Function::kDistance: {
auto* sem = builder_.Sem().GetVal(expr->args[0]);
- if (sem->Type()->UnwrapRef()->is_scalar()) {
+ if (sem->Type()->UnwrapRef()->Is<type::Scalar>()) {
// Emulate scalar overload using fabs(x - y);
out << "fabs";
ScopedParen sp(out);
diff --git a/src/tint/writer/msl/generator_impl_type_test.cc b/src/tint/writer/msl/generator_impl_type_test.cc
index 3b9b652..52d02ee 100644
--- a/src/tint/writer/msl/generator_impl_type_test.cc
+++ b/src/tint/writer/msl/generator_impl_type_test.cc
@@ -214,7 +214,7 @@
TEST_F(MslGeneratorImplTest, EmitType_Pointer) {
auto* f32 = create<type::F32>();
auto* p =
- create<type::Pointer>(f32, builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite);
+ create<type::Pointer>(builtin::AddressSpace::kWorkgroup, f32, builtin::Access::kReadWrite);
GeneratorImpl& gen = Build();
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 58bc069..a0ce6f0 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -1241,13 +1241,13 @@
return 0;
}
- bool can_cast_or_copy = result_type->is_scalar();
+ bool can_cast_or_copy = result_type->Is<type::Scalar>();
if (auto* res_vec = result_type->As<type::Vector>()) {
- if (res_vec->type()->is_scalar()) {
+ if (res_vec->type()->Is<type::Scalar>()) {
auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_vec = value_type->As<type::Vector>()) {
- if (val_vec->type()->is_scalar()) {
+ if (val_vec->type()->Is<type::Scalar>()) {
can_cast_or_copy = res_vec->Width() == val_vec->Width();
}
}
@@ -1304,7 +1304,7 @@
// Both scalars, but not the same type so we need to generate a conversion
// of the value.
- if (value_type->is_scalar() && result_type->is_scalar()) {
+ if (value_type->Is<type::Scalar>() && result_type->Is<type::Scalar>()) {
id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(), global_var);
ops.push_back(Operand(id));
continue;
@@ -1366,7 +1366,7 @@
// For a single-value vector initializer, splat the initializer value.
auto* const init_result_type = call->Type()->UnwrapRef();
if (args.Length() == 1 && init_result_type->is_scalar_vector() &&
- args[0]->Type()->UnwrapRef()->is_scalar()) {
+ args[0]->Type()->UnwrapRef()->Is<type::Scalar>()) {
size_t vec_size = init_result_type->As<type::Vector>()->Width();
for (size_t i = 0; i < (vec_size - 1); ++i) {
ops.push_back(ops[kOpsFirstValueIdx]);
@@ -1408,7 +1408,7 @@
}
auto elem_type_of = [](const type::Type* t) -> const type::Type* {
- if (t->is_scalar()) {
+ if (t->Is<type::Scalar>()) {
return t;
}
if (auto* v = t->As<type::Vector>()) {
@@ -1464,7 +1464,7 @@
(from_type->is_unsigned_integer_vector() &&
to_type->is_integer_scalar_or_vector())) {
op = spv::Op::OpBitcast;
- } else if ((from_type->is_numeric_scalar() && to_type->Is<type::Bool>()) ||
+ } else if ((from_type->Is<type::NumericScalar>() && to_type->Is<type::Bool>()) ||
(from_type->is_numeric_vector() && to_type->is_bool_vector())) {
// Convert scalar (vector) to bool (vector)
@@ -1855,8 +1855,8 @@
uint32_t Builder::GenerateSplat(uint32_t scalar_id, const type::Type* vec_type) {
// Create a new vector to splat scalar into
auto splat_vector = result_op();
- auto* splat_vector_type = builder_.create<type::Pointer>(
- vec_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ auto* splat_vector_type = builder_.create<type::Pointer>(builtin::AddressSpace::kFunction,
+ vec_type, builtin::Access::kReadWrite);
push_function_var({Operand(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
U32Operand(ConvertAddressSpace(builtin::AddressSpace::kFunction)),
Operand(GenerateConstantNullIfNeeded(vec_type))});
@@ -1990,7 +1990,7 @@
(lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
- if (lhs_type->Is<type::Vector>() && rhs_type->is_numeric_scalar()) {
+ if (lhs_type->Is<type::Vector>() && rhs_type->Is<type::NumericScalar>()) {
uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
if (splat_vector_id == 0) {
return 0;
@@ -1998,7 +1998,7 @@
rhs_id = splat_vector_id;
rhs_type = lhs_type;
- } else if (lhs_type->is_numeric_scalar() && rhs_type->Is<type::Vector>()) {
+ } else if (lhs_type->Is<type::NumericScalar>() && rhs_type->Is<type::Vector>()) {
uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
if (splat_vector_id == 0) {
return 0;
@@ -2449,7 +2449,7 @@
// If the interpolant is scalar but the objects are vectors, we need to
// splat the interpolant into a vector of the same size.
auto* result_vector_type = builtin->ReturnType()->As<type::Vector>();
- if (result_vector_type && builtin->Parameters()[2]->Type()->is_scalar()) {
+ if (result_vector_type && builtin->Parameters()[2]->Type()->Is<type::Scalar>()) {
f_id = GenerateSplat(f_id, builtin->Parameters()[0]->Type());
if (f_id == 0) {
return 0;
@@ -2482,7 +2482,7 @@
// splat the condition into a vector of the same size.
// TODO(jrprice): If we're targeting SPIR-V 1.4, we don't need to do this.
auto* result_vector_type = builtin->ReturnType()->As<type::Vector>();
- if (result_vector_type && builtin->Parameters()[2]->Type()->is_scalar()) {
+ if (result_vector_type && builtin->Parameters()[2]->Type()->Is<type::Scalar>()) {
auto* bool_vec_ty = builder_.create<type::Vector>(builder_.create<type::Bool>(),
result_vector_type->Width());
if (!GenerateTypeIfNeeded(bool_vec_ty)) {
@@ -3620,10 +3620,10 @@
// references are not legal in WGSL, so only considering the top-level type is
// fine.
if (auto* ptr = type->As<type::Pointer>()) {
- type = builder_.create<type::Pointer>(ptr->StoreType(), ptr->AddressSpace(),
+ type = builder_.create<type::Pointer>(ptr->AddressSpace(), ptr->StoreType(),
builtin::Access::kReadWrite);
} else if (auto* ref = type->As<type::Reference>()) {
- type = builder_.create<type::Pointer>(ref->StoreType(), ref->AddressSpace(),
+ type = builder_.create<type::Pointer>(ref->AddressSpace(), ref->StoreType(),
builtin::Access::kReadWrite);
}
diff --git a/src/tint/writer/spirv/builder_type_test.cc b/src/tint/writer/spirv/builder_type_test.cc
index 87c4cf3..0bf5280 100644
--- a/src/tint/writer/spirv/builder_type_test.cc
+++ b/src/tint/writer/spirv/builder_type_test.cc
@@ -297,7 +297,7 @@
TEST_F(BuilderTest_Type, GeneratePtr) {
auto* i32 = create<type::I32>();
auto* ptr =
- create<type::Pointer>(i32, builtin::AddressSpace::kOut, builtin::Access::kReadWrite);
+ create<type::Pointer>(builtin::AddressSpace::kOut, i32, builtin::Access::kReadWrite);
spirv::Builder& b = Build();
@@ -313,7 +313,7 @@
TEST_F(BuilderTest_Type, ReturnsGeneratedPtr) {
auto* i32 = create<type::I32>();
auto* ptr =
- create<type::Pointer>(i32, builtin::AddressSpace::kOut, builtin::Access::kReadWrite);
+ create<type::Pointer>(builtin::AddressSpace::kOut, i32, builtin::Access::kReadWrite);
spirv::Builder& b = Build();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index a173b49..ad1b180 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -22,8 +22,9 @@
#include "src/tint/ir/access.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
+#include "src/tint/ir/block_param.h"
#include "src/tint/ir/break_if.h"
-#include "src/tint/ir/builtin.h"
+#include "src/tint/ir/builtin_call.h"
#include "src/tint/ir/continue.h"
#include "src/tint/ir/exit_if.h"
#include "src/tint/ir/exit_loop.h"
@@ -32,6 +33,7 @@
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
+#include "src/tint/ir/multi_in_block.h"
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
@@ -136,7 +138,7 @@
return true;
}
-uint32_t GeneratorImplIr::Constant(const ir::Constant* constant) {
+uint32_t GeneratorImplIr::Constant(ir::Constant* constant) {
return Constant(constant->Value());
}
@@ -254,16 +256,14 @@
});
}
-uint32_t GeneratorImplIr::Value(const ir::Value* value) {
+uint32_t GeneratorImplIr::Value(ir::Value* value) {
return Switch(
value, //
- [&](const ir::Constant* constant) { return Constant(constant); },
- [&](const ir::Value*) {
- return values_.GetOrCreate(value, [&] { return module_.NextId(); });
- });
+ [&](ir::Constant* constant) { return Constant(constant); },
+ [&](ir::Value*) { return values_.GetOrCreate(value, [&] { return module_.NextId(); }); });
}
-uint32_t GeneratorImplIr::Label(const ir::Block* block) {
+uint32_t GeneratorImplIr::Label(ir::Block* block) {
return block_labels_.GetOrCreate(block, [&]() { return module_.NextId(); });
}
@@ -313,7 +313,7 @@
}
}
-void GeneratorImplIr::EmitFunction(const ir::Function* func) {
+void GeneratorImplIr::EmitFunction(ir::Function* func) {
auto id = Value(func);
// Emit the function name.
@@ -368,7 +368,7 @@
module_.PushFunction(current_function_);
}
-void GeneratorImplIr::EmitEntryPoint(const ir::Function* func, uint32_t id) {
+void GeneratorImplIr::EmitEntryPoint(ir::Function* func, uint32_t id) {
SpvExecutionModel stage = SpvExecutionModelMax;
switch (func->Stage()) {
case ir::Function::PipelineStage::kCompute: {
@@ -400,11 +400,11 @@
{U32Operand(stage), id, ir_->NameOf(func).Name()});
}
-void GeneratorImplIr::EmitRootBlock(const ir::Block* root_block) {
+void GeneratorImplIr::EmitRootBlock(ir::Block* root_block) {
for (auto* inst : *root_block) {
Switch(
inst, //
- [&](const ir::Var* v) { return EmitVar(v); },
+ [&](ir::Var* v) { return EmitVar(v); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented root block instruction: " << inst->TypeInfo().name;
@@ -412,7 +412,7 @@
}
}
-void GeneratorImplIr::EmitBlock(const ir::Block* block) {
+void GeneratorImplIr::EmitBlock(ir::Block* block) {
// Emit the label.
// Skip if this is the function's entry block, as it will be emitted by the function object.
if (!current_function_.instructions().empty()) {
@@ -426,20 +426,22 @@
return;
}
- // Emit all OpPhi nodes for incoming branches to block.
- EmitIncomingPhis(block);
+ if (auto* mib = block->As<ir::MultiInBlock>()) {
+ // Emit all OpPhi nodes for incoming branches to block.
+ EmitIncomingPhis(mib);
+ }
// Emit the block's statements.
EmitBlockInstructions(block);
}
-void GeneratorImplIr::EmitIncomingPhis(const ir::Block* block) {
+void GeneratorImplIr::EmitIncomingPhis(ir::MultiInBlock* block) {
// Emit Phi nodes for all the incoming block parameters
for (size_t param_idx = 0; param_idx < block->Params().Length(); param_idx++) {
auto* param = block->Params()[param_idx];
OperandList ops{Type(param->Type()), Value(param)};
- for (auto* incoming : block->InboundBranches()) {
+ for (auto* incoming : block->InboundSiblingBranches()) {
auto* arg = incoming->Args()[param_idx];
ops.push_back(Value(arg));
ops.push_back(Label(incoming->Block()));
@@ -449,21 +451,21 @@
}
}
-void GeneratorImplIr::EmitBlockInstructions(const ir::Block* block) {
+void GeneratorImplIr::EmitBlockInstructions(ir::Block* block) {
for (auto* inst : *block) {
Switch(
inst, //
- [&](const ir::Access* a) { EmitAccess(a); }, //
- [&](const ir::Binary* b) { EmitBinary(b); }, //
- [&](const ir::Builtin* b) { EmitBuiltin(b); }, //
- [&](const ir::Load* l) { EmitLoad(l); }, //
- [&](const ir::Loop* l) { EmitLoop(l); }, //
- [&](const ir::Switch* sw) { EmitSwitch(sw); }, //
- [&](const ir::Store* s) { EmitStore(s); }, //
- [&](const ir::UserCall* c) { EmitUserCall(c); }, //
- [&](const ir::Var* v) { EmitVar(v); }, //
- [&](const ir::If* i) { EmitIf(i); }, //
- [&](const ir::Branch* b) { EmitBranch(b); }, //
+ [&](ir::Access* a) { EmitAccess(a); }, //
+ [&](ir::Binary* b) { EmitBinary(b); }, //
+ [&](ir::BuiltinCall* b) { EmitBuiltinCall(b); }, //
+ [&](ir::Load* l) { EmitLoad(l); }, //
+ [&](ir::Loop* l) { EmitLoop(l); }, //
+ [&](ir::Switch* sw) { EmitSwitch(sw); }, //
+ [&](ir::Store* s) { EmitStore(s); }, //
+ [&](ir::UserCall* c) { EmitUserCall(c); }, //
+ [&](ir::Var* v) { EmitVar(v); }, //
+ [&](ir::If* i) { EmitIf(i); }, //
+ [&](ir::Branch* b) { EmitBranch(b); }, //
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented instruction: " << inst->TypeInfo().name;
@@ -471,10 +473,10 @@
}
}
-void GeneratorImplIr::EmitBranch(const ir::Branch* b) {
+void GeneratorImplIr::EmitBranch(ir::Branch* b) {
tint::Switch( //
b, //
- [&](const ir::Return*) {
+ [&](ir::Return*) {
if (!b->Args().IsEmpty()) {
TINT_ASSERT(Writer, b->Args().Length() == 1u);
OperandList operands;
@@ -485,7 +487,7 @@
}
return;
},
- [&](const ir::BreakIf* breakif) {
+ [&](ir::BreakIf* breakif) {
current_function_.push_inst(spv::Op::OpBranchConditional,
{
Value(breakif->Condition()),
@@ -493,19 +495,19 @@
Label(breakif->Loop()->Body()),
});
},
- [&](const ir::Continue* cont) {
+ [&](ir::Continue* cont) {
current_function_.push_inst(spv::Op::OpBranch, {Label(cont->Loop()->Continuing())});
},
- [&](const ir::ExitIf* if_) {
+ [&](ir::ExitIf* if_) {
current_function_.push_inst(spv::Op::OpBranch, {Label(if_->If()->Merge())});
},
- [&](const ir::ExitLoop* loop) {
+ [&](ir::ExitLoop* loop) {
current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Merge())});
},
- [&](const ir::ExitSwitch* swtch) {
+ [&](ir::ExitSwitch* swtch) {
current_function_.push_inst(spv::Op::OpBranch, {Label(swtch->Switch()->Merge())});
},
- [&](const ir::NextIteration* loop) {
+ [&](ir::NextIteration* loop) {
current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Body())});
},
[&](Default) {
@@ -513,7 +515,7 @@
});
}
-void GeneratorImplIr::EmitIf(const ir::If* i) {
+void GeneratorImplIr::EmitIf(ir::If* i) {
auto* merge_block = i->Merge();
auto* true_block = i->True();
auto* false_block = i->False();
@@ -553,7 +555,7 @@
EmitBlock(merge_block);
}
-void GeneratorImplIr::EmitAccess(const ir::Access* access) {
+void GeneratorImplIr::EmitAccess(ir::Access* access) {
auto id = Value(access);
OperandList operands = {Type(access->Type()), id, Value(access->Object())};
@@ -568,30 +570,24 @@
// For non-pointer types, we assume that the indices are constants and use OpCompositeExtract.
// If we hit a non-constant index into a vector type, use OpVectorExtractDynamic for it.
- auto* ty = access->Object()->Type();
+ auto* source_ty = access->Object()->Type();
for (auto* idx : access->Indices()) {
if (auto* constant = idx->As<ir::Constant>()) {
// Push the index to the chain and update the current type.
auto i = constant->Value()->ValueAs<u32>();
operands.push_back(i);
- ty = Switch(
- ty, //
- [&](const type::Array* arr) { return arr->ElemType(); },
- [&](const type::Matrix* mat) { return mat->ColumnType(); },
- [&](const type::Struct* str) { return str->Members()[i]->Type(); },
- [&](const type::Vector* vec) { return vec->type(); },
- [&](Default) { return nullptr; });
+ source_ty = source_ty->Element(i);
} else {
// The VarForDynamicIndex transform ensures that only value types that are vectors
// will be dynamically indexed, as we can use OpVectorExtractDynamic for this case.
- TINT_ASSERT(Writer, ty->Is<type::Vector>());
+ TINT_ASSERT(Writer, source_ty->Is<type::Vector>());
// If this wasn't the first access in the chain then emit the chain so far as an
// OpCompositeExtract, creating a new result ID for the resulting vector.
auto vec_id = Value(access->Object());
if (operands.size() > 3) {
vec_id = module_.NextId();
- operands[0] = Type(ty);
+ operands[0] = Type(source_ty);
operands[1] = vec_id;
current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
}
@@ -605,7 +601,7 @@
current_function_.push_inst(spv::Op::OpCompositeExtract, std::move(operands));
}
-void GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
+void GeneratorImplIr::EmitBinary(ir::Binary* binary) {
auto id = Value(binary);
auto* lhs_ty = binary->LHS()->Type();
@@ -706,7 +702,7 @@
op, {Type(binary->Type()), id, Value(binary->LHS()), Value(binary->RHS())});
}
-void GeneratorImplIr::EmitBuiltin(const ir::Builtin* builtin) {
+void GeneratorImplIr::EmitBuiltinCall(ir::BuiltinCall* builtin) {
auto* result_ty = builtin->Type();
if (builtin->Func() == builtin::Function::kAbs &&
@@ -775,12 +771,12 @@
current_function_.push_inst(op, operands);
}
-void GeneratorImplIr::EmitLoad(const ir::Load* load) {
+void GeneratorImplIr::EmitLoad(ir::Load* load) {
current_function_.push_inst(spv::Op::OpLoad,
{Type(load->Type()), Value(load), Value(load->From())});
}
-void GeneratorImplIr::EmitLoop(const ir::Loop* loop) {
+void GeneratorImplIr::EmitLoop(ir::Loop* loop) {
auto init_label = loop->HasInitializer() ? Label(loop->Initializer()) : 0;
auto header_label = Label(loop->Body()); // Back-edge needs to branch to the loop header
auto body_label = module_.NextId();
@@ -821,7 +817,7 @@
EmitBlock(loop->Merge());
}
-void GeneratorImplIr::EmitSwitch(const ir::Switch* swtch) {
+void GeneratorImplIr::EmitSwitch(ir::Switch* swtch) {
// Find the default selector. There must be exactly one.
uint32_t default_label = 0u;
for (auto& c : swtch->Cases()) {
@@ -860,11 +856,11 @@
EmitBlock(swtch->Merge());
}
-void GeneratorImplIr::EmitStore(const ir::Store* store) {
+void GeneratorImplIr::EmitStore(ir::Store* store) {
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
}
-void GeneratorImplIr::EmitUserCall(const ir::UserCall* call) {
+void GeneratorImplIr::EmitUserCall(ir::UserCall* call) {
auto id = Value(call);
OperandList operands = {Type(call->Type()), id, Value(call->Func())};
for (auto* arg : call->Args()) {
@@ -873,10 +869,9 @@
current_function_.push_inst(spv::Op::OpFunctionCall, operands);
}
-void GeneratorImplIr::EmitVar(const ir::Var* var) {
+void GeneratorImplIr::EmitVar(ir::Var* var) {
auto id = Value(var);
- auto* ptr = var->Type()->As<type::Pointer>();
- TINT_ASSERT(Writer, ptr);
+ auto* ptr = var->Type();
auto ty = Type(ptr);
switch (ptr->AddressSpace()) {
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index d3abeb6..dccb23a 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -34,12 +34,13 @@
class Block;
class BlockParam;
class Branch;
-class Builtin;
-class If;
+class BuiltinCall;
class Function;
+class If;
class Load;
class Loop;
class Module;
+class MultiInBlock;
class Store;
class Switch;
class UserCall;
@@ -77,7 +78,7 @@
/// Get the result ID of the constant `constant`, emitting its instruction if necessary.
/// @param constant the constant to get the ID for
/// @returns the result ID of the constant
- uint32_t Constant(const ir::Constant* constant);
+ uint32_t Constant(ir::Constant* constant);
/// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
/// @param type the type to get the ID for
@@ -92,12 +93,12 @@
/// Get the result ID of the value `value`, emitting its instruction if necessary.
/// @param value the value to get the ID for
/// @returns the result ID of the value
- uint32_t Value(const ir::Value* value);
+ uint32_t Value(ir::Value* value);
/// Get the ID of the label for `block`.
/// @param block the block to get the label ID for
/// @returns the ID of the block's label
- uint32_t Label(const ir::Block* block);
+ uint32_t Label(ir::Block* block);
/// Emit a struct type.
/// @param id the result ID to use
@@ -106,72 +107,72 @@
/// Emit a function.
/// @param func the function to emit
- void EmitFunction(const ir::Function* func);
+ void EmitFunction(ir::Function* func);
/// Emit entry point declarations for a function.
/// @param func the function to emit entry point declarations for
/// @param id the result ID of the function declaration
- void EmitEntryPoint(const ir::Function* func, uint32_t id);
+ void EmitEntryPoint(ir::Function* func, uint32_t id);
/// Emit a block, including the initial OpLabel, OpPhis and instructions.
/// @param block the block to emit
- void EmitBlock(const ir::Block* block);
+ void EmitBlock(ir::Block* block);
/// Emit all OpPhi nodes for incoming branches to @p block.
/// @param block the block to emit the OpPhis for
- void EmitIncomingPhis(const ir::Block* block);
+ void EmitIncomingPhis(ir::MultiInBlock* block);
/// Emit all instructions of @p block.
/// @param block the block's instructions to emit
- void EmitBlockInstructions(const ir::Block* block);
+ void EmitBlockInstructions(ir::Block* block);
/// Emit the root block.
/// @param root_block the root block to emit
- void EmitRootBlock(const ir::Block* root_block);
+ void EmitRootBlock(ir::Block* root_block);
/// Emit an `if` flow node.
/// @param i the if node to emit
- void EmitIf(const ir::If* i);
+ void EmitIf(ir::If* i);
/// Emit an access instruction
/// @param access the access instruction to emit
- void EmitAccess(const ir::Access* access);
+ void EmitAccess(ir::Access* access);
/// Emit a binary instruction.
/// @param binary the binary instruction to emit
- void EmitBinary(const ir::Binary* binary);
+ void EmitBinary(ir::Binary* binary);
/// Emit a builtin function call instruction.
/// @param call the builtin call instruction to emit
- void EmitBuiltin(const ir::Builtin* call);
+ void EmitBuiltinCall(ir::BuiltinCall* call);
/// Emit a load instruction.
/// @param load the load instruction to emit
- void EmitLoad(const ir::Load* load);
+ void EmitLoad(ir::Load* load);
/// Emit a loop instruction.
/// @param loop the loop instruction to emit
- void EmitLoop(const ir::Loop* loop);
+ void EmitLoop(ir::Loop* loop);
/// Emit a store instruction.
/// @param store the store instruction to emit
- void EmitStore(const ir::Store* store);
+ void EmitStore(ir::Store* store);
/// Emit a switch instruction.
/// @param swtch the switch instruction to emit
- void EmitSwitch(const ir::Switch* swtch);
+ void EmitSwitch(ir::Switch* swtch);
/// Emit a user call instruction.
/// @param call the user call instruction to emit
- void EmitUserCall(const ir::UserCall* call);
+ void EmitUserCall(ir::UserCall* call);
/// Emit a var instruction.
/// @param var the var instruction to emit
- void EmitVar(const ir::Var* var);
+ void EmitVar(ir::Var* var);
/// Emit a branch instruction.
/// @param b the branch instruction to emit
- void EmitBranch(const ir::Branch* b);
+ void EmitBranch(ir::Branch* b);
private:
/// Get the result ID of the constant `constant`, emitting its instruction if necessary.
@@ -222,10 +223,10 @@
utils::Hashmap<const type::Type*, uint32_t, 4> constant_nulls_;
/// The map of non-constant values to their result IDs.
- utils::Hashmap<const ir::Value*, uint32_t, 8> values_;
+ utils::Hashmap<ir::Value*, uint32_t, 8> values_;
/// The map of blocks to the IDs of their label instructions.
- utils::Hashmap<const ir::Block*, uint32_t, 8> block_labels_;
+ utils::Hashmap<ir::Block*, uint32_t, 8> block_labels_;
/// The map of extended instruction set names to their result IDs.
utils::Hashmap<std::string_view, uint32_t, 2> imports_;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
index 62dd469..4318147 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -21,17 +21,17 @@
class SpvGeneratorImplTest_Access : public SpvGeneratorImplTest {
protected:
- const type::Type* ptr(const type::Type* elem) {
- return ty.pointer(elem, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ const type::Pointer* ptr(const type::Type* elem) {
+ return ty.ptr(builtin::AddressSpace::kFunction, elem, builtin::Access::kReadWrite);
}
};
TEST_F(SpvGeneratorImplTest_Access, Array_Value_ConstantIndex) {
auto* arr_val = b.FunctionParam(ty.array(ty.i32(), 4));
- auto* access = b.Access(ty.i32(), arr_val, utils::Vector{b.Constant(1_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{arr_val});
- func->StartTarget()->SetInstructions(utils::Vector{access, b.Return(func)});
+ auto* access = b.Access(ty.i32(), arr_val, 1_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({arr_val});
+ func->StartTarget()->SetInstructions({access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -54,10 +54,10 @@
}
TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_ConstantIndex) {
- auto* arr_var = b.Declare(ptr(ty.array(ty.i32(), 4)));
- auto* access = b.Access(ptr(ty.i32()), arr_var, utils::Vector{b.Constant(1_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{arr_var, access, b.Return(func)});
+ auto* arr_var = b.Var(ptr(ty.array(ty.i32(), 4)));
+ auto* access = b.Access(ptr(ty.i32()), arr_var, 1_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({arr_var, access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -83,11 +83,11 @@
}
TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_DynamicIndex) {
- auto* arr_var = b.Declare(ptr(ty.array(ty.i32(), 4)));
- auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* arr_var = b.Var(ptr(ty.array(ty.i32(), 4)));
+ auto* idx_var = b.Var(ptr(ty.i32()));
auto* idx = b.Load(idx_var);
- auto* access = b.Access(ptr(ty.i32()), arr_var, utils::Vector{idx});
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* access = b.Access(ptr(ty.i32()), arr_var, idx);
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{idx_var, idx, arr_var, access, b.Return(func)});
@@ -117,11 +117,11 @@
TEST_F(SpvGeneratorImplTest_Access, Matrix_Value_ConstantIndex) {
auto* mat_val = b.FunctionParam(ty.mat2x2(ty.f32()));
- auto* access_vec = b.Access(ty.vec2(ty.f32()), mat_val, utils::Vector{b.Constant(1_u)});
- auto* access_el = b.Access(ty.f32(), mat_val, utils::Vector{b.Constant(1_u), b.Constant(0_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{mat_val});
- func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+ auto* access_vec = b.Access(ty.vec2(ty.f32()), mat_val, 1_u);
+ auto* access_el = b.Access(ty.f32(), mat_val, 1_u, 0_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({mat_val});
+ func->StartTarget()->SetInstructions({access_vec, access_el, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -143,12 +143,11 @@
}
TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_ConstantIndex) {
- auto* mat_var = b.Declare(ptr(ty.mat2x2(ty.f32())));
- auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, utils::Vector{b.Constant(1_u)});
- auto* access_el =
- b.Access(ptr(ty.f32()), mat_var, utils::Vector{b.Constant(1_u), b.Constant(0_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+ auto* mat_var = b.Var(ptr(ty.mat2x2(ty.f32())));
+ auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, 1_u);
+ auto* access_el = b.Access(ptr(ty.f32()), mat_var, 1_u, 0_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({access_vec, access_el, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -173,12 +172,12 @@
}
TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_DynamicIndex) {
- auto* mat_var = b.Declare(ptr(ty.mat2x2(ty.f32())));
- auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* mat_var = b.Var(ptr(ty.mat2x2(ty.f32())));
+ auto* idx_var = b.Var(ptr(ty.i32()));
auto* idx = b.Load(idx_var);
- auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, utils::Vector{idx});
- auto* access_el = b.Access(ptr(ty.f32()), mat_var, utils::Vector{idx, idx});
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* access_vec = b.Access(ptr(ty.vec2(ty.f32())), mat_var, idx);
+ auto* access_el = b.Access(ptr(ty.f32()), mat_var, idx, idx);
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{idx_var, idx, mat_var, access_vec, access_el, b.Return(func)});
@@ -210,10 +209,10 @@
TEST_F(SpvGeneratorImplTest_Access, Vector_Value_ConstantIndex) {
auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
- auto* access = b.Access(ty.i32(), vec_val, utils::Vector{b.Constant(1_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{vec_val});
- func->StartTarget()->SetInstructions(utils::Vector{access, b.Return(func)});
+ auto* access = b.Access(ty.i32(), vec_val, 1_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec_val});
+ func->StartTarget()->SetInstructions({access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -234,12 +233,12 @@
TEST_F(SpvGeneratorImplTest_Access, Vector_Value_DynamicIndex) {
auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
- auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx_var = b.Var(ptr(ty.i32()));
auto* idx = b.Load(idx_var);
- auto* access = b.Access(ty.i32(), vec_val, utils::Vector{idx});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{vec_val});
- func->StartTarget()->SetInstructions(utils::Vector{idx_var, idx, access, b.Return(func)});
+ auto* access = b.Access(ty.i32(), vec_val, idx);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({vec_val});
+ func->StartTarget()->SetInstructions({idx_var, idx, access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -262,10 +261,10 @@
}
TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_ConstantIndex) {
- auto* vec_var = b.Declare(ptr(ty.vec4(ty.i32())));
- auto* access = b.Access(ptr(ty.i32()), vec_var, utils::Vector{b.Constant(1_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{vec_var, access, b.Return(func)});
+ auto* vec_var = b.Var(ptr(ty.vec4(ty.i32())));
+ auto* access = b.Access(ptr(ty.i32()), vec_var, 1_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({vec_var, access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -289,11 +288,11 @@
}
TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_DynamicIndex) {
- auto* vec_var = b.Declare(ptr(ty.vec4(ty.i32())));
- auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* vec_var = b.Var(ptr(ty.vec4(ty.i32())));
+ auto* idx_var = b.Var(ptr(ty.i32()));
auto* idx = b.Load(idx_var);
- auto* access = b.Access(ptr(ty.i32()), vec_var, utils::Vector{idx});
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* access = b.Access(ptr(ty.i32()), vec_var, idx);
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{idx_var, idx, vec_var, access, b.Return(func)});
@@ -320,12 +319,12 @@
TEST_F(SpvGeneratorImplTest_Access, NestedVector_Value_DynamicIndex) {
auto* val = b.FunctionParam(ty.array(ty.array(ty.vec4(ty.i32()), 4), 4));
- auto* idx_var = b.Declare(ptr(ty.i32()));
+ auto* idx_var = b.Var(ptr(ty.i32()));
auto* idx = b.Load(idx_var);
- auto* access = b.Access(ty.i32(), val, utils::Vector{b.Constant(1_u), b.Constant(2_u), idx});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{val});
- func->StartTarget()->SetInstructions(utils::Vector{idx_var, idx, access, b.Return(func)});
+ auto* access = b.Access(ty.i32(), val, 1_u, 2_u, idx);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({val});
+ func->StartTarget()->SetInstructions({idx_var, idx, access, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -365,11 +364,11 @@
},
16u, 32u, 32u);
auto* str_val = b.FunctionParam(str);
- auto* access_vec = b.Access(ty.i32(), str_val, utils::Vector{b.Constant(1_u)});
- auto* access_el = b.Access(ty.i32(), str_val, utils::Vector{b.Constant(1_u), b.Constant(2_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
- func->SetParams(utils::Vector{str_val});
- func->StartTarget()->SetInstructions(utils::Vector{access_vec, access_el, b.Return(func)});
+ auto* access_vec = b.Access(ty.i32(), str_val, 1_u);
+ auto* access_el = b.Access(ty.i32(), str_val, 1_u, 2_u);
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({str_val});
+ func->StartTarget()->SetInstructions({access_vec, access_el, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -406,11 +405,10 @@
16u, type::StructMemberAttributes{}),
},
16u, 32u, 32u);
- auto* str_var = b.Declare(ptr(str));
- auto* access_vec = b.Access(ptr(ty.i32()), str_var, utils::Vector{b.Constant(1_u)});
- auto* access_el =
- b.Access(ptr(ty.i32()), str_var, utils::Vector{b.Constant(1_u), b.Constant(2_u)});
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* str_var = b.Var(ptr(str));
+ auto* access_vec = b.Access(ptr(ty.i32()), str_var, 1_u);
+ auto* access_el = b.Access(ptr(ty.i32()), str_var, 1_u, 2_u);
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
utils::Vector{str_var, access_vec, access_el, b.Return(func)});
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 5481e03..6ca13f7 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -36,10 +36,10 @@
TEST_P(Arithmetic, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
- MakeScalarValue(params.type), MakeScalarValue(params.type)),
+ utils::Vector{b.Binary(params.kind, MakeScalarType(params.type),
+ MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -50,10 +50,10 @@
TEST_P(Arithmetic, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
- MakeVectorValue(params.type), MakeVectorValue(params.type)),
+ utils::Vector{b.Binary(params.kind, MakeVectorType(params.type),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
b.Return(func)});
@@ -87,10 +87,10 @@
TEST_P(Bitwise, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
- MakeScalarValue(params.type), MakeScalarValue(params.type)),
+ utils::Vector{b.Binary(params.kind, MakeScalarType(params.type),
+ MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -101,10 +101,10 @@
TEST_P(Bitwise, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
- MakeVectorValue(params.type), MakeVectorValue(params.type)),
+ utils::Vector{b.Binary(params.kind, MakeVectorType(params.type),
+ MakeVectorValue(params.type), MakeVectorValue(params.type)),
b.Return(func)});
@@ -130,10 +130,10 @@
TEST_P(Comparison, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, ty.bool_(), MakeScalarValue(params.type),
- MakeScalarValue(params.type)),
+ utils::Vector{b.Binary(params.kind, ty.bool_(), MakeScalarValue(params.type),
+ MakeScalarValue(params.type)),
b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -144,10 +144,10 @@
TEST_P(Comparison, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.CreateBinary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
- MakeVectorValue(params.type)),
+ utils::Vector{b.Binary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
+ MakeVectorValue(params.type)),
b.Return(func)});
@@ -203,9 +203,9 @@
BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"}));
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
- auto* func = b.CreateFunction("foo", ty.void_());
- auto* a = b.Subtract(ty.i32(), b.Constant(1_i), b.Constant(2_i));
- func->StartTarget()->SetInstructions(utils::Vector{a, b.Add(ty.i32(), a, a), b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ auto* a = b.Subtract(ty.i32(), 1_i, 2_i);
+ func->StartTarget()->SetInstructions({a, b.Add(ty.i32(), a, a), b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index 04d89fd..44d0505 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -37,11 +37,11 @@
TEST_P(Builtin_1arg, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(
- utils::Vector{b.Builtin(MakeScalarType(params.type), params.function,
- utils::Vector{MakeScalarValue(params.type)}),
- b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({
+ b.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type)),
+ b.Return(func),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -51,12 +51,11 @@
TEST_P(Builtin_1arg, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(
- utils::Vector{b.Builtin(MakeVectorType(params.type), params.function,
- utils::Vector{MakeVectorValue(params.type)}),
-
- b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({
+ b.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type)),
+ b.Return(func),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -70,11 +69,12 @@
// Test that abs of an unsigned value just folds away.
TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) {
- auto* result = b.Builtin(MakeScalarType(kU32), builtin::Function::kAbs,
- utils::Vector{MakeScalarValue(kU32)});
- auto* func = b.CreateFunction("foo", MakeScalarType(kU32));
- func->StartTarget()->SetInstructions(
- utils::Vector{result, b.Return(func, utils::Vector{result})});
+ auto* result = b.Call(MakeScalarType(kU32), builtin::Function::kAbs, MakeScalarValue(kU32));
+ auto* func = b.Function("foo", MakeScalarType(kU32));
+ func->StartTarget()->SetInstructions({
+ result,
+ b.Return(func, result),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -89,12 +89,14 @@
OpFunctionEnd
)");
}
+
TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) {
- auto* result = b.Builtin(MakeVectorType(kU32), builtin::Function::kAbs,
- utils::Vector{MakeVectorValue(kU32)});
- auto* func = b.CreateFunction("foo", MakeVectorType(kU32));
- func->StartTarget()->SetInstructions(
- utils::Vector{result, b.Return(func, utils::Vector{result})});
+ auto* result = b.Call(MakeVectorType(kU32), builtin::Function::kAbs, MakeVectorValue(kU32));
+ auto* func = b.Function("foo", MakeVectorType(kU32));
+ func->StartTarget()->SetInstructions({
+ result,
+ b.Return(func, result),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -118,11 +120,12 @@
TEST_P(Builtin_2arg, Scalar) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{
- b.Builtin(MakeScalarType(params.type), params.function,
- utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}),
- b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({
+ b.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type),
+ MakeScalarValue(params.type)),
+ b.Return(func),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -132,12 +135,13 @@
TEST_P(Builtin_2arg, Vector) {
auto params = GetParam();
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{
- b.Builtin(MakeVectorType(params.type), params.function,
- utils::Vector{MakeVectorValue(params.type), MakeVectorValue(params.type)}),
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({
+ b.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type),
+ MakeVectorValue(params.type)),
- b.Return(func)});
+ b.Return(func),
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 0879e58..9c67418 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -18,8 +18,8 @@
namespace {
TEST_F(SpvGeneratorImplTest, Function_Empty) {
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -36,8 +36,8 @@
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
- auto* func = b.CreateFunction("foo", ty.void_());
- func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* func = b.Function("foo", ty.void_());
+ func->StartTarget()->SetInstructions({b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -51,8 +51,8 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func =
- b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ b.Function("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
+ func->StartTarget()->SetInstructions({b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -70,8 +70,8 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
- auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kFragment);
- func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kFragment);
+ func->StartTarget()->SetInstructions({b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -89,8 +89,8 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
- auto* func = b.CreateFunction("main", ty.void_(), ir::Function::PipelineStage::kVertex);
- func->StartTarget()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kVertex);
+ func->StartTarget()->SetInstructions({b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -107,16 +107,14 @@
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
- auto* f1 =
- b.CreateFunction("main1", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- f1->StartTarget()->SetInstructions(utils::Vector{b.Return(f1)});
+ auto* f1 = b.Function("main1", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
+ f1->StartTarget()->SetInstructions({b.Return(f1)});
- auto* f2 =
- b.CreateFunction("main2", ty.void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
- f2->StartTarget()->SetInstructions(utils::Vector{b.Return(f2)});
+ auto* f2 = b.Function("main2", ty.void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
+ f2->StartTarget()->SetInstructions({b.Return(f2)});
- auto* f3 = b.CreateFunction("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
- f3->StartTarget()->SetInstructions(utils::Vector{b.Return(f3)});
+ auto* f3 = b.Function("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
+ f3->StartTarget()->SetInstructions({b.Return(f3)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -150,9 +148,8 @@
}
TEST_F(SpvGeneratorImplTest, Function_ReturnValue) {
- auto* func = b.CreateFunction("foo", ty.i32());
- func->StartTarget()->SetInstructions(
- utils::Vector{b.Return(func, utils::Vector{b.Constant(i32(42))})});
+ auto* func = b.Function("foo", ty.i32());
+ func->StartTarget()->SetInstructions({b.Return(func, i32(42))});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -173,10 +170,9 @@
auto* x = b.FunctionParam(i32);
auto* y = b.FunctionParam(i32);
auto* result = b.Add(i32, x, y);
- auto* func = b.CreateFunction("foo", i32);
- func->SetParams(utils::Vector{x, y});
- func->StartTarget()->SetInstructions(
- utils::Vector{result, b.Return(func, utils::Vector{result})});
+ auto* func = b.Function("foo", i32);
+ func->SetParams({x, y});
+ func->StartTarget()->SetInstructions({result, b.Return(func, result)});
mod.SetName(x, "x");
mod.SetName(y, "y");
@@ -203,15 +199,13 @@
auto* x = b.FunctionParam(i32_ty);
auto* y = b.FunctionParam(i32_ty);
auto* result = b.Add(i32_ty, x, y);
- auto* foo = b.CreateFunction("foo", i32_ty);
- foo->SetParams(utils::Vector{x, y});
- foo->StartTarget()->SetInstructions(
- utils::Vector{result, b.Return(foo, utils::Vector{result})});
+ auto* foo = b.Function("foo", i32_ty);
+ foo->SetParams({x, y});
+ foo->StartTarget()->SetInstructions({result, b.Return(foo, result)});
- auto* bar = b.CreateFunction("bar", ty.void_());
- bar->StartTarget()->SetInstructions(utils::Vector{
- b.UserCall(i32_ty, foo, utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}),
- b.Return(bar)});
+ auto* bar = b.Function("bar", ty.void_());
+ bar->StartTarget()->SetInstructions(
+ utils::Vector{b.Call(i32_ty, foo, i32(2), i32(3)), b.Return(bar)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -241,12 +235,12 @@
}
TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
- auto* foo = b.CreateFunction("foo", ty.void_());
- foo->StartTarget()->SetInstructions(utils::Vector{b.Return(foo)});
+ auto* foo = b.Function("foo", ty.void_());
+ foo->StartTarget()->SetInstructions({b.Return(foo)});
- auto* bar = b.CreateFunction("bar", ty.void_());
+ auto* bar = b.Function("bar", ty.void_());
bar->StartTarget()->SetInstructions(
- utils::Vector{b.UserCall(ty.void_(), foo, utils::Empty), b.Return(bar)});
+ utils::Vector{b.Call(ty.void_(), foo, utils::Empty), b.Return(bar)});
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index 4c4e5c3..6534a56 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -20,14 +20,14 @@
namespace {
TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.ExitIf(i)});
- i->False()->SetInstructions(utils::Vector{b.ExitIf(i)});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.ExitIf(i)});
+ i->False()->SetInstructions({b.ExitIf(i)});
+ i->Merge()->SetInstructions({b.Return(func)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -48,17 +48,16 @@
}
TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* i = b.CreateIf(b.Constant(true));
- i->False()->SetInstructions(utils::Vector{b.ExitIf(i)});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* i = b.If(true);
+ i->False()->SetInstructions({b.ExitIf(i)});
+ i->Merge()->SetInstructions({b.Return(func)});
auto* true_block = i->True();
- true_block->SetInstructions(
- utils::Vector{b.Add(ty.i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
+ true_block->SetInstructions({b.Add(ty.i32(), 1_i, 1_i), b.ExitIf(i)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -84,17 +83,16 @@
}
TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.ExitIf(i)});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.ExitIf(i)});
+ i->Merge()->SetInstructions({b.Return(func)});
auto* false_block = i->False();
- false_block->SetInstructions(
- utils::Vector{b.Add(ty.i32(), b.Constant(1_i), b.Constant(1_i)), b.ExitIf(i)});
+ false_block->SetInstructions({b.Add(ty.i32(), 1_i, 1_i), b.ExitIf(i)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -120,13 +118,13 @@
}
TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.Return(func)});
- i->False()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.Return(func)});
+ i->False()->SetInstructions({b.Return(func)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -151,17 +149,17 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i)})});
- i->False()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i)})});
- i->Merge()->SetParams(utils::Vector{merge_param});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.ExitIf(i, 10_i)});
+ i->False()->SetInstructions({b.ExitIf(i, 20_i)});
+ i->Merge()->SetParams({merge_param});
+ i->Merge()->SetInstructions({b.Return(func, merge_param)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -190,17 +188,17 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_TrueReturn) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{b.Constant(42_i)})});
- i->False()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i)})});
- i->Merge()->SetParams(utils::Vector{merge_param});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.Return(func, 42_i)});
+ i->False()->SetInstructions({b.ExitIf(i, 20_i)});
+ i->Merge()->SetParams({merge_param});
+ i->Merge()->SetInstructions({b.Return(func, merge_param)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -229,17 +227,17 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_SingleValue_FalseReturn) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i)})});
- i->False()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{b.Constant(42_i)})});
- i->Merge()->SetParams(utils::Vector{merge_param});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func, utils::Vector{merge_param})});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.ExitIf(i, 10_i)});
+ i->False()->SetInstructions({b.Return(func, 42_i)});
+ i->Merge()->SetParams({merge_param});
+ i->Merge()->SetInstructions({b.Return(func, merge_param)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -268,22 +266,18 @@
}
TEST_F(SpvGeneratorImplTest, If_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(
- utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(10_i), b.Constant(true)})});
- i->False()->SetInstructions(
- utils::Vector{b.ExitIf(i, utils::Vector{b.Constant(20_i), b.Constant(false)})});
- i->Merge()->SetParams(utils::Vector{merge_param_0, merge_param_1});
- i->Merge()->SetInstructions(utils::Vector{
- b.Return(func, utils::Vector{merge_param_0}),
- });
+ auto* i = b.If(true);
+ i->True()->SetInstructions({b.ExitIf(i, 10_i, true)});
+ i->False()->SetInstructions({b.ExitIf(i, 20_i, false)});
+ i->Merge()->SetParams({merge_param_0, merge_param_1});
+ i->Merge()->SetInstructions({b.Return(func, merge_param_0)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index 807833f..c1e3f00 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -20,12 +20,12 @@
namespace {
TEST_F(SpvGeneratorImplTest, Loop_BreakIf) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
loop->Body()->Append(b.Continue(loop));
- loop->Continuing()->Append(b.BreakIf(b.Constant(true), loop));
+ loop->Continuing()->Append(b.BreakIf(true, loop));
loop->Merge()->Append(b.Return(func));
func->StartTarget()->Append(loop);
@@ -56,9 +56,9 @@
// Test that we still emit the continuing block with a back-edge, even when it is unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalBreakInBody) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
loop->Body()->Append(b.ExitLoop(loop));
loop->Merge()->Append(b.Return(func));
@@ -88,11 +88,11 @@
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalBreakInBody) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
- auto* cond_break = b.CreateIf(b.Constant(true));
+ auto* cond_break = b.If(true);
cond_break->True()->Append(b.ExitLoop(loop));
cond_break->False()->Append(b.ExitIf(cond_break));
cond_break->Merge()->Append(b.Continue(loop));
@@ -133,11 +133,11 @@
}
TEST_F(SpvGeneratorImplTest, Loop_ConditionalContinueInBody) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
- auto* cond_break = b.CreateIf(b.Constant(true));
+ auto* cond_break = b.If(true);
cond_break->True()->Append(b.Continue(loop));
cond_break->False()->Append(b.ExitIf(cond_break));
cond_break->Merge()->Append(b.ExitLoop(loop));
@@ -180,9 +180,9 @@
// Test that we still emit the continuing block with a back-edge, and the merge block, even when
// they are unreachable.
TEST_F(SpvGeneratorImplTest, Loop_UnconditionalReturnInBody) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
loop->Body()->Append(b.Return(func));
@@ -211,11 +211,11 @@
}
TEST_F(SpvGeneratorImplTest, Loop_UseResultFromBodyInContinuing) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* loop = b.CreateLoop();
+ auto* loop = b.Loop();
- auto* result = b.Equal(ty.i32(), b.Constant(1_i), b.Constant(2_i));
+ auto* result = b.Equal(ty.i32(), 1_i, 2_i);
loop->Body()->Append(result);
loop->Continuing()->Append(b.BreakIf(result, loop));
@@ -249,17 +249,17 @@
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInBody) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* outer_loop = b.CreateLoop();
- auto* inner_loop = b.CreateLoop();
+ auto* outer_loop = b.Loop();
+ auto* inner_loop = b.Loop();
inner_loop->Body()->Append(b.ExitLoop(inner_loop));
inner_loop->Continuing()->Append(b.NextIteration(inner_loop));
inner_loop->Merge()->Append(b.Continue(outer_loop));
outer_loop->Body()->Append(inner_loop);
- outer_loop->Continuing()->Append(b.BreakIf(b.Constant(true), outer_loop));
+ outer_loop->Continuing()->Append(b.BreakIf(true, outer_loop));
outer_loop->Merge()->Append(b.Return(func));
func->StartTarget()->Append(outer_loop);
@@ -298,14 +298,14 @@
}
TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInContinuing) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* outer_loop = b.CreateLoop();
- auto* inner_loop = b.CreateLoop();
+ auto* outer_loop = b.Loop();
+ auto* inner_loop = b.Loop();
inner_loop->Body()->Append(b.Continue(inner_loop));
- inner_loop->Continuing()->Append(b.BreakIf(b.Constant(true), inner_loop));
- inner_loop->Merge()->Append(b.BreakIf(b.Constant(true), outer_loop));
+ inner_loop->Continuing()->Append(b.BreakIf(true, inner_loop));
+ inner_loop->Merge()->Append(b.BreakIf(true, outer_loop));
outer_loop->Body()->Append(b.Continue(outer_loop));
outer_loop->Continuing()->Append(inner_loop);
@@ -347,25 +347,24 @@
}
TEST_F(SpvGeneratorImplTest, Loop_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* l = b.CreateLoop();
+ auto* l = b.Loop();
func->StartTarget()->Append(l);
- l->Initializer()->AddInboundBranch(l);
- l->Initializer()->Append(b.NextIteration(l, utils::Vector{b.Constant(1_i)}));
+ l->Initializer()->Append(b.NextIteration(l, 1_i));
auto* loop_param = b.BlockParam(b.ir.Types().i32());
- l->Body()->SetParams(utils::Vector{loop_param});
- auto* inc = b.Add(b.ir.Types().i32(), loop_param, b.Constant(1_i));
+ l->Body()->SetParams({loop_param});
+ auto* inc = b.Add(b.ir.Types().i32(), loop_param, 1_i);
l->Body()->Append(inc);
- l->Body()->Append(b.Continue(l, utils::Vector{inc}));
+ l->Body()->Append(b.Continue(l, inc));
auto* cont_param = b.BlockParam(b.ir.Types().i32());
- l->Continuing()->SetParams(utils::Vector{cont_param});
- auto* cmp = b.GreaterThan(b.ir.Types().bool_(), cont_param, b.Constant(5_i));
+ l->Continuing()->SetParams({cont_param});
+ auto* cmp = b.GreaterThan(b.ir.Types().bool_(), cont_param, 5_i);
l->Continuing()->Append(cmp);
- l->Continuing()->Append(b.BreakIf(cmp, l, utils::Vector{cont_param}));
+ l->Continuing()->Append(b.BreakIf(cmp, l, cont_param));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -400,29 +399,28 @@
}
TEST_F(SpvGeneratorImplTest, Loop_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* l = b.CreateLoop();
+ auto* l = b.Loop();
func->StartTarget()->Append(l);
- l->Initializer()->AddInboundBranch(l);
- l->Initializer()->Append(b.NextIteration(l, utils::Vector{b.Constant(1_i), b.Constant(false)}));
+ l->Initializer()->Append(b.NextIteration(l, 1_i, false));
auto* loop_param_a = b.BlockParam(b.ir.Types().i32());
auto* loop_param_b = b.BlockParam(b.ir.Types().bool_());
- l->Body()->SetParams(utils::Vector{loop_param_a, loop_param_b});
- auto* inc = b.Add(b.ir.Types().i32(), loop_param_a, b.Constant(1_i));
+ l->Body()->SetParams({loop_param_a, loop_param_b});
+ auto* inc = b.Add(b.ir.Types().i32(), loop_param_a, 1_i);
l->Body()->Append(inc);
- l->Body()->Append(b.Continue(l, utils::Vector{inc, loop_param_b}));
+ l->Body()->Append(b.Continue(l, inc, loop_param_b));
auto* cont_param_a = b.BlockParam(b.ir.Types().i32());
auto* cont_param_b = b.BlockParam(b.ir.Types().bool_());
- l->Continuing()->SetParams(utils::Vector{cont_param_a, cont_param_b});
- auto* cmp = b.GreaterThan(b.ir.Types().bool_(), cont_param_a, b.Constant(5_i));
+ l->Continuing()->SetParams({cont_param_a, cont_param_b});
+ auto* cmp = b.GreaterThan(b.ir.Types().bool_(), cont_param_a, 5_i);
l->Continuing()->Append(cmp);
auto* not_b = b.Not(b.ir.Types().bool_(), cont_param_b);
l->Continuing()->Append(not_b);
- l->Continuing()->Append(b.BreakIf(cmp, l, utils::Vector{cont_param_a, not_b}));
+ l->Continuing()->Append(b.BreakIf(cmp, l, cont_param_a, not_b));
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index 64e9491..215aad8 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -20,11 +20,11 @@
namespace {
TEST_F(SpvGeneratorImplTest, Switch_Basic) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* swtch = b.CreateSwitch(b.Constant(42_i));
+ auto* swtch = b.Switch(42_i);
- auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
swtch->Merge()->Append(b.Return(func));
@@ -52,17 +52,17 @@
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleCases) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* swtch = b.CreateSwitch(b.Constant(42_i));
+ auto* swtch = b.Switch(42_i);
- auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
case_a->Append(b.ExitSwitch(swtch));
- auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.ExitSwitch(swtch));
- auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
swtch->Merge()->Append(b.Return(func));
@@ -94,20 +94,20 @@
}
TEST_F(SpvGeneratorImplTest, Switch_MultipleSelectorsPerCase) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* swtch = b.CreateSwitch(b.Constant(42_i));
+ auto* swtch = b.Switch(42_i);
- auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{b.Constant(3_i)}});
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{b.Constant(3_i)}});
case_a->Append(b.ExitSwitch(swtch));
- auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
- ir::Switch::CaseSelector{b.Constant(4_i)}});
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
+ ir::Switch::CaseSelector{b.Constant(4_i)}});
case_b->Append(b.ExitSwitch(swtch));
- auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
- ir::Switch::CaseSelector()});
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
+ ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
swtch->Merge()->Append(b.Return(func));
@@ -139,17 +139,17 @@
}
TEST_F(SpvGeneratorImplTest, Switch_AllCasesReturn) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* swtch = b.CreateSwitch(b.Constant(42_i));
+ auto* swtch = b.Switch(42_i);
- auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
case_a->Append(b.Return(func));
- auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ auto* case_b = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.Return(func));
- auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.Return(func));
func->StartTarget()->Append(swtch);
@@ -179,19 +179,19 @@
}
TEST_F(SpvGeneratorImplTest, Switch_ConditionalBreak) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* swtch = b.CreateSwitch(b.Constant(42_i));
+ auto* swtch = b.Switch(42_i);
- auto* cond_break = b.CreateIf(b.Constant(true));
+ auto* cond_break = b.If(true);
cond_break->True()->Append(b.ExitSwitch(swtch));
cond_break->False()->Append(b.ExitIf(cond_break));
cond_break->Merge()->Append(b.Return(func));
- auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ auto* case_a = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
case_a->Append(cond_break);
- auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
swtch->Merge()->Append(b.Return(func));
@@ -228,22 +228,22 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param = b.BlockParam(b.ir.Types().i32());
- auto* s = b.CreateSwitch(b.Constant(42_i));
- auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.ExitSwitch(s, utils::Vector{b.Constant(10_i)}));
+ auto* s = b.Switch(42_i);
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->Append(b.ExitSwitch(s, 10_i));
- auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, utils::Vector{b.Constant(20_i)}));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->Append(b.ExitSwitch(s, 20_i));
- s->Merge()->SetParams(utils::Vector{merge_param});
+ s->Merge()->SetParams({merge_param});
s->Merge()->Append(b.Return(func));
- func->StartTarget()->SetInstructions(utils::Vector{s});
+ func->StartTarget()->SetInstructions({s});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -271,20 +271,20 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_SingleValue_CaseReturn) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* s = b.CreateSwitch(b.Constant(42_i));
- auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.Return(func, utils::Vector{b.Constant(10_i)}));
+ auto* s = b.Switch(42_i);
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->Append(b.Return(func, 10_i));
- auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, utils::Vector{b.Constant(20_i)}));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->Append(b.ExitSwitch(s, 20_i));
- s->Merge()->SetParams(utils::Vector{b.BlockParam(b.ir.Types().i32())});
+ s->Merge()->SetParams({b.BlockParam(b.ir.Types().i32())});
s->Merge()->Append(b.Return(func));
- func->StartTarget()->SetInstructions(utils::Vector{s});
+ func->StartTarget()->SetInstructions({s});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -312,23 +312,23 @@
}
TEST_F(SpvGeneratorImplTest, Switch_Phi_MultipleValue) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* merge_param_0 = b.BlockParam(b.ir.Types().i32());
auto* merge_param_1 = b.BlockParam(b.ir.Types().bool_());
- auto* s = b.CreateSwitch(b.Constant(42_i));
- auto* case_a = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
- ir::Switch::CaseSelector{nullptr}});
- case_a->Append(b.ExitSwitch(s, utils::Vector{b.Constant(10_i), b.Constant(true)}));
+ auto* s = b.Switch(42_i);
+ auto* case_a = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{nullptr}});
+ case_a->Append(b.ExitSwitch(s, 10_i, true));
- auto* case_b = b.CreateCase(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
- case_b->Append(b.ExitSwitch(s, utils::Vector{b.Constant(20_i), b.Constant(false)}));
+ auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->Append(b.ExitSwitch(s, 20_i, false));
- s->Merge()->SetParams(utils::Vector{merge_param_0, merge_param_1});
- s->Merge()->Append(b.Return(func, utils::Vector{merge_param_0}));
+ s->Merge()->SetParams({merge_param_0, merge_param_1});
+ s->Merge()->Append(b.Return(func, merge_param_0));
- func->StartTarget()->SetInstructions(utils::Vector{s});
+ func->StartTarget()->SetInstructions({s});
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 70c1712..5f161be 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -21,12 +21,11 @@
namespace {
TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
func->StartTarget()->SetInstructions(
- utils::Vector{b.Declare(ty.pointer(ty.i32(), builtin::AddressSpace::kFunction,
- builtin::Access::kReadWrite)),
- b.Return(func)});
+ {b.Var(ty.ptr(builtin::AddressSpace::kFunction, ty.i32(), builtin::Access::kReadWrite)),
+ b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -45,13 +44,13 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, ty.i32(), builtin::Access::kReadWrite));
v->SetInitializer(b.Constant(42_i));
- func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
+ func->StartTarget()->SetInstructions({v, b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -72,11 +71,11 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
- func->StartTarget()->SetInstructions(utils::Vector{v, b.Return(func)});
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, ty.i32(), builtin::Access::kReadWrite));
+ func->StartTarget()->SetInstructions({v, b.Return(func)});
mod.SetName(v, "myvar");
ASSERT_TRUE(IRIsValid()) << Error();
@@ -97,18 +96,18 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, ty.i32(), builtin::Access::kReadWrite));
v->SetInitializer(b.Constant(42_i));
- auto* i = b.CreateIf(b.Constant(true));
- i->True()->SetInstructions(utils::Vector{v, b.ExitIf(i)});
- i->False()->SetInstructions(utils::Vector{b.Return(func)});
- i->Merge()->SetInstructions(utils::Vector{b.Return(func)});
+ auto* i = b.If(true);
+ i->True()->SetInstructions({v, b.ExitIf(i)});
+ i->False()->SetInstructions({b.Return(func)});
+ i->Merge()->SetInstructions({b.Return(func)});
- func->StartTarget()->SetInstructions(utils::Vector{i});
+ func->StartTarget()->SetInstructions({i});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -138,12 +137,12 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Load) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
auto* store_ty = ty.i32();
- auto* v = b.Declare(
- ty.pointer(store_ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
- func->StartTarget()->SetInstructions(utils::Vector{v, b.Load(v), b.Return(func)});
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, store_ty, builtin::Access::kReadWrite));
+ func->StartTarget()->SetInstructions({v, b.Load(v), b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -163,12 +162,11 @@
}
TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
- auto* func = b.CreateFunction("foo", ty.void_());
+ auto* func = b.Function("foo", ty.void_());
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite));
- func->StartTarget()->SetInstructions(
- utils::Vector{v, b.Store(v, b.Constant(42_i)), b.Return(func)});
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kFunction, ty.i32(), builtin::Access::kReadWrite));
+ func->StartTarget()->SetInstructions({v, b.Store(v, 42_i), b.Return(func)});
ASSERT_TRUE(IRIsValid()) << Error();
@@ -189,8 +187,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_NoInit) {
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite))});
+ b.RootBlock()->SetInstructions(
+ {b.Var(ty.ptr(builtin::AddressSpace::kPrivate, ty.i32(), builtin::Access::kReadWrite))});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -211,9 +209,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_WithInit) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kPrivate, ty.i32(), builtin::Access::kReadWrite));
+ b.RootBlock()->SetInstructions({v});
v->SetInitializer(b.Constant(42_i));
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
@@ -236,9 +233,8 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_Name) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kPrivate, ty.i32(), builtin::Access::kReadWrite));
+ b.RootBlock()->SetInstructions({v});
v->SetInitializer(b.Constant(42_i));
mod.SetName(v, "myvar");
@@ -263,19 +259,18 @@
}
TEST_F(SpvGeneratorImplTest, PrivateVar_LoadAndStore) {
- auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
+ auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kFragment);
mod.functions.Push(func);
auto* store_ty = ty.i32();
- auto* v = b.Declare(
- ty.pointer(store_ty, builtin::AddressSpace::kPrivate, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kPrivate, store_ty, builtin::Access::kReadWrite));
+ b.RootBlock()->SetInstructions({v});
v->SetInitializer(b.Constant(42_i));
auto* load = b.Load(v);
- auto* add = b.Add(store_ty, v, b.Constant(1_i));
+ auto* add = b.Add(store_ty, v, 1_i);
auto* store = b.Store(v, add);
- func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+ func->StartTarget()->SetInstructions({load, add, store, b.Return(func)});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -301,8 +296,8 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite))});
+ b.RootBlock()->SetInstructions(
+ {b.Var(ty.ptr(builtin::AddressSpace::kWorkgroup, ty.i32(), builtin::Access::kReadWrite))});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -323,9 +318,9 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kWorkgroup, ty.i32(), builtin::Access::kReadWrite));
+ b.RootBlock()->SetInstructions({v});
mod.SetName(v, "myvar");
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
@@ -348,19 +343,19 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
- auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
- std::array{1u, 1u, 1u});
+ auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
mod.functions.Push(func);
auto* store_ty = ty.i32();
- auto* v = b.Declare(
- ty.pointer(store_ty, builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite));
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ auto* v =
+ b.Var(ty.ptr(builtin::AddressSpace::kWorkgroup, store_ty, builtin::Access::kReadWrite));
+ b.RootBlock()->SetInstructions({v});
auto* load = b.Load(v);
- auto* add = b.Add(store_ty, v, b.Constant(1_i));
+ auto* add = b.Add(store_ty, v, 1_i);
auto* store = b.Store(v, add);
- func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+ func->StartTarget()->SetInstructions({load, add, store, b.Return(func)});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -385,8 +380,8 @@
}
TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kWorkgroup, builtin::Access::kReadWrite))});
+ b.RootBlock()->SetInstructions(
+ {b.Var(ty.ptr(builtin::AddressSpace::kWorkgroup, ty.i32(), builtin::Access::kReadWrite))});
// Create a generator with the zero_init_workgroup_memory flag set to `true`.
spirv::GeneratorImplIr gen(&mod, true);
@@ -410,10 +405,9 @@
}
TEST_F(SpvGeneratorImplTest, StorageVar) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kStorage, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -441,10 +435,9 @@
}
TEST_F(SpvGeneratorImplTest, StorageVar_Name) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kStorage, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
mod.SetName(v, "myvar");
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
@@ -473,19 +466,18 @@
}
TEST_F(SpvGeneratorImplTest, StorageVar_LoadAndStore) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kStorage, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
- auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
- std::array{1u, 1u, 1u});
+ auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
mod.functions.Push(func);
auto* load = b.Load(v);
- auto* add = b.Add(ty.i32(), v, b.Constant(1_i));
+ auto* add = b.Add(ty.i32(), v, 1_i);
auto* store = b.Store(v, add);
- func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+ func->StartTarget()->SetInstructions({load, add, store, b.Return(func)});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -523,10 +515,9 @@
}
TEST_F(SpvGeneratorImplTest, UniformVar) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kUniform, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -554,10 +545,9 @@
}
TEST_F(SpvGeneratorImplTest, UniformVar_Name) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kUniform, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
mod.SetName(v, "myvar");
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
@@ -586,17 +576,16 @@
}
TEST_F(SpvGeneratorImplTest, UniformVar_Load) {
- auto* v = b.Declare(
- ty.pointer(ty.i32(), builtin::AddressSpace::kUniform, builtin::Access::kReadWrite));
+ auto* v = b.Var(ty.ptr(builtin::AddressSpace::kUniform, ty.i32(), builtin::Access::kReadWrite));
v->SetBindingPoint(0, 0);
- b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ b.RootBlock()->SetInstructions({v});
- auto* func = b.CreateFunction("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
- std::array{1u, 1u, 1u});
+ auto* func = b.Function("foo", ty.void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
mod.functions.Push(func);
auto* load = b.Load(v);
- func->StartTarget()->SetInstructions(utils::Vector{load, b.Return(func)});
+ func->StartTarget()->SetInstructions({load, b.Return(func)});
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc
index 9bf9a4c..cb790cb 100644
--- a/src/tint/writer/wgsl/generator_impl_type_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_type_test.cc
@@ -146,8 +146,7 @@
}
TEST_F(WgslGeneratorImplTest, EmitType_Pointer) {
- auto type =
- Alias("make_type_reachable", ty.pointer<f32>(builtin::AddressSpace::kWorkgroup))->type;
+ auto type = Alias("make_type_reachable", ty.ptr<f32>(builtin::AddressSpace::kWorkgroup))->type;
GeneratorImpl& gen = Build();
@@ -159,7 +158,7 @@
TEST_F(WgslGeneratorImplTest, EmitType_PointerAccessMode) {
auto type = Alias("make_type_reachable",
- ty.pointer<f32>(builtin::AddressSpace::kStorage, builtin::Access::kReadWrite))
+ ty.ptr<f32>(builtin::AddressSpace::kStorage, builtin::Access::kReadWrite))
->type;
GeneratorImpl& gen = Build();