Import Tint changes from Dawn
Changes:
- c0f0ab761bb9b9e12cc73c124604a0a4eb334d5e [ir] Split `TextGenerator` apart. by dan sinclair <dsinclair@chromium.org>
- 6dfc38b8a59d5aa70396fccc100260e4b0b091d9 [ir][validation] Validate `binary` by dan sinclair <dsinclair@chromium.org>
- a0aabafb16d5f4cbab607ab2b67b6700b691a85c Remove redundant parentheses on lambdas by Ben Clayton <bclayton@google.com>
- fbc2738baf84af039ae98097e89f3a24c40923fd Add More Dual Source Blending Tint Tests + Validation by Brandon Jones <brandon1.jones@intel.com>
- 965b0b65aa1466f26645d5023693d772b1f1a129 [tint][ir][ToProgram] Add additional tests by Ben Clayton <bclayton@google.com>
- 26fbdaa51384cbc8274b5adca452332afbfa5b92 [tint][ir][builder] Add callback based With() by Ben Clayton <bclayton@google.com>
- a48e9b4dda0c127fa4e4e100811905414013a808 [tint][ir] Fix build by Ben Clayton <bclayton@google.com>
- 7afd2c4147e89ae339b254439f8864e81d78473e [tint][ir][ToProgram] Add non-roundtrip tests by Ben Clayton <bclayton@google.com>
- 56cb75e81cb0831406eb26ce05d827df2b7cffbb [ir][validation] Validate `var` by dan sinclair <dsinclair@chromium.org>
- 5acd44cc94f5dcd6bd72b8be0a6070a2bb14c80e [tint][ir] Rename Function::StartTarget() to Block() by Ben Clayton <bclayton@google.com>
- 1d94ba00a58affdc6fb2ee18cbb0234c8a2ee3c6 [ir] Enable disabled `for` test. by dan sinclair <dsinclair@chromium.org>
- 91307895c1d4b1a13a7f0e966b975ce58318895f [tint][ir][ToProgram] Implement loops by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: c0f0ab761bb9b9e12cc73c124604a0a4eb334d5e
Change-Id: I801e55fc6d3eaa10690af5f09b22762a3693154e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/138661
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index ae14e47..3901d5d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -988,6 +988,8 @@
"writer/append_vector.h",
"writer/array_length_from_uniform_options.cc",
"writer/array_length_from_uniform_options.h",
+ "writer/ast_text_generator.cc",
+ "writer/ast_text_generator.h",
"writer/binding_point.h",
"writer/binding_remapper_options.cc",
"writer/binding_remapper_options.h",
@@ -1894,10 +1896,10 @@
tint_unittests_source_set("tint_unittests_writer_src") {
sources = [
"writer/append_vector_test.cc",
+ "writer/ast_text_generator_test.cc",
"writer/check_supported_extensions_test.cc",
"writer/flatten_bindings_test.cc",
"writer/float_to_string_test.cc",
- "writer/text_generator_test.cc",
]
deps = [
":libtint_unittests_ast_helper",
@@ -2363,6 +2365,7 @@
"ir/switch_test.cc",
"ir/swizzle_test.cc",
"ir/to_program_roundtrip_test.cc",
+ "ir/to_program_test.cc",
"ir/unary_test.cc",
"ir/user_call_test.cc",
"ir/validate_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 0b4cd81..a9c5a8d 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -562,6 +562,8 @@
writer/append_vector.h
writer/array_length_from_uniform_options.cc
writer/array_length_from_uniform_options.h
+ writer/ast_text_generator.cc
+ writer/ast_text_generator.h
writer/binding_point.h
writer/binding_remapper_options.cc
writer/binding_remapper_options.h
@@ -1110,10 +1112,10 @@
utils/unique_vector_test.cc
utils/vector_test.cc
writer/append_vector_test.cc
+ writer/ast_text_generator_test.cc
writer/check_supported_extensions_test.cc
writer/flatten_bindings_test.cc
writer/float_to_string_test.cc
- writer/text_generator_test.cc
)
# Noet, the source files are included here otherwise the cmd sources would not be included in the
@@ -1578,6 +1580,11 @@
)
endif()
+ if (${TINT_BUILD_IR} AND ${TINT_BUILD_WGSL_WRITER})
+ list(APPEND TINT_TEST_SRCS
+ ir/to_program_test.cc
+ )
+ endif()
if (${TINT_BUILD_FUZZERS})
list(APPEND TINT_TEST_SRCS
diff --git a/src/tint/ast/transform/array_length_from_uniform.cc b/src/tint/ast/transform/array_length_from_uniform.cc
index 0101658..eafd805 100644
--- a/src/tint/ast/transform/array_length_from_uniform.cc
+++ b/src/tint/ast/transform/array_length_from_uniform.cc
@@ -97,7 +97,7 @@
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const Variable* buffer_size_ubo = nullptr;
- auto get_ubo = [&]() {
+ auto get_ubo = [&] {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
diff --git a/src/tint/ast/transform/builtin_polyfill.cc b/src/tint/ast/transform/builtin_polyfill.cc
index 75d91b3..01b9f65 100644
--- a/src/tint/ast/transform/builtin_polyfill.cc
+++ b/src/tint/ast/transform/builtin_polyfill.cc
@@ -285,7 +285,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() {
+ auto U = [&] {
if (width == 1) {
return b.ty.u32();
}
@@ -343,7 +343,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() {
+ auto U = [&] {
if (width == 1) {
return b.ty.u32();
}
@@ -460,7 +460,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() {
+ auto U = [&] {
if (width == 1) {
return b.ty.u32();
}
@@ -532,7 +532,7 @@
uint32_t width = WidthOf(ty);
// Returns either u32 or vecN<u32>
- auto U = [&]() {
+ auto U = [&] {
if (width == 1) {
return b.ty.u32();
}
diff --git a/src/tint/ast/transform/demote_to_helper.cc b/src/tint/ast/transform/demote_to_helper.cc
index 8e67ddb..83bfee9 100644
--- a/src/tint/ast/transform/demote_to_helper.cc
+++ b/src/tint/ast/transform/demote_to_helper.cc
@@ -188,7 +188,7 @@
auto* result_struct = sem_call->Type()->As<type::Struct>();
auto* atomic_ty = result_struct->Members()[0]->Type();
result_ty = b.ty(
- utils::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&]() {
+ utils::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&] {
auto name = b.Sym();
b.Structure(
name,
diff --git a/src/tint/ast/transform/direct_variable_access.cc b/src/tint/ast/transform/direct_variable_access.cc
index 3267ae2..dd250a5 100644
--- a/src/tint/ast/transform/direct_variable_access.cc
+++ b/src/tint/ast/transform/direct_variable_access.cc
@@ -868,7 +868,7 @@
/// indices.
void TransformCall(const sem::Call* call) {
// Register a custom handler for the specific call expression
- ctx.Replace(call->Declaration(), [this, call]() {
+ ctx.Replace(call->Declaration(), [this, call] {
auto target_variant = clone_state->current_variant->calls.Find(call);
if (!target_variant) {
// The current variant does not need to transform this call.
diff --git a/src/tint/ast/transform/expand_compound_assignment.cc b/src/tint/ast/transform/expand_compound_assignment.cc
index 3862d2b..f793418 100644
--- a/src/tint/ast/transform/expand_compound_assignment.cc
+++ b/src/tint/ast/transform/expand_compound_assignment.cc
@@ -103,7 +103,7 @@
// foo.bar += rhs;
// After:
// foo.bar = foo.bar + rhs;
- new_lhs = [&]() { return ctx.Clone(lhs); };
+ new_lhs = [&] { return ctx.Clone(lhs); };
} else if (index_accessor && is_vec(index_accessor->object)) {
// This is the case for vector component via an array accessor. We need
// to capture a pointer to the vector and also the index value.
@@ -115,7 +115,7 @@
// (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
auto lhs_ptr = hoist_pointer_to(index_accessor->object);
auto index = hoist_expr_to_let(index_accessor->index);
- new_lhs = [&, lhs_ptr, index]() { return b.IndexAccessor(b.Deref(lhs_ptr), index); };
+ new_lhs = [&, lhs_ptr, index] { return b.IndexAccessor(b.Deref(lhs_ptr), index); };
} else if (member_accessor && is_vec(member_accessor->object)) {
// This is the case for vector component via a member accessor. We just
// need to capture a pointer to the vector.
@@ -125,7 +125,7 @@
// let vec_ptr = &a[idx()];
// (*vec_ptr).y = (*vec_ptr).y + rhs;
auto lhs_ptr = hoist_pointer_to(member_accessor->object);
- new_lhs = [&, lhs_ptr]() {
+ new_lhs = [&, lhs_ptr] {
return b.MemberAccessor(b.Deref(lhs_ptr), ctx.Clone(member_accessor->member));
};
} else {
@@ -137,7 +137,7 @@
// let lhs_ptr = &a[idx()];
// (*lhs_ptr) = (*lhs_ptr) + rhs;
auto lhs_ptr = hoist_pointer_to(lhs);
- new_lhs = [&, lhs_ptr]() { return b.Deref(lhs_ptr); };
+ new_lhs = [&, lhs_ptr] { return b.Deref(lhs_ptr); };
}
// Replace the statement with a regular assignment statement.
diff --git a/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc b/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
index 14112f0..1392794 100644
--- a/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/ast/transform/module_scope_var_to_entry_point_param.cc
@@ -122,7 +122,7 @@
auto* ty = var->Type()->UnwrapRef();
// Helper to create an AST node for the store type of the variable.
- auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
+ auto store_type = [&] { return CreateASTTypeFor(ctx, ty); };
builtin::AddressSpace sc = var->AddressSpace();
switch (sc) {
@@ -324,7 +324,7 @@
// Create a statement to assign the initializer if present.
if (var->initializer) {
- private_initializers.Push([&, name, var]() {
+ private_initializers.Push([&, name, var] {
return ctx.dst->Assign(
ctx.dst->MemberAccessor(PrivateStructVariableName(), name),
ctx.Clone(var->initializer));
@@ -401,7 +401,7 @@
// threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
StructMemberList workgroup_parameter_members;
- auto workgroup_param = [&]() {
+ auto workgroup_param = [&] {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
}
diff --git a/src/tint/ast/transform/num_workgroups_from_uniform.cc b/src/tint/ast/transform/num_workgroups_from_uniform.cc
index 6f7daad..54ecec5 100644
--- a/src/tint/ast/transform/num_workgroups_from_uniform.cc
+++ b/src/tint/ast/transform/num_workgroups_from_uniform.cc
@@ -127,7 +127,7 @@
// Get (or create, on first call) the uniform buffer that will receive the
// number of workgroups.
const Variable* num_workgroups_ubo = nullptr;
- auto get_ubo = [&]() {
+ auto get_ubo = [&] {
if (!num_workgroups_ubo) {
auto* num_workgroups_struct =
b.Structure(b.Sym(), utils::Vector{
diff --git a/src/tint/ast/transform/packed_vec3.cc b/src/tint/ast/transform/packed_vec3.cc
index 62f82d2..126bb7b 100644
--- a/src/tint/ast/transform/packed_vec3.cc
+++ b/src/tint/ast/transform/packed_vec3.cc
@@ -118,7 +118,7 @@
// Create a struct with a single `__packed_vec3` member.
// Give the struct member the same alignment as the original unpacked vec3
// type, to avoid changing the array element stride.
- return b.ty(packed_vec3_wrapper_struct_names.GetOrCreate(vec, [&]() {
+ return b.ty(packed_vec3_wrapper_struct_names.GetOrCreate(vec, [&] {
auto name = b.Symbols().New(
"tint_packed_vec3_" + vec->type()->FriendlyName() +
(array_element ? "_array_element" : "_struct_member"));
@@ -161,7 +161,7 @@
},
[&](const type::Struct* str) -> Type {
if (ContainsVec3(str)) {
- auto name = rewritten_structs.GetOrCreate(str, [&]() {
+ auto name = rewritten_structs.GetOrCreate(str, [&] {
utils::Vector<const StructMember*, 4> members;
for (auto* member : str->Members()) {
// If the member type contains a vec3, rewrite it.
@@ -281,7 +281,7 @@
/// @param ty the unpacked type
/// @returns an expression that holds the unpacked value
const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
- auto helper = unpack_helpers.GetOrCreate(ty, [&]() {
+ auto helper = unpack_helpers.GetOrCreate(ty, [&] {
return MakePackUnpackHelper(
"tint_unpack_vec3_in_composite", ty,
[&](const Expression* element,
@@ -297,8 +297,8 @@
return UnpackComposite(element, element_type);
}
},
- [&]() { return RewriteType(ty); }, //
- [&]() { return CreateASTTypeFor(ctx, ty); });
+ [&] { return RewriteType(ty); }, //
+ [&] { return CreateASTTypeFor(ctx, ty); });
});
return b.Call(helper, expr);
}
@@ -309,7 +309,7 @@
/// @param ty the unpacked type
/// @returns an expression that holds the packed value
const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
- auto helper = pack_helpers.GetOrCreate(ty, [&]() {
+ auto helper = pack_helpers.GetOrCreate(ty, [&] {
return MakePackUnpackHelper(
"tint_pack_vec3_in_composite", ty,
[&](const Expression* element,
@@ -326,8 +326,8 @@
return PackComposite(element, element_type);
}
},
- [&]() { return CreateASTTypeFor(ctx, ty); }, //
- [&]() { return RewriteType(ty); });
+ [&] { return CreateASTTypeFor(ctx, ty); }, //
+ [&] { return RewriteType(ty); });
});
return b.Call(helper, expr);
}
diff --git a/src/tint/ast/transform/preserve_padding.cc b/src/tint/ast/transform/preserve_padding.cc
index 3f51fa4..99908e1 100644
--- a/src/tint/ast/transform/preserve_padding.cc
+++ b/src/tint/ast/transform/preserve_padding.cc
@@ -119,7 +119,7 @@
const char* kValueParamName = "value";
auto call_helper = [&](auto&& body) {
EnableExtension();
- auto helper = helpers.GetOrCreate(ty, [&]() {
+ auto helper = helpers.GetOrCreate(ty, [&] {
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
utils::Vector<const Parameter*, 2> params = {
b.Param(kDestParamName,
@@ -136,7 +136,7 @@
ty, //
[&](const type::Array* arr) {
// Call a helper function that uses a loop to assigns each element separately.
- return call_helper([&]() {
+ return call_helper([&] {
utils::Vector<const Statement*, 8> body;
auto* idx = b.Var("i", b.Expr(0_u));
body.Push(
@@ -150,7 +150,7 @@
},
[&](const type::Matrix* mat) {
// Call a helper function that assigns each column separately.
- return call_helper([&]() {
+ return call_helper([&] {
utils::Vector<const Statement*, 4> body;
for (uint32_t i = 0; i < mat->columns(); i++) {
body.Push(MakeAssignment(mat->ColumnType(),
@@ -162,7 +162,7 @@
},
[&](const type::Struct* str) {
// Call a helper function that assigns each member separately.
- return call_helper([&]() {
+ return call_helper([&] {
utils::Vector<const Statement*, 8> body;
for (auto member : str->Members()) {
auto name = member->Name().Name();
diff --git a/src/tint/ast/transform/remove_continue_in_switch.cc b/src/tint/ast/transform/remove_continue_in_switch.cc
index 0ae266d..7843efc 100644
--- a/src/tint/ast/transform/remove_continue_in_switch.cc
+++ b/src/tint/ast/transform/remove_continue_in_switch.cc
@@ -59,7 +59,7 @@
made_changes = true;
auto cont_var_name =
- tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
+ tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&] {
// Create and insert 'var tint_continue : bool = false;' before the
// switch.
auto var_name = b.Symbols().New("tint_continue");
diff --git a/src/tint/ast/transform/vectorize_matrix_conversions.cc b/src/tint/ast/transform/vectorize_matrix_conversions.cc
index 66c8251..f012936 100644
--- a/src/tint/ast/transform/vectorize_matrix_conversions.cc
+++ b/src/tint/ast/transform/vectorize_matrix_conversions.cc
@@ -115,7 +115,7 @@
// Replace the matrix conversion to column vector conversions and a matrix construction.
if (!matrix->HasSideEffects()) {
// Simply use the argument's declaration if it has no side effects.
- return build_vectorized_conversion_expression([&]() { //
+ return build_vectorized_conversion_expression([&] { //
return ctx.Clone(matrix->Declaration());
});
} else {
@@ -132,7 +132,7 @@
},
CreateASTTypeFor(ctx, dst_type),
utils::Vector{
- b.Return(build_vectorized_conversion_expression([&]() { //
+ b.Return(build_vectorized_conversion_expression([&] { //
return b.Expr("value");
})),
});
diff --git a/src/tint/ast/transform/vertex_pulling.cc b/src/tint/ast/transform/vertex_pulling.cc
index 54cb8c0..cd9337a 100644
--- a/src/tint/ast/transform/vertex_pulling.cc
+++ b/src/tint/ast/transform/vertex_pulling.cc
@@ -766,7 +766,7 @@
ctx.InsertFront(func->body->statements, b.Decl(func_var));
// Capture mapping from location to the new variable.
LocationInfo info;
- info.expr = [this, func_var]() { return b.Expr(func_var); };
+ info.expr = [this, func_var] { return b.Expr(func_var); };
auto* sem = src->Sem().Get<sem::Parameter>(param);
info.type = sem->Type();
@@ -785,11 +785,11 @@
auto builtin = src->Sem().Get(builtin_attr)->Value();
// Check for existing vertex_index and instance_index builtins.
if (builtin == builtin::BuiltinValue::kVertexIndex) {
- vertex_index_expr = [this, param]() {
+ vertex_index_expr = [this, param] {
return b.Expr(ctx.Clone(param->name->symbol));
};
} else if (builtin == builtin::BuiltinValue::kInstanceIndex) {
- instance_index_expr = [this, param]() {
+ instance_index_expr = [this, param] {
return b.Expr(ctx.Clone(param->name->symbol));
};
}
@@ -815,7 +815,7 @@
utils::Vector<const StructMember*, 8> members_to_clone;
for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->name->symbol);
- std::function<const Expression*()> member_expr = [this, param_sym, member_sym]() {
+ std::function<const Expression*()> member_expr = [this, param_sym, member_sym] {
return b.MemberAccessor(param_sym, member_sym);
};
@@ -907,7 +907,7 @@
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(builtin::BuiltinValue::kVertexIndex)}));
- vertex_index_expr = [this, name]() { return b.Expr(name); };
+ vertex_index_expr = [this, name] { return b.Expr(name); };
break;
}
}
@@ -919,7 +919,7 @@
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(builtin::BuiltinValue::kInstanceIndex)}));
- instance_index_expr = [this, name]() { return b.Expr(name); };
+ instance_index_expr = [this, name] { return b.Expr(name); };
break;
}
}
diff --git a/src/tint/inspector/test_inspector_builder.cc b/src/tint/inspector/test_inspector_builder.cc
index c351ec3..6511ef6 100644
--- a/src/tint/inspector/test_inspector_builder.cc
+++ b/src/tint/inspector/test_inspector_builder.cc
@@ -273,7 +273,7 @@
case type::TextureDimension::kCubeArray:
return ty.vec3(scalar);
default:
- [=]() {
+ [=] {
utils::StringStream str;
str << dim;
FAIL() << "Unsupported texture dimension: " << str.str();
@@ -313,19 +313,19 @@
std::function<ast::Type()> func;
switch (component) {
case ComponentType::kF32:
- func = [this]() { return ty.f32(); };
+ func = [this] { return ty.f32(); };
break;
case ComponentType::kI32:
- func = [this]() { return ty.i32(); };
+ func = [this] { return ty.i32(); };
break;
case ComponentType::kU32:
- func = [this]() { return ty.u32(); };
+ func = [this] { return ty.u32(); };
break;
case ComponentType::kF16:
- func = [this]() { return ty.f16(); };
+ func = [this] { return ty.f16(); };
break;
case ComponentType::kUnknown:
- return []() { return ast::Type{}; };
+ return [] { return ast::Type{}; };
}
uint32_t n;
@@ -342,10 +342,10 @@
n = 4;
break;
default:
- return []() { return ast::Type{}; };
+ return [] { return ast::Type{}; };
}
- return [this, func, n]() { return ty.vec(func(), n); };
+ return [this, func, n] { return ty.vec(func(), n); };
}
Inspector& InspectorBuilder::Build() {
@@ -353,7 +353,7 @@
return *inspector_;
}
program_ = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program_->IsValid()) << diag::Formatter().format(program_->Diagnostics());
}();
inspector_ = std::make_unique<Inspector>(program_.get());
diff --git a/src/tint/inspector/test_inspector_runner.cc b/src/tint/inspector/test_inspector_runner.cc
index 16e4196..9b88ac8 100644
--- a/src/tint/inspector/test_inspector_runner.cc
+++ b/src/tint/inspector/test_inspector_runner.cc
@@ -26,7 +26,7 @@
file_ = std::make_unique<Source::File>("test", shader);
program_ = std::make_unique<Program>(reader::wgsl::Parse(file_.get()));
- [&]() {
+ [&] {
ASSERT_TRUE(program_->IsValid()) << diag::Formatter().format(program_->Diagnostics());
}();
inspector_ = std::make_unique<Inspector>(program_.get());
diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc
index 2bb295b..2b3c9a4 100644
--- a/src/tint/ir/builder.cc
+++ b/src/tint/ir/builder.cc
@@ -48,7 +48,7 @@
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
auto* ir_func = ir.values.Create<ir::Function>(return_type, stage, wg_size);
- ir_func->SetStartTarget(Block());
+ ir_func->SetBlock(Block());
ir.SetName(ir_func, name);
return ir_func;
}
diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h
index 8db4098..3ad87a6 100644
--- a/src/tint/ir/builder.h
+++ b/src/tint/ir/builder.h
@@ -59,6 +59,7 @@
#include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
#include "src/tint/type/void.h"
+#include "src/tint/utils/scoped_assignment.h"
namespace tint::ir {
@@ -76,14 +77,6 @@
using DisableIfVectorLike = utils::traits::EnableIf<
!utils::IsVectorLike<utils::traits::Decay<utils::traits::NthTypeOf<0, TYPES..., void>>>>;
- template <typename T>
- T* Append(T* val) {
- if (current_block_) {
- current_block_->Append(val);
- }
- return val;
- }
-
/// If set, any created instruction will be auto-appended to the block.
ir::Block* current_block_ = nullptr;
@@ -101,7 +94,28 @@
/// Creates a new builder wrapping the given block
/// @param b the block to set as the current block
/// @returns the builder
- Builder With(Block* b) { return Builder(ir, b); }
+ Builder With(ir::Block* b) { return Builder(ir, b); }
+
+ /// Calls @p cb with the builder appending to block @p b
+ /// @param b the block to set as the block to append to
+ /// @param cb the function to call with the builder appending to block @p b
+ template <typename FUNCTION>
+ void With(ir::Block* b, FUNCTION&& cb) {
+ TINT_SCOPED_ASSIGNMENT(current_block_, b);
+ cb();
+ }
+
+ /// Appends and returns the instruction @p val to the current block. If there is no current
+ /// block bound, then @p val is just returned.
+ /// @param val the instruction to append
+ /// @returns the instruction
+ template <typename T>
+ T* Append(T* val) {
+ if (current_block_) {
+ current_block_->Append(val);
+ }
+ return val;
+ }
/// @returns a new block
ir::Block* Block();
@@ -157,7 +171,7 @@
/// @param val the constant value
/// @returns the new constant
ir::Constant* Constant(const constant::Value* val) {
- return ir.constants.GetOrCreate(val, [&]() { return ir.values.Create<ir::Constant>(val); });
+ return ir.constants.GetOrCreate(val, [&] { return ir.values.Create<ir::Constant>(val); });
}
/// Creates a ir::Constant for an i32 Scalar
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 629146a..03c88ee 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -196,7 +196,7 @@
out_ << " [";
bool need_comma = false;
- auto comma = [&]() {
+ auto comma = [&] {
if (need_comma) {
out_ << ", ";
}
@@ -233,7 +233,7 @@
out_ << " [";
bool need_comma = false;
- auto comma = [&]() {
+ auto comma = [&] {
if (need_comma) {
out_ << ", ";
}
@@ -283,19 +283,23 @@
EmitReturnAttributes(func);
- out_ << " -> %b" << IdOf(func->StartTarget()) << " {";
+ out_ << " -> %b" << IdOf(func->Block()) << " {";
EmitLine();
{
ScopedIndent si(indent_size_);
- EmitBlock(func->StartTarget());
+ EmitBlock(func->Block());
}
Indent() << "}";
EmitLine();
}
void Disassembler::EmitValueWithType(Instruction* val) {
- EmitValueWithType(val->Result());
+ if (val->Result()) {
+ EmitValueWithType(val->Result());
+ } else {
+ out_ << "undef";
+ }
}
void Disassembler::EmitValueWithType(Value* val) {
@@ -374,7 +378,7 @@
[&](If* i) { EmitIf(i); }, //
[&](Loop* l) { EmitLoop(l); }, //
[&](Binary* b) { EmitBinary(b); }, //
- [&](Unary* u) { EmitUnary(u); },
+ [&](Unary* u) { EmitUnary(u); }, //
[&](Bitcast* b) {
EmitValueWithType(b);
out_ << " = ";
@@ -444,7 +448,7 @@
EmitInstructionName("var", v);
if (v->Initializer()) {
out_ << ", ";
- EmitValue(v->Initializer());
+ EmitOperand(v, v->Initializer(), Var::kInitializerOperandOffset);
}
if (v->BindingPoint().has_value()) {
out_ << " ";
@@ -681,6 +685,7 @@
}
void Disassembler::EmitBinary(Binary* b) {
+ SourceMarker sm(this);
EmitValueWithType(b);
out_ << " = ";
switch (b->Kind()) {
@@ -737,10 +742,13 @@
EmitValue(b->LHS());
out_ << ", ";
EmitValue(b->RHS());
+
+ sm.Store(b);
EmitLine();
}
void Disassembler::EmitUnary(Unary* u) {
+ SourceMarker sm(this);
EmitValueWithType(u);
out_ << " = ";
switch (u->Kind()) {
@@ -753,6 +761,8 @@
}
out_ << " ";
EmitValue(u->Val());
+
+ sm.Store(u);
EmitLine();
}
diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc
index 3ff8811..b6c6586 100644
--- a/src/tint/ir/from_program.cc
+++ b/src/tint/ir/from_program.cc
@@ -417,7 +417,7 @@
}
ir_func->SetParams(params);
- TINT_SCOPED_ASSIGNMENT(current_block_, ir_func->StartTarget());
+ TINT_SCOPED_ASSIGNMENT(current_block_, ir_func->Block());
EmitBlock(ast_func->body);
// Add a terminator if one was not already created.
@@ -701,6 +701,9 @@
SetTerminator(builder_.ExitLoop(loop_inst));
}
+ EmitStatements(stmt->body->statements);
+
+ // The current block didn't `break`, `return` or `continue`, go to the continuing block.
if (NeedTerminator()) {
SetTerminator(builder_.Continue(loop_inst));
}
diff --git a/src/tint/ir/from_program_test.cc b/src/tint/ir/from_program_test.cc
index 81738d6..ecee83b 100644
--- a/src/tint/ir/from_program_test.cc
+++ b/src/tint/ir/from_program_test.cc
@@ -62,7 +62,7 @@
ASSERT_EQ(1u, m->functions.Length());
auto* f = m->functions[0];
- ASSERT_NE(f->StartTarget(), nullptr);
+ ASSERT_NE(f->Block(), nullptr);
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
@@ -83,7 +83,7 @@
ASSERT_EQ(1u, m->functions.Length());
auto* f = m->functions[0];
- ASSERT_NE(f->StartTarget(), nullptr);
+ ASSERT_NE(f->Block(), nullptr);
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
@@ -105,7 +105,7 @@
ASSERT_EQ(1u, m->functions.Length());
auto* f = m->functions[0];
- ASSERT_NE(f->StartTarget(), nullptr);
+ ASSERT_NE(f->Block(), nullptr);
EXPECT_EQ(m->functions[0]->Stage(), Function::PipelineStage::kUndefined);
@@ -681,7 +681,7 @@
ASSERT_EQ(1u, m.functions.Length());
EXPECT_EQ(1u, loop->Body()->InboundSiblingBranches().Length());
- EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
+ EXPECT_EQ(0u, loop->Continuing()->InboundSiblingBranches().Length());
EXPECT_EQ(Disassemble(m),
R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
@@ -696,7 +696,7 @@
exit_loop # loop_1
}
}
- continue %b3
+ ret
}
%b3 = block { # continuing
next_iteration %b2
@@ -708,20 +708,7 @@
)");
}
-// TODO(dsinclair): Enable when variable declarations and increment are supported
-TEST_F(IR_FromProgramTest, DISABLED_For) {
- // for(var i: 0; i < 10; i++) {
- // }
- //
- // func -> loop -> loop start -> if true
- // -> if false
- //
- // [if true] -> if merge
- // [if false] -> loop merge
- // [if merge] -> loop continuing
- // [loop continuing] -> loop start
- // [loop merge] -> func end
- //
+TEST_F(IR_FromProgramTest, For) {
auto* ast_for = For(Decl(Var("i", ty.i32())), LessThan("i", 10_a), Increment("i"), Block());
WrapInFunction(ast_for);
@@ -736,7 +723,38 @@
EXPECT_EQ(2u, loop->Body()->InboundSiblingBranches().Length());
EXPECT_EQ(1u, loop->Continuing()->InboundSiblingBranches().Length());
- EXPECT_EQ(Disassemble(m), R"()");
+ EXPECT_EQ(Disassemble(m),
+ R"(%test_function = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ loop [i: %b2, b: %b3, c: %b4] { # loop_1
+ %b2 = block { # initializer
+ %i:ptr<function, i32, read_write> = var
+ next_iteration %b3
+ }
+ %b3 = block { # body
+ %3:i32 = load %i
+ %4:bool = lt %3, 10i
+ if %4 [t: %b5, f: %b6] { # if_1
+ %b5 = block { # true
+ exit_if # if_1
+ }
+ %b6 = block { # false
+ exit_loop # loop_1
+ }
+ }
+ continue %b4
+ }
+ %b4 = block { # continuing
+ %5:i32 = load %i
+ %6:i32 = add %5, 1i
+ store %i, %6
+ next_iteration %b3
+ }
+ }
+ ret
+ }
+}
+)");
}
TEST_F(IR_FromProgramTest, For_Init_NoCondOrContinuing) {
diff --git a/src/tint/ir/function.h b/src/tint/ir/function.h
index a9e1801..0aadc64 100644
--- a/src/tint/ir/function.h
+++ b/src/tint/ir/function.h
@@ -120,14 +120,14 @@
/// @returns the function parameters
const utils::VectorRef<FunctionParam*> Params() { return params_; }
- /// Sets the start target for the function
- /// @param target the start target
- void SetStartTarget(Block* target) {
+ /// Sets the root block for the function
+ /// @param target the root block
+ void SetBlock(Block* target) {
TINT_ASSERT(IR, target != nullptr);
- start_target_ = target;
+ block_ = target;
}
- /// @returns the function start target
- Block* StartTarget() { return start_target_; }
+ /// @returns the function root block
+ ir::Block* Block() { return block_; }
private:
PipelineStage pipeline_stage_;
@@ -141,7 +141,7 @@
} return_;
utils::Vector<FunctionParam*, 1> params_;
- Block* start_target_ = nullptr;
+ ir::Block* block_ = nullptr;
};
utils::StringStream& operator<<(utils::StringStream& out, Function::PipelineStage value);
diff --git a/src/tint/ir/function_test.cc b/src/tint/ir/function_test.cc
index 35076df..f5e55d0 100644
--- a/src/tint/ir/function_test.cc
+++ b/src/tint/ir/function_test.cc
@@ -55,13 +55,13 @@
"");
}
-TEST_F(IR_FunctionTest, Fail_NullStartTarget) {
+TEST_F(IR_FunctionTest, Fail_NullBlock) {
EXPECT_FATAL_FAILURE(
{
Module mod;
Builder b{mod};
auto* f = b.Function("my_func", mod.Types().void_());
- f->SetStartTarget(nullptr);
+ f->SetBlock(nullptr);
},
"");
}
diff --git a/src/tint/ir/program_test_helper.h b/src/tint/ir/program_test_helper.h
index d982366..5f20862 100644
--- a/src/tint/ir/program_test_helper.h
+++ b/src/tint/ir/program_test_helper.h
@@ -41,7 +41,7 @@
SetResolveOnBuild(true);
auto program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
diag::Formatter formatter;
ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics());
}();
diff --git a/src/tint/ir/to_program.cc b/src/tint/ir/to_program.cc
index 495f608..e3a9604 100644
--- a/src/tint/ir/to_program.cc
+++ b/src/tint/ir/to_program.cc
@@ -23,13 +23,17 @@
#include "src/tint/ir/block.h"
#include "src/tint/ir/call.h"
#include "src/tint/ir/constant.h"
+#include "src/tint/ir/continue.h"
#include "src/tint/ir/exit_if.h"
+#include "src/tint/ir/exit_loop.h"
#include "src/tint/ir/exit_switch.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/load.h"
+#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/multi_in_block.h"
+#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
#include "src/tint/ir/switch.h"
@@ -127,7 +131,7 @@
auto name = BindName(fn);
auto ret_ty = Type(fn->ReturnType());
- auto* body = Block(fn->StartTarget());
+ auto* body = Block(fn->Block());
utils::Vector<const ast::Attribute*, 1> attrs{};
utils::Vector<const ast::Attribute*, 1> ret_attrs{};
return b.Func(name, std::move(params), ret_ty, body, std::move(attrs),
@@ -159,13 +163,17 @@
[&](ir::Call* i) { Call(i); }, //
[&](ir::ExitIf*) {}, //
[&](ir::ExitSwitch* i) { ExitSwitch(i); }, //
+ [&](ir::ExitLoop* i) { ExitLoop(i); }, //
[&](ir::If* i) { If(i); }, //
[&](ir::Load* l) { Load(l); }, //
+ [&](ir::Loop* l) { Loop(l); }, //
[&](ir::Return* i) { Return(i); }, //
[&](ir::Store* i) { Store(i); }, //
[&](ir::Switch* i) { Switch(i); }, //
[&](ir::Unary* u) { Unary(u); }, //
[&](ir::Var* i) { Var(i); }, //
+ [&](ir::NextIteration*) {}, //
+ [&](ir::Continue*) {}, //
[&](Default) { UNHANDLED_CASE(inst); });
}
@@ -174,7 +182,7 @@
auto true_stmts = Statements(if_->True());
auto false_stmts = Statements(if_->False());
- if (IsShortCircuit(if_, true_stmts, false_stmts)) {
+ if (AsShortCircuit(if_, true_stmts, false_stmts)) {
return;
}
@@ -197,6 +205,57 @@
Append(b.If(cond, true_block, b.Else(false_block)));
}
+ void Loop(ir::Loop* l) {
+ auto init_stmts = Statements(l->Initializer());
+ auto* init = init_stmts.Length() == 1 ? init_stmts.Front()->As<ast::VariableDeclStatement>()
+ : nullptr;
+
+ const ast::Expression* cond = nullptr;
+
+ StatementList body_stmts;
+ {
+ TINT_SCOPED_ASSIGNMENT(statements_, &body_stmts);
+ for (auto* inst : *l->Body()) {
+ if (body_stmts.IsEmpty()) {
+ if (auto* if_ = inst->As<ir::If>()) {
+ if (!if_->HasResults() && //
+ if_->True()->Length() == 1 && //
+ if_->False()->Length() == 1 && //
+ tint::Is<ir::ExitIf>(if_->True()->Front()) && //
+ tint::Is<ir::ExitLoop>(if_->False()->Front())) {
+ cond = Expr(if_->Condition());
+ continue;
+ }
+ }
+ }
+
+ Instruction(inst);
+ }
+ }
+
+ auto cont_stmts = Statements(l->Continuing());
+ auto* cont = cont_stmts.Length() == 1 ? cont_stmts.Front() : nullptr;
+
+ auto* body = b.Block(std::move(body_stmts));
+
+ const ast::Statement* loop = nullptr;
+ if (cond) {
+ if (init || cont) {
+ loop = b.For(init, cond, cont, body);
+ } else {
+ loop = b.While(cond, body);
+ }
+ } else {
+ loop = cont_stmts.IsEmpty() ? b.Loop(body) //
+ : b.Loop(body, b.Block(std::move(cont_stmts)));
+ if (!init_stmts.IsEmpty()) {
+ init_stmts.Push(loop);
+ loop = b.Block(std::move(init_stmts));
+ }
+ }
+ statements_->Push(loop);
+ }
+
void Switch(ir::Switch* s) {
SCOPED_NESTING();
@@ -232,6 +291,8 @@
Append(b.Break());
}
+ void ExitLoop(const ir::ExitLoop*) { Append(b.Break()); }
+
void Return(ir::Return* ret) {
if (ret->Args().IsEmpty()) {
// Return has no arguments.
@@ -575,7 +636,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers
////////////////////////////////////////////////////////////////////////////////////////////////
- bool IsShortCircuit(ir::If* i,
+ bool AsShortCircuit(ir::If* i,
const StatementList& true_stmts,
const StatementList& false_stmts) {
if (!i->HasResults()) {
diff --git a/src/tint/ir/to_program_roundtrip_test.cc b/src/tint/ir/to_program_roundtrip_test.cc
index 2d91267..c8552bc 100644
--- a/src/tint/ir/to_program_roundtrip_test.cc
+++ b/src/tint/ir/to_program_roundtrip_test.cc
@@ -264,7 +264,6 @@
////////////////////////////////////////////////////////////////////////////////
// Short-circuiting binary ops
////////////////////////////////////////////////////////////////////////////////
-
TEST_F(IRToProgramRoundtripTest, BinaryOp_LogicalAnd_Param_2) {
Test(R"(
fn f(a : bool, b : bool) -> bool {
@@ -499,7 +498,6 @@
////////////////////////////////////////////////////////////////////////////////
// Compound assignment
////////////////////////////////////////////////////////////////////////////////
-
TEST_F(IRToProgramRoundtripTest, CompoundAssign_Increment) {
Test(R"(
fn f() {
@@ -663,8 +661,7 @@
fn a() {
}
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
a();
}
@@ -674,8 +671,7 @@
TEST_F(IRToProgramRoundtripTest, If_Return) {
Test(R"(
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
return;
}
@@ -703,8 +699,7 @@
fn b() {
}
-fn f() {
- var cond : bool = true;
+fn f(cond : bool) {
if (cond) {
a();
} else {
@@ -760,8 +755,8 @@
}
fn f() {
- var cond_a : bool = true;
- if (cond_a) {
+ var cond : bool = true;
+ if (cond) {
a();
} else if (false) {
b();
@@ -915,5 +910,298 @@
)");
}
+////////////////////////////////////////////////////////////////////////////////
+// For
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, For_Empty) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoInit) {
+ Test(R"(
+fn f() {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoCond) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; ; i = (i + 1i)) {
+ break;
+ }
+}
+)",
+ R"(
+fn f() {
+ {
+ var i : i32 = 0i;
+ loop {
+ break;
+
+ continuing {
+ i = (i + 1i);
+ }
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_Empty_NoCont) {
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoInit) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoCond) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; ; i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+}
+)",
+ R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ {
+ var i : i32 = 0i;
+ loop {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+
+ continuing {
+ i = (i + 1i);
+ }
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_ComplexBody_NoCont) {
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, For_CallInInitCondCont) {
+ Test(R"(
+fn n(v : i32) -> i32 {
+ return (v + 1i);
+}
+
+fn f() {
+ for(var i : i32 = n(0i); (i < n(1i)); i = n(i)) {
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// While
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, While_Empty) {
+ Test(R"(
+fn f() {
+ while(true) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_Cond) {
+ Test(R"(
+fn f(cond : bool) {
+ while(cond) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_Break) {
+ Test(R"(
+fn f() {
+ while(true) {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_IfBreak) {
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, While_IfReturn) {
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Loop
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramRoundtripTest, Loop_Break) {
+ Test(R"(
+fn f() {
+ loop {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfBreak) {
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfReturn) {
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_IfContinuing) {
+ Test(R"(
+fn f() {
+ var cond : bool = false;
+ loop {
+ if (cond) {
+ return;
+ }
+
+ continuing {
+ cond = true;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramRoundtripTest, Loop_VarsDeclaredOutsideAndInside) {
+ Test(R"(
+fn f() {
+ var b : i32 = 1i;
+ loop {
+ var a : i32 = 2i;
+ if ((a == b)) {
+ return;
+ }
+
+ continuing {
+ b = (a + b);
+ }
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/ir/to_program_test.cc b/src/tint/ir/to_program_test.cc
new file mode 100644
index 0000000..8cfae5c
--- /dev/null
+++ b/src/tint/ir/to_program_test.cc
@@ -0,0 +1,2513 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "src/tint/ir/disassembler.h"
+#include "src/tint/ir/ir_test_helper.h"
+#include "src/tint/ir/to_program.h"
+#include "src/tint/utils/string.h"
+#include "src/tint/writer/wgsl/generator.h"
+
+#if !TINT_BUILD_WGSL_WRITER
+#error "to_program_test.cc requires both the WGSL writer to be enabled"
+#endif
+
+namespace tint::ir {
+namespace {
+
+using namespace tint::number_suffixes; // NOLINT
+using namespace tint::builtin::fluent_types; // NOLINT
+
+class IRToProgramTest : public IRTestHelper {
+ public:
+ void Test(std::string_view expected_wgsl) {
+ tint::ir::Disassembler d{mod};
+ auto disassembly = d.Disassemble();
+
+ auto output_program = ToProgram(mod);
+ if (!output_program.IsValid()) {
+ FAIL() << output_program.Diagnostics().str() << std::endl //
+ << "IR:" << std::endl //
+ << disassembly << std::endl //
+ << "AST:" << std::endl //
+ << Program::printer(&output_program) << std::endl;
+ }
+
+ ASSERT_TRUE(output_program.IsValid()) << output_program.Diagnostics().str();
+
+ auto output = writer::wgsl::Generate(&output_program, {});
+ ASSERT_TRUE(output.success) << output.error;
+
+ auto expected = std::string(utils::TrimSpace(expected_wgsl));
+ if (!expected.empty()) {
+ expected = "\n" + expected + "\n";
+ }
+ auto got = std::string(utils::TrimSpace(output.wgsl));
+ if (!got.empty()) {
+ got = "\n" + got + "\n";
+ }
+ EXPECT_EQ(expected, got) << "IR:" << std::endl << disassembly;
+ }
+};
+
+TEST_F(IRToProgramTest, EmptyModule) {
+ Test("");
+}
+
+TEST_F(IRToProgramTest, SingleFunction_Empty) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ Test(R"(
+fn f() {
+}
+)");
+}
+
+TEST_F(IRToProgramTest, SingleFunction_Return) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ fn->Block()->Append(b.Return(fn));
+
+ Test(R"(
+fn f() {
+}
+)");
+}
+
+TEST_F(IRToProgramTest, SingleFunction_Return_i32) {
+ auto* fn = b.Function("f", ty.i32());
+ mod.functions.Push(fn);
+
+ fn->Block()->Append(b.Return(fn, 42_i));
+
+ Test(R"(
+fn f() -> i32 {
+ return 42i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, SingleFunction_Parameters) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* i = b.FunctionParam(ty.i32());
+ auto* u = b.FunctionParam(ty.u32());
+ mod.SetName(i, "i");
+ mod.SetName(u, "u");
+ fn->SetParams({i, u});
+ mod.functions.Push(fn);
+
+ fn->Block()->Append(b.Return(fn, i));
+
+ Test(R"(
+fn f(i : i32, u : u32) -> i32 {
+ return i;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Unary ops
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, UnaryOp_Negate) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* i = b.FunctionParam(ty.i32());
+ mod.SetName(i, "i");
+ fn->SetParams({i});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Negation(ty.i32(), i)); });
+
+ Test(R"(
+fn f(i : i32) -> i32 {
+ return -(i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, UnaryOp_Complement) {
+ auto* fn = b.Function("f", ty.u32());
+ auto* i = b.FunctionParam(ty.u32());
+ mod.SetName(i, "i");
+ fn->SetParams({i});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Complement(ty.u32(), i)); });
+
+ Test(R"(
+fn f(i : u32) -> u32 {
+ return ~(i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, UnaryOp_Not) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* i = b.FunctionParam(ty.bool_());
+ mod.SetName(i, "b");
+ fn->SetParams({i});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Not(ty.bool_(), i)); });
+
+ Test(R"(
+fn f(b : bool) -> bool {
+ return !(b);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Binary ops
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, BinaryOp_Add) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Add(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a + b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Subtract) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Subtract(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a - b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Multiply) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Multiply(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a * b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Divide) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Divide(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a / b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Modulo) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Modulo(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a % b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_And) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.And(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a & b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Or) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Or(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a | b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Xor) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Xor(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> i32 {
+ return (a ^ b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_Equal) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.Equal(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a == b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_NotEqual) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.NotEqual(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a != b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LessThan) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.LessThan(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a < b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_GreaterThan) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.GreaterThan(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a > b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LessThanEqual) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.LessThanEqual(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a <= b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_GreaterThanEqual) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.i32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.GreaterThanEqual(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : i32) -> bool {
+ return (a >= b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_ShiftLeft) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.u32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.ShiftLeft(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : u32) -> i32 {
+ return (a << b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_ShiftRight) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* pa = b.FunctionParam(ty.i32());
+ auto* pb = b.FunctionParam(ty.u32());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] { b.Return(fn, b.ShiftRight(ty.i32(), pa, pb)); });
+
+ Test(R"(
+fn f(a : i32, b : u32) -> i32 {
+ return (a >> b);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Short-circuiting binary ops
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_2) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(pa);
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, pb); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, false); });
+
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ return (a && b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_3_ab_c) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pa);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pb); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, pc); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ return ((a && b) && c);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Param_3_a_bc) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ return (a && (b && c));
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_2) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(pa);
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, pb); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, false); });
+
+ mod.SetName(if_->Result(0), "l");
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ let l = (a && b);
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_3_ab_c) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pa);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pb); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, pc); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let l = ((a && b) && c);
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Let_3_a_bc) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, pc); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let l = (a && (b && c));
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_2) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(b.Call(ty.bool_(), fn_a));
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, b.Call(ty.bool_(), fn_b)); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, false); });
+
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return (a() && b());
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_3_ab_c) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+ mod.functions.Push(fn_c);
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_b)); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, b.Call(ty.bool_(), fn_c)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return ((a() && b()) && c());
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalAnd_Call_3_a_bc) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+ mod.functions.Push(fn_c);
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, false); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, if1->Result(0)); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return (a() && (b() && c()));
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_2) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(pa);
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, true); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, pb); });
+
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ return (a || b);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_3_ab_c) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pa);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pb); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, pc); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ return ((a || b) || c);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Param_3_a_bc) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ return (a || (b || c));
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_2) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ fn->SetParams({pa, pb});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(pa);
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, true); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, pb); });
+
+ mod.SetName(if_->Result(0), "l");
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool) -> bool {
+ let l = (a || b);
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_3_ab_c) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pa);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pb); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, pc); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let l = ((a || b) || c);
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Let_3_a_bc) {
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pb, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pb);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, pc); });
+
+ auto* if2 = b.If(pa);
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn f(a : bool, b : bool, c : bool) -> bool {
+ let l = (a || (b || c));
+ return l;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_2) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(b.Call(ty.bool_(), fn_a));
+ if_->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if_->True(), [&] { b.ExitIf(if_, true); });
+ b.With(if_->False(), [&] { b.ExitIf(if_, b.Call(ty.bool_(), fn_b)); });
+
+ b.Return(fn, if_->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return (a() || b());
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_3_ab_c) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+ mod.functions.Push(fn_c);
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_a));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_b)); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, b.Call(ty.bool_(), fn_c)); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return ((a() || b()) || c());
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalOr_Call_3_a_bc) {
+ auto* fn_a = b.Function("a", ty.bool_());
+ mod.functions.Push(fn_a);
+ b.With(fn_a->Block(), [&] { b.Return(fn_a, true); });
+
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_c = b.Function("c", ty.bool_());
+ mod.functions.Push(fn_c);
+ b.With(fn_c->Block(), [&] { b.Return(fn_c, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(b.Call(ty.bool_(), fn_b));
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_c)); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), fn_a));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] { b.ExitIf(if2, true); });
+ b.With(if2->False(), [&] { b.ExitIf(if2, if1->Result(0)); });
+
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn a() -> bool {
+ return true;
+}
+
+fn b() -> bool {
+ return true;
+}
+
+fn c() -> bool {
+ return true;
+}
+
+fn f() -> bool {
+ return (a() || (b() || c()));
+}
+)");
+}
+
+TEST_F(IRToProgramTest, BinaryOp_LogicalMixed) {
+ auto* fn_b = b.Function("b", ty.bool_());
+ mod.functions.Push(fn_b);
+ b.With(fn_b->Block(), [&] { b.Return(fn_b, true); });
+
+ auto* fn_d = b.Function("d", ty.bool_());
+ mod.functions.Push(fn_d);
+ b.With(fn_d->Block(), [&] { b.Return(fn_d, true); });
+
+ auto* fn = b.Function("f", ty.bool_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pc, "c");
+ fn->SetParams({pa, pc});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if1 = b.If(pa);
+ if1->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if1->True(), [&] { b.ExitIf(if1, true); });
+ b.With(if1->False(), [&] { b.ExitIf(if1, b.Call(ty.bool_(), fn_b)); });
+
+ auto* if2 = b.If(if1->Result(0));
+ if2->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if2->True(), [&] {
+ auto* if3 = b.If(pc);
+ if3->SetResults(b.InstructionResult(ty.bool_()));
+ b.With(if3->True(), [&] { b.ExitIf(if3, true); });
+ b.With(if3->False(), [&] { b.ExitIf(if3, b.Call(ty.bool_(), fn_d)); });
+
+ b.ExitIf(if2, if3->Result(0));
+ });
+ b.With(if2->False(), [&] { b.ExitIf(if2, false); });
+
+ mod.SetName(if2->Result(0), "l");
+ b.Return(fn, if2->Result(0));
+ });
+
+ Test(R"(
+fn b() -> bool {
+ return true;
+}
+
+fn d() -> bool {
+ return true;
+}
+
+fn f(a : bool, c : bool) -> bool {
+ let l = ((a || b()) && (c || d()));
+ return l;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Compound assignment
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, CompoundAssign_Increment) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Add(ty.i32(), b.Load(v), 1_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v + 1i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Decrement) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Subtract(ty.i32(), b.Load(v), 1_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v - 1i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Add) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Add(ty.i32(), b.Load(v), 8_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v + 8i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Subtract) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Subtract(ty.i32(), b.Load(v), 8_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v - 8i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Multiply) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Multiply(ty.i32(), b.Load(v), 8_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v * 8i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Divide) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Divide(ty.i32(), b.Load(v), 8_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v / 8i);
+}
+)");
+}
+
+TEST_F(IRToProgramTest, CompoundAssign_Xor) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, b.Xor(ty.i32(), b.Load(v), 8_i));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f() {
+ var v : i32;
+ v = (v ^ 8i);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// let
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, LetUsedOnce) {
+ auto* fn = b.Function("f", ty.u32());
+ auto* i = b.FunctionParam(ty.u32());
+ mod.SetName(i, "i");
+ fn->SetParams({i});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Complement(ty.u32(), i);
+ b.Return(fn, v);
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f(i : u32) -> u32 {
+ let v = ~(i);
+ return v;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, LetUsedTwice) {
+ auto* fn = b.Function("f", ty.i32());
+ auto* i = b.FunctionParam(ty.i32());
+ mod.SetName(i, "i");
+ fn->SetParams({i});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Multiply(ty.i32(), i, 2_i);
+ b.Return(fn, b.Add(ty.i32(), v, v));
+ mod.SetName(v, "v");
+ });
+
+ Test(R"(
+fn f(i : i32) -> i32 {
+ let v = (i * 2i);
+ return (v + v);
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Function-scope var
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, FunctionScopeVar_i32) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ });
+
+ Test(R"(
+fn f() {
+ var i : i32;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, FunctionScopeVar_i32_InitLiteral) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ i->SetInitializer(b.Constant(42_i));
+ mod.SetName(i, "i");
+ });
+
+ Test(R"(
+fn f() {
+ var i : i32 = 42i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, FunctionScopeVar_Chained) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* va = b.Var(ty.ptr<function, i32>());
+ va->SetInitializer(b.Constant(42_i));
+
+ auto* la = b.Load(va)->Result();
+ auto* vb = b.Var(ty.ptr<function, i32>());
+ vb->SetInitializer(la);
+
+ auto* lb = b.Load(vb)->Result();
+ auto* vc = b.Var(ty.ptr<function, i32>());
+ vc->SetInitializer(lb);
+
+ mod.SetName(va, "a");
+ mod.SetName(vb, "b");
+ mod.SetName(vc, "c");
+ });
+
+ Test(R"(
+fn f() {
+ var a : i32 = 42i;
+ var b : i32 = a;
+ var c : i32 = b;
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// If
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, If_CallFn) {
+ auto* a = b.Function("a", ty.void_());
+ mod.functions.Push(a);
+
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* if_ = b.If(cond);
+ b.With(if_->True(), [&] {
+ b.Call(ty.void_(), a);
+ b.ExitIf(if_);
+ });
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn f(cond : bool) {
+ if (cond) {
+ a();
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_Return) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto if_ = b.If(cond);
+ b.With(if_->True(), [&] { b.Return(fn); });
+ });
+
+ Test(R"(
+fn f(cond : bool) {
+ if (cond) {
+ return;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_Return_i32) {
+ auto* fn = b.Function("f", ty.i32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* cond = b.Var(ty.ptr<function, bool>());
+ mod.SetName(cond, "cond");
+ cond->SetInitializer(b.Constant(true));
+ auto if_ = b.If(b.Load(cond));
+ b.With(if_->True(), [&] { b.Return(fn, 42_i); });
+ b.Return(fn, 10_i);
+ });
+
+ Test(R"(
+fn f() -> i32 {
+ var cond : bool = true;
+ if (cond) {
+ return 42i;
+ }
+ return 10i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_CallFn_Else_CallFn) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn_b = b.Function("b", ty.void_());
+ mod.functions.Push(fn_b);
+
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto if_ = b.If(cond);
+ b.With(if_->True(), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitIf(if_);
+ });
+ b.With(if_->False(), [&] {
+ b.Call(ty.void_(), fn_b);
+ b.ExitIf(if_);
+ });
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn f(cond : bool) {
+ if (cond) {
+ a();
+ } else {
+ b();
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_Return_f32_Else_Return_f32) {
+ auto* fn = b.Function("f", ty.f32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* cond = b.Var(ty.ptr<function, bool>());
+ mod.SetName(cond, "cond");
+ cond->SetInitializer(b.Constant(true));
+ auto if_ = b.If(b.Load(cond));
+ b.With(if_->True(), [&] { b.Return(fn, 1.0_f); });
+ b.With(if_->False(), [&] { b.Return(fn, 2.0_f); });
+ });
+
+ Test(R"(
+fn f() -> f32 {
+ var cond : bool = true;
+ if (cond) {
+ return 1.0f;
+ } else {
+ return 2.0f;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_Return_u32_Else_CallFn) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn_b = b.Function("b", ty.void_());
+ mod.functions.Push(fn_b);
+
+ auto* fn = b.Function("f", ty.u32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* cond = b.Var(ty.ptr<function, bool>());
+ mod.SetName(cond, "cond");
+ cond->SetInitializer(b.Constant(true));
+ auto if_ = b.If(b.Load(cond));
+ b.With(if_->True(), [&] { b.Return(fn, 1_u); });
+ b.With(if_->False(), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitIf(if_);
+ });
+ b.Call(ty.void_(), fn_b);
+ b.Return(fn, 2_u);
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn f() -> u32 {
+ var cond : bool = true;
+ if (cond) {
+ return 1u;
+ } else {
+ a();
+ }
+ b();
+ return 2u;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_CallFn_ElseIf_CallFn) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn_b = b.Function("b", ty.void_());
+ mod.functions.Push(fn_b);
+
+ auto* fn_c = b.Function("c", ty.void_());
+ mod.functions.Push(fn_c);
+
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* cond = b.Var(ty.ptr<function, bool>());
+ mod.SetName(cond, "cond");
+ cond->SetInitializer(b.Constant(true));
+ auto if1 = b.If(b.Load(cond));
+ b.With(if1->True(), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitIf(if1);
+ });
+ b.With(if1->False(), [&] {
+ auto* if2 = b.If(b.Constant(false));
+ b.With(if2->True(), [&] {
+ b.Call(ty.void_(), fn_b);
+ b.ExitIf(if2);
+ });
+ b.ExitIf(if1);
+ });
+ b.Call(ty.void_(), fn_c);
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var cond : bool = true;
+ if (cond) {
+ a();
+ } else if (false) {
+ b();
+ }
+ c();
+}
+)");
+}
+
+TEST_F(IRToProgramTest, If_Else_Chain) {
+ auto* x = b.Function("x", ty.bool_());
+ auto* i = b.FunctionParam(ty.i32());
+ mod.SetName(i, "i");
+ x->SetParams({i});
+ mod.functions.Push(x);
+ b.With(x->Block(), [&] { b.Return(x, true); });
+
+ auto* fn = b.Function("f", ty.void_());
+ auto* pa = b.FunctionParam(ty.bool_());
+ auto* pb = b.FunctionParam(ty.bool_());
+ auto* pc = b.FunctionParam(ty.bool_());
+ auto* pd = b.FunctionParam(ty.bool_());
+ mod.SetName(pa, "a");
+ mod.SetName(pb, "b");
+ mod.SetName(pc, "c");
+ mod.SetName(pd, "d");
+ fn->SetParams({pa, pb, pc, pd});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto if1 = b.If(pa);
+ b.With(if1->True(), [&] {
+ b.Call(ty.void_(), x, 0_i);
+ b.ExitIf(if1);
+ });
+ b.With(if1->False(), [&] {
+ auto* if2 = b.If(pb);
+ b.With(if2->True(), [&] {
+ b.Call(ty.void_(), x, 1_i);
+ b.ExitIf(if2);
+ b.With(if2->False(), [&] {
+ auto* if3 = b.If(pc);
+ b.With(if3->True(), [&] {
+ b.Call(ty.void_(), x, 2_i);
+ b.ExitIf(if3);
+ });
+ b.With(if3->False(), [&] {
+ b.Call(ty.void_(), x, 3_i);
+ b.ExitIf(if3);
+ });
+ });
+ });
+ });
+ });
+ Test(R"(
+fn x(i : i32) -> bool {
+ return true;
+}
+
+fn f(a : bool, b : bool, c : bool, d : bool) {
+ if (a) {
+ x(0i);
+ } else if (b) {
+ x(1i);
+ } else if (c) {
+ x(2i);
+ } else {
+ x(3i);
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Switch
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, Switch_Default) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ mod.SetName(v, "v");
+ v->SetInitializer(b.Constant(42_i));
+
+ auto s = b.Switch(b.Load(v));
+ b.With(b.Case(s, {Switch::CaseSelector{}}), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitSwitch(s);
+ });
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ default: {
+ a();
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Switch_3_Cases) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn_b = b.Function("b", ty.void_());
+ mod.functions.Push(fn_b);
+
+ auto* fn_c = b.Function("c", ty.void_());
+ mod.functions.Push(fn_c);
+
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ mod.SetName(v, "v");
+ v->SetInitializer(b.Constant(42_i));
+
+ auto s = b.Switch(b.Load(v));
+ b.With(b.Case(s, {Switch::CaseSelector{b.Constant(0_i)}}), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitSwitch(s);
+ });
+ b.With(b.Case(s,
+ {
+ Switch::CaseSelector{b.Constant(1_i)},
+ Switch::CaseSelector{},
+ }),
+ [&] {
+ b.Call(ty.void_(), fn_b);
+ b.ExitSwitch(s);
+ });
+ b.With(b.Case(s, {Switch::CaseSelector{b.Constant(2_i)}}), [&] {
+ b.Call(ty.void_(), fn_c);
+ b.ExitSwitch(s);
+ });
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ case 0i: {
+ a();
+ }
+ case 1i, default: {
+ b();
+ }
+ case 2i: {
+ c();
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Switch_3_Cases_AllReturn) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ mod.SetName(v, "v");
+ v->SetInitializer(b.Constant(42_i));
+
+ auto s = b.Switch(b.Load(v));
+ b.With(b.Case(s, {Switch::CaseSelector{b.Constant(0_i)}}), [&] { b.Return(fn); });
+ b.With(b.Case(s,
+ {
+ Switch::CaseSelector{b.Constant(1_i)},
+ Switch::CaseSelector{},
+ }),
+ [&] { b.Return(fn); });
+ b.With(b.Case(s, {Switch::CaseSelector{b.Constant(2_i)}}), [&] { b.Return(fn); });
+
+ b.Call(ty.void_(), fn_a);
+ b.Return(fn);
+ });
+
+ Test(R"(
+fn a() {
+}
+
+fn f() {
+ var v : i32 = 42i;
+ switch(v) {
+ case 0i: {
+ return;
+ }
+ case 1i, default: {
+ return;
+ }
+ case 2i: {
+ return;
+ }
+ }
+ a();
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Switch_Nested) {
+ auto* fn_a = b.Function("a", ty.void_());
+ mod.functions.Push(fn_a);
+
+ auto* fn_b = b.Function("b", ty.void_());
+ mod.functions.Push(fn_b);
+
+ auto* fn_c = b.Function("c", ty.void_());
+ mod.functions.Push(fn_c);
+
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* v1 = b.Var(ty.ptr<function, i32>());
+ mod.SetName(v1, "v1");
+ v1->SetInitializer(b.Constant(42_i));
+
+ auto* v2 = b.Var(ty.ptr<function, i32>());
+ mod.SetName(v2, "v2");
+ v2->SetInitializer(b.Constant(24_i));
+
+ auto s1 = b.Switch(b.Load(v1));
+ b.With(b.Case(s1, {Switch::CaseSelector{b.Constant(0_i)}}), [&] {
+ b.Call(ty.void_(), fn_a);
+ b.ExitSwitch(s1);
+ });
+ b.With(b.Case(s1,
+ {
+ Switch::CaseSelector{b.Constant(1_i)},
+ Switch::CaseSelector{},
+ }),
+ [&] {
+ auto s2 = b.Switch(b.Load(v2));
+ b.With(b.Case(s2, {Switch::CaseSelector{b.Constant(0_i)}}),
+ [&] { b.ExitSwitch(s2); });
+ b.With(b.Case(s2,
+ {
+ Switch::CaseSelector{b.Constant(1_i)},
+ Switch::CaseSelector{},
+ }),
+ [&] { b.Return(fn); });
+ });
+ b.With(b.Case(s1, {Switch::CaseSelector{b.Constant(2_i)}}), [&] {
+ b.Call(ty.void_(), fn_c);
+ b.ExitSwitch(s1);
+ });
+ });
+ Test(R"(
+fn a() {
+}
+
+fn b() {
+}
+
+fn c() {
+}
+
+fn f() {
+ var v1 : i32 = 42i;
+ var v2 : i32 = 24i;
+ switch(v1) {
+ case 0i: {
+ a();
+ }
+ case 1i, default: {
+ switch(v2) {
+ case 0i: {
+ }
+ case 1i, default: {
+ return;
+ }
+ }
+ }
+ case 2i: {
+ c();
+ }
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// For
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, For_Empty) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Initializer(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if_->True(), [&] { b.ExitIf(if_); });
+ b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ });
+ });
+
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_Empty_NoInit) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if_->True(), [&] { b.ExitIf(if_); });
+ b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ });
+
+ Test(R"(
+fn f() {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_Empty_NoCont) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Initializer(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if_->True(), [&] { b.ExitIf(if_); });
+ b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ });
+ });
+ });
+
+ Test(R"(
+fn f() {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_ComplexBody) {
+ auto* a = b.Function("a", ty.bool_());
+ auto* v = b.FunctionParam(ty.i32());
+ mod.SetName(v, "v");
+ a->SetParams({v});
+ b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
+ mod.functions.Push(a);
+
+ auto* fn = b.Function("f", ty.i32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Initializer(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ b.With(loop->Body(), [&] {
+ auto* if1 = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if1->True(), [&] { b.ExitIf(if1); });
+ b.With(if1->False(), [&] { b.ExitLoop(loop); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
+ b.With(if2->True(), [&] { b.Return(fn, 1_i); });
+ b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+ });
+
+ b.Return(fn, 3_i);
+ });
+
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_ComplexBody_NoInit) {
+ auto* a = b.Function("a", ty.bool_());
+ auto* v = b.FunctionParam(ty.i32());
+ mod.SetName(v, "v");
+ a->SetParams({v});
+ b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
+ mod.functions.Push(a);
+
+ auto* fn = b.Function("f", ty.i32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if1 = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if1->True(), [&] { b.ExitIf(if1); });
+ b.With(if1->False(), [&] { b.ExitLoop(loop); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
+ b.With(if2->True(), [&] { b.Return(fn, 1_i); });
+ b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(i, b.Add(ty.i32(), b.Load(i), 1_i)); });
+
+ b.Return(fn, 3_i);
+ });
+
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ var i : i32 = 0i;
+ for(; (i < 5i); i = (i + 1i)) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_ComplexBody_NoCont) {
+ auto* a = b.Function("a", ty.bool_());
+ auto* v = b.FunctionParam(ty.i32());
+ mod.SetName(v, "v");
+ a->SetParams({v});
+ b.With(a->Block(), [&] { b.Return(a, b.Equal(ty.bool_(), v, 1_i)); });
+ mod.functions.Push(a);
+
+ auto* fn = b.Function("f", ty.i32());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Initializer(), [&] {
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(b.Constant(0_i));
+
+ b.With(loop->Body(), [&] {
+ auto* if1 = b.If(b.LessThan(ty.bool_(), b.Load(i), 5_i));
+ b.With(if1->True(), [&] { b.ExitIf(if1); });
+ b.With(if1->False(), [&] { b.ExitLoop(loop); });
+
+ auto* if2 = b.If(b.Call(ty.bool_(), a, 42_i));
+ b.With(if2->True(), [&] { b.Return(fn, 1_i); });
+ b.With(if2->False(), [&] { b.Return(fn, 2_i); });
+ });
+ });
+ b.Return(fn, 3_i);
+ });
+
+ Test(R"(
+fn a(v : i32) -> bool {
+ return (v == 1i);
+}
+
+fn f() -> i32 {
+ for(var i : i32 = 0i; (i < 5i); ) {
+ if (a(42i)) {
+ return 1i;
+ } else {
+ return 2i;
+ }
+ }
+ return 3i;
+}
+)");
+}
+
+TEST_F(IRToProgramTest, For_CallInInitCondCont) {
+ auto* fn_n = b.Function("n", ty.i32());
+ auto* v = b.FunctionParam(ty.i32());
+ mod.SetName(v, "v");
+ fn_n->SetParams({v});
+ b.With(fn_n->Block(), [&] { b.Return(fn_n, b.Add(ty.i32(), v, 1_i)); });
+ mod.functions.Push(fn_n);
+
+ auto* fn_f = b.Function("f", ty.void_());
+ mod.functions.Push(fn_f);
+
+ b.With(fn_f->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Initializer(), [&] {
+ auto* n_0 = b.Call(ty.i32(), fn_n, 0_i)->Result();
+ auto* i = b.Var(ty.ptr<function, i32>());
+ mod.SetName(i, "i");
+ i->SetInitializer(n_0);
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(b.LessThan(ty.bool_(), b.Load(i), b.Call(ty.i32(), fn_n, 1_i)));
+ b.With(if_->True(), [&] { b.ExitIf(if_); });
+ b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(i, b.Call(ty.i32(), fn_n, b.Load(i))); });
+ });
+ });
+
+ Test(R"(
+fn n(v : i32) -> i32 {
+ return (v + 1i);
+}
+
+fn f() {
+ for(var i : i32 = n(0i); (i < n(1i)); i = n(i)) {
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// While
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, While_Empty) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* cond = b.If(true);
+ b.With(cond->True(), [&] { b.ExitIf(cond); });
+ b.With(cond->False(), [&] { b.ExitLoop(loop); });
+ });
+ });
+
+ Test(R"(
+fn f() {
+ while(true) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, While_Cond) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(cond);
+ b.With(if_->True(), [&] { b.ExitIf(if_); });
+ b.With(if_->False(), [&] { b.ExitLoop(loop); });
+ });
+ });
+
+ Test(R"(
+fn f(cond : bool) {
+ while(cond) {
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, While_Break) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* cond = b.If(true);
+ b.With(cond->True(), [&] { b.ExitIf(cond); });
+ b.With(cond->False(), [&] { b.ExitLoop(loop); });
+ b.ExitLoop(loop);
+ });
+ });
+
+ Test(R"(
+fn f() {
+ while(true) {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, While_IfBreak) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if1 = b.If(true);
+ b.With(if1->True(), [&] { b.ExitIf(if1); });
+ b.With(if1->False(), [&] { b.ExitLoop(loop); });
+
+ auto* if2 = b.If(cond);
+ b.With(if2->True(), [&] { b.ExitLoop(loop); });
+ });
+ });
+
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, While_IfReturn) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if1 = b.If(true);
+ b.With(if1->True(), [&] { b.ExitIf(if1); });
+ b.With(if1->False(), [&] { b.ExitLoop(loop); });
+
+ auto* if2 = b.If(cond);
+ b.With(if2->True(), [&] { b.Return(fn); });
+ });
+ });
+
+ Test(R"(
+fn f(cond : bool) {
+ while(true) {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Loop
+////////////////////////////////////////////////////////////////////////////////
+TEST_F(IRToProgramTest, Loop_Break) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] { b.ExitLoop(loop); });
+ });
+
+ Test(R"(
+fn f() {
+ loop {
+ break;
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Loop_IfBreak) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(cond);
+ b.With(if_->True(), [&] { b.ExitLoop(loop); });
+ });
+ });
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ break;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Loop_IfReturn) {
+ auto* fn = b.Function("f", ty.void_());
+ auto* cond = b.FunctionParam(ty.bool_());
+ mod.SetName(cond, "cond");
+ fn->SetParams({cond});
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(cond);
+ b.With(if_->True(), [&] { b.Return(fn); });
+ });
+ });
+
+ Test(R"(
+fn f(cond : bool) {
+ loop {
+ if (cond) {
+ return;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Loop_IfContinuing) {
+ auto* fn = b.Function("f", ty.void_());
+ mod.functions.Push(fn);
+
+ b.With(fn->Block(), [&] {
+ auto* cond = b.Var(ty.ptr<function, bool>());
+ cond->SetInitializer(b.Constant(false));
+ mod.SetName(cond, "cond");
+
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* if_ = b.If(cond);
+ b.With(if_->True(), [&] { b.Return(fn); });
+ });
+
+ b.With(loop->Continuing(), [&] { b.Store(cond, true); });
+ });
+
+ Test(R"(
+fn f() {
+ var cond : bool = false;
+ loop {
+ if (cond) {
+ return;
+ }
+
+ continuing {
+ cond = true;
+ }
+ }
+}
+)");
+}
+
+TEST_F(IRToProgramTest, Loop_VarsDeclaredOutsideAndInside) {
+ auto* f = b.Function("f", ty.void_());
+ mod.functions.Push(f);
+
+ b.With(f->Block(), [&] {
+ auto* var_b = b.Var(ty.ptr<function, i32>());
+ var_b->SetInitializer(b.Constant(1_i));
+ mod.SetName(var_b, "b");
+
+ auto* loop = b.Loop();
+
+ b.With(loop->Body(), [&] {
+ auto* var_a = b.Var(ty.ptr<function, i32>());
+ var_a->SetInitializer(b.Constant(2_i));
+ mod.SetName(var_a, "a");
+
+ auto* if_ = b.If(b.Equal(ty.bool_(), b.Load(var_a), b.Load(var_b)));
+ b.With(if_->True(), [&] { b.Return(f); });
+ b.With(if_->False(), [&] { b.ExitIf(if_); });
+
+ b.With(loop->Continuing(),
+ [&] { b.Store(var_b, b.Add(ty.i32(), b.Load(var_a), b.Load(var_b))); });
+ });
+ });
+
+ Test(R"(
+fn f() {
+ var b : i32 = 1i;
+ loop {
+ var a : i32 = 2i;
+ if ((a == b)) {
+ return;
+ }
+
+ continuing {
+ b = (a + b);
+ }
+ }
+}
+)");
+}
+
+} // namespace
+} // namespace tint::ir
diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc
index 3c28e59..5781e93 100644
--- a/src/tint/ir/transform/add_empty_entry_point.cc
+++ b/src/tint/ir/transform/add_empty_entry_point.cc
@@ -37,7 +37,7 @@
ir::Builder builder(*ir);
auto* ep = builder.Function("unused_entry_point", ir->Types().void_(),
Function::PipelineStage::kCompute, std::array{1u, 1u, 1u});
- ep->StartTarget()->Append(builder.Return(ep));
+ ep->Block()->Append(builder.Return(ep));
ir->functions.Push(ep);
}
diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc
index c5c60a1..567feae 100644
--- a/src/tint/ir/transform/add_empty_entry_point_test.cc
+++ b/src/tint/ir/transform/add_empty_entry_point_test.cc
@@ -39,7 +39,7 @@
TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) {
auto* ep = b.Function("main", mod.Types().void_(), Function::PipelineStage::kFragment);
- ep->StartTarget()->Append(b.Return(ep));
+ ep->Block()->Append(b.Return(ep));
mod.functions.Push(ep);
auto* expect = R"(
diff --git a/src/tint/ir/transform/block_decorated_structs_test.cc b/src/tint/ir/transform/block_decorated_structs_test.cc
index 2211f99..2d2a856 100644
--- a/src/tint/ir/transform/block_decorated_structs_test.cc
+++ b/src/tint/ir/transform/block_decorated_structs_test.cc
@@ -31,7 +31,7 @@
TEST_F(IR_BlockDecoratedStructsTest, NoRootBlock) {
auto* func = b.Function("foo", ty.void_());
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
@@ -54,7 +54,7 @@
auto* func = b.Function("foo", ty.i32());
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* load = block->Append(b.Load(buffer));
block->Append(b.Return(func, load));
mod.functions.Push(func);
@@ -88,8 +88,8 @@
b.RootBlock()->Append(buffer);
auto* func = b.Function("foo", ty.void_());
- func->StartTarget()->Append(b.Store(buffer, 42_i));
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Store(buffer, 42_i));
+ func->Block()->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
@@ -122,10 +122,11 @@
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* access = sb.Access(ty.ptr<storage, i32>(), buffer, 1_u);
- sb.Store(access, 42_i);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* access = b.Access(ty.ptr<storage, i32>(), buffer, 1_u);
+ b.Store(access, 42_i);
+ b.Return(func);
+ });
mod.functions.Push(func);
@@ -168,12 +169,13 @@
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* val_ptr = sb.Access(i32_ptr, buffer, 0_u);
- auto* load = sb.Load(val_ptr);
- auto* elem_ptr = sb.Access(i32_ptr, buffer, 1_u, 3_u);
- sb.Store(elem_ptr, load);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* val_ptr = b.Access(i32_ptr, buffer, 0_u);
+ auto* load = b.Load(val_ptr);
+ auto* elem_ptr = b.Access(i32_ptr, buffer, 1_u, 3_u);
+ b.Store(elem_ptr, load);
+ b.Return(func);
+ });
mod.functions.Push(func);
@@ -222,8 +224,8 @@
b.RootBlock()->Append(private_var);
auto* func = b.Function("foo", ty.void_());
- func->StartTarget()->Append(b.Store(buffer, private_var));
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Store(buffer, private_var));
+ func->Block()->Append(b.Return(func));
mod.functions.Push(func);
auto* expect = R"(
@@ -268,7 +270,7 @@
root->Append(buffer_c);
auto* func = b.Function("foo", ty.void_());
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* load_b = block->Append(b.Load(buffer_b));
auto* load_c = block->Append(b.Load(buffer_c));
block->Append(b.Store(buffer_a, b.Add(ty.i32(), load_b, load_c)));
diff --git a/src/tint/ir/transform/merge_return.cc b/src/tint/ir/transform/merge_return.cc
index 4817d04..d2d4259 100644
--- a/src/tint/ir/transform/merge_return.cc
+++ b/src/tint/ir/transform/merge_return.cc
@@ -80,23 +80,23 @@
// Create a boolean variable that can be used to check whether the function is returning.
continue_execution = b.Var(ty.ptr<function, bool>());
continue_execution->SetInitializer(b.Constant(true));
- fn->StartTarget()->Prepend(continue_execution);
+ fn->Block()->Prepend(continue_execution);
ir->SetName(continue_execution, "continue_execution");
// Create a variable to hold the return value if needed.
if (!fn->ReturnType()->Is<type::Void>()) {
return_val = b.Var(ty.ptr(function, fn->ReturnType()));
- fn->StartTarget()->Prepend(return_val);
+ fn->Block()->Prepend(return_val);
ir->SetName(return_val, "return_value");
}
// Look to see if the function ends with a return
- fn_return = tint::As<Return>(fn->StartTarget()->Terminator());
+ fn_return = tint::As<Return>(fn->Block()->Terminator());
// Process the function's block.
// This will traverse into control instructions that hold returns, and apply the necessary
// changes to remove returns.
- ProcessBlock(fn->StartTarget());
+ ProcessBlock(fn->Block());
// If the function didn't end with a return, add one
if (!fn_return) {
@@ -279,12 +279,13 @@
/// Adds a final return instruction to the end of @p fn
/// @param fn the function
void AppendFinalReturn(Function* fn) {
- auto fb = b.With(fn->StartTarget());
- if (return_val) {
- fb.Return(fn, fb.Load(return_val));
- } else {
- fb.Return(fn);
- }
+ b.With(fn->Block(), [&] {
+ if (return_val) {
+ b.Return(fn, b.Load(return_val));
+ } else {
+ b.Return(fn);
+ }
+ });
}
};
diff --git a/src/tint/ir/transform/merge_return_test.cc b/src/tint/ir/transform/merge_return_test.cc
index bf68eae..c54f809 100644
--- a/src/tint/ir/transform/merge_return_test.cc
+++ b/src/tint/ir/transform/merge_return_test.cc
@@ -32,8 +32,7 @@
func->SetParams({in});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
- sb.Return(func, sb.Add(ty.i32(), in, 1_i));
+ b.With(func->Block(), [&] { b.Return(func, b.Add(ty.i32(), in, 1_i)); });
auto* src = R"(
%foo = func(%2:i32):i32 -> %b1 {
@@ -59,17 +58,14 @@
func->SetParams({in});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ ifelse->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse->True(), [&] { b.ExitIf(ifelse, b.Add(ty.i32(), in, 1_i)); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse, b.Add(ty.i32(), in, 2_i)); });
- auto* ifelse = sb.If(cond);
- ifelse->SetResults(b.InstructionResult(ty.i32()));
- auto tb = b.With(ifelse->True());
- tb.ExitIf(ifelse, tb.Add(ty.i32(), in, 1_i));
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse, fb.Add(ty.i32(), in, 2_i));
-
- sb.Return(func, ifelse->Result(0));
-
+ b.Return(func, ifelse->Result(0));
+ });
auto* src = R"(
%foo = func(%2:i32):i32 -> %b1 {
%b1 = block {
@@ -103,21 +99,19 @@
func->SetParams({in});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* swtch = b.Switch(in);
+ b.With(b.Case(swtch, {Switch::CaseSelector{}}), [&] { b.ExitSwitch(swtch); });
- auto* swtch = sb.Switch(in);
- b.Case(swtch, {Switch::CaseSelector{}})->Append(b.ExitSwitch(swtch));
+ b.Loop();
- sb.Loop();
+ auto* ifelse = b.If(cond);
+ ifelse->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse->True(), [&] { b.ExitIf(ifelse, b.Add(ty.i32(), in, 1_i)); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse, b.Add(ty.i32(), in, 2_i)); });
- auto* ifelse = sb.If(cond);
- ifelse->SetResults(b.InstructionResult(ty.i32()));
- auto tb = b.With(ifelse->True());
- tb.ExitIf(ifelse, tb.Add(ty.i32(), in, 1_i));
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse, fb.Add(ty.i32(), in, 2_i));
-
- sb.Return(func, ifelse->Result(0));
+ b.Return(func, ifelse->Result(0));
+ });
auto* src = R"(
%foo = func(%2:i32):i32 -> %b1 {
@@ -158,15 +152,13 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* ifelse = sb.If(cond);
- auto tb = b.With(ifelse->True());
- tb.Return(func);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
-
- sb.Return(func);
+ b.Return(func);
+ });
auto* src = R"(
%foo = func(%2:bool):void -> %b1 {
@@ -214,15 +206,13 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.Return(func);
- auto* ifelse = sb.If(cond);
- sb.Return(func);
-
- auto tb = b.With(ifelse->True());
- tb.Return(func);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
+ b.With(ifelse->True(), [&] { b.Return(func); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
+ });
auto* src = R"(
%foo = func(%2:bool):void -> %b1 {
@@ -268,15 +258,13 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* ifelse = sb.If(cond);
- auto tb = b.With(ifelse->True());
- tb.Return(func, 1_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
-
- sb.Return(func, 2_i);
+ b.Return(func, 2_i);
+ });
auto* src = R"(
%foo = func(%2:bool):i32 -> %b1 {
@@ -334,16 +322,14 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ ifelse->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse, 2_i); });
- auto* ifelse = sb.If(cond);
- ifelse->SetResults(b.InstructionResult(ty.i32()));
- auto tb = b.With(ifelse->True());
- tb.Return(func, 1_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse, 2_i);
-
- sb.Return(func, ifelse->Result(0));
+ b.Return(func, ifelse->Result(0));
+ });
auto* src = R"(
%foo = func(%2:bool):i32 -> %b1 {
@@ -401,16 +387,14 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ ifelse->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse, nullptr); });
- auto* ifelse = sb.If(cond);
- ifelse->SetResults(b.InstructionResult(ty.i32()));
- auto tb = b.With(ifelse->True());
- tb.Return(func, 1_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse, nullptr);
-
- sb.Return(func, ifelse->Result(0));
+ b.Return(func, ifelse->Result(0));
+ });
auto* src = R"(
%foo = func(%2:bool):i32 -> %b1 {
@@ -468,15 +452,13 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func); });
+ b.With(ifelse->False(), [&] { b.Return(func); });
- auto* ifelse = sb.If(cond);
- auto tb = b.With(ifelse->True());
- tb.Return(func);
- auto fb = b.With(ifelse->False());
- fb.Return(func);
-
- sb.Unreachable();
+ b.Unreachable();
+ });
auto* src = R"(
%foo = func(%2:bool):void -> %b1 {
@@ -526,16 +508,14 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* ifelse = sb.If(cond);
- auto tb = b.With(ifelse->True());
- tb.Return(func);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
-
- sb.Store(global, 42_i);
- sb.Return(func);
+ b.Store(global, 42_i);
+ b.Return(func);
+ });
auto* src = R"(
%b1 = block { # root
@@ -605,16 +585,14 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse = b.If(cond);
+ b.Store(global, 42_i);
+ b.Return(func);
- auto* ifelse = sb.If(cond);
- sb.Store(global, 42_i);
- sb.Return(func);
-
- auto tb = b.With(ifelse->True());
- tb.Return(func);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
+ b.With(ifelse->True(), [&] { b.Return(func); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
+ });
auto* src = R"(
%b1 = block { # root
@@ -687,32 +665,26 @@
func->SetParams({condA, condB, condC});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse_outer = b.If(condA);
+ b.With(ifelse_outer->True(), [&] { b.Return(func, 3_i); });
+ b.With(ifelse_outer->False(), [&] {
+ auto* ifelse_middle = b.If(condB);
+ b.With(ifelse_middle->True(), [&] {
+ auto* ifelse_inner = b.If(condC);
+ b.With(ifelse_inner->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse_inner->False(), [&] { b.ExitIf(ifelse_inner); });
- auto* ifelse_outer = sb.If(condA);
- auto outer_true = b.With(ifelse_outer->True());
- outer_true.Return(func, 3_i);
- auto outer_false = b.With(ifelse_outer->False());
- auto* ifelse_middle = outer_false.If(condB);
-
- sb.Store(global, 3_i);
- sb.Return(func, sb.Add(ty.i32(), 5_i, 6_i));
-
- auto middle_true = b.With(ifelse_middle->True());
- auto* ifelse_inner = middle_true.If(condC);
- auto middle_false = b.With(ifelse_middle->False());
- middle_false.ExitIf(ifelse_middle);
-
- outer_false.Store(global, 2_i);
- outer_false.ExitIf(ifelse_outer);
-
- auto inner_true = b.With(ifelse_inner->True());
- inner_true.Return(func, 1_i);
- auto inner_false = b.With(ifelse_inner->False());
- inner_false.ExitIf(ifelse_inner);
-
- middle_true.Store(global, 1_i);
- middle_true.Return(func, 2_i);
+ b.Store(global, 1_i);
+ b.Return(func, 2_i);
+ });
+ b.With(ifelse_middle->False(), [&] { b.ExitIf(ifelse_middle); });
+ b.Store(global, 2_i);
+ b.ExitIf(ifelse_outer);
+ });
+ b.Store(global, 3_i);
+ b.Return(func, b.Add(ty.i32(), 5_i, 6_i));
+ });
auto* src = R"(
%b1 = block { # root
@@ -843,30 +815,24 @@
func->SetParams({condA, condB, condC});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse_outer = b.If(condA);
+ b.With(ifelse_outer->True(), [&] { b.Return(func, 3_i); });
+ b.With(ifelse_outer->False(), [&] {
+ auto* ifelse_middle = b.If(condB);
+ b.With(ifelse_middle->True(), [&] {
+ auto* ifelse_inner = b.If(condC);
+ b.With(ifelse_inner->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse_inner->False(), [&] { b.ExitIf(ifelse_inner); });
- auto* ifelse_outer = sb.If(condA);
- auto outer_true = b.With(ifelse_outer->True());
- outer_true.Return(func, 3_i);
- auto outer_false = b.With(ifelse_outer->False());
- auto* ifelse_middle = outer_false.If(condB);
+ b.ExitIf(ifelse_middle);
+ });
+ b.With(ifelse_middle->False(), [&] { b.ExitIf(ifelse_middle); });
- sb.Return(func, 3_i);
-
- auto middle_true = b.With(ifelse_middle->True());
- auto* ifelse_inner = middle_true.If(condC);
- auto middle_false = b.With(ifelse_middle->False());
- middle_false.ExitIf(ifelse_middle);
-
- outer_false.ExitIf(ifelse_outer);
-
- auto inner_true = b.With(ifelse_inner->True());
- inner_true.Return(func, 1_i);
- auto inner_false = b.With(ifelse_inner->False());
- inner_false.ExitIf(ifelse_inner);
-
- middle_true.ExitIf(ifelse_middle);
-
+ b.ExitIf(ifelse_outer);
+ });
+ b.Return(func, 3_i);
+ });
auto* src = R"(
%b1 = block { # root
%1:ptr<private, i32, read_write> = var
@@ -974,31 +940,28 @@
func->SetParams({condA, condB, condC});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* ifelse_outer = b.If(condA);
+ ifelse_outer->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse_outer->True(), [&] { b.Return(func, 3_i); });
+ b.With(ifelse_outer->False(), [&] {
+ auto* ifelse_middle = b.If(condB);
+ ifelse_middle->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse_middle->True(), [&] {
+ auto* ifelse_inner = b.If(condC);
- auto* ifelse_outer = sb.If(condA);
- ifelse_outer->SetResults(b.InstructionResult(ty.i32()));
- auto outer_true = b.With(ifelse_outer->True());
- outer_true.Return(func, 3_i);
- auto outer_false = b.With(ifelse_outer->False());
- auto* ifelse_middle = outer_false.If(condB);
- ifelse_middle->SetResults(b.InstructionResult(ty.i32()));
+ b.With(ifelse_inner->True(), [&] { b.Return(func, 1_i); });
+ b.With(ifelse_inner->False(), [&] { b.ExitIf(ifelse_inner); });
- sb.Return(func, sb.Add(ty.i32(), ifelse_outer->Result(0), 1_i));
+ b.ExitIf(ifelse_middle, b.Add(ty.i32(), 42_i, 1_i));
+ });
+ b.With(ifelse_middle->False(),
+ [&] { b.ExitIf(ifelse_middle, b.Add(ty.i32(), 43_i, 2_i)); });
+ b.ExitIf(ifelse_outer, b.Add(ty.i32(), ifelse_middle->Result(0), 1_i));
+ });
- auto middle_true = b.With(ifelse_middle->True());
- auto* ifelse_inner = middle_true.If(condC);
- auto middle_false = b.With(ifelse_middle->False());
- middle_false.ExitIf(ifelse_middle, middle_false.Add(ty.i32(), 43_i, 2_i));
-
- outer_false.ExitIf(ifelse_outer, outer_false.Add(ty.i32(), ifelse_middle->Result(0), 1_i));
-
- auto inner_true = b.With(ifelse_inner->True());
- inner_true.Return(func, 1_i);
- auto inner_false = b.With(ifelse_inner->False());
- inner_false.ExitIf(ifelse_inner);
-
- middle_true.ExitIf(ifelse_middle, middle_true.Add(ty.i32(), 42_i, 1_i));
+ b.Return(func, b.Add(ty.i32(), ifelse_outer->Result(0), 1_i));
+ });
auto* src = R"(
%b1 = block { # root
@@ -1119,13 +1082,12 @@
auto* func = b.Function("foo", ty.i32());
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] { b.Return(func, 42_i); });
- auto* loop = sb.Loop();
- loop->Body()->Append(b.Return(func, 42_i));
-
- sb.Unreachable();
-
+ b.Unreachable();
+ });
auto* src = R"(
%foo = func():i32 -> %b1 {
%b1 = block {
@@ -1170,26 +1132,25 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func, 42_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* loop = sb.Loop();
- auto lb = b.With(loop->Body());
- auto* ifelse = lb.If(cond);
- {
- auto tb = b.With(ifelse->True());
- tb.Return(func, 42_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
- }
- lb.Store(global, 2_i);
- lb.Continue(loop);
+ b.Store(global, 2_i);
+ b.Continue(loop);
+ });
- auto cb = b.With(loop->Continuing());
- cb.Store(global, 1_i);
- cb.BreakIf(true, loop);
+ b.With(loop->Continuing(), [&] {
+ b.Store(global, 1_i);
+ b.BreakIf(true, loop);
+ });
- sb.Store(global, 3_i);
- sb.Return(func, 43_i);
+ b.Store(global, 3_i);
+ b.Return(func, 43_i);
+ });
auto* src = R"(
%b1 = block { # root
@@ -1286,25 +1247,25 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ auto* ifelse = b.If(cond);
- auto* loop = sb.Loop();
- auto lb = b.With(loop->Body());
- auto* ifelse = lb.If(cond);
- {
- auto tb = b.With(ifelse->True());
- tb.Return(func, 42_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
- }
- lb.Store(global, 2_i);
- lb.Continue(loop);
+ b.With(ifelse->True(), [&] { b.Return(func, 42_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto cb = b.With(loop->Continuing());
- cb.Store(global, 1_i);
- cb.NextIteration(loop);
+ b.Store(global, 2_i);
+ b.Continue(loop);
+ });
- sb.Unreachable();
+ b.With(loop->Continuing(), [&] {
+ b.Store(global, 1_i);
+ b.NextIteration(loop);
+ });
+
+ b.Unreachable();
+ });
auto* src = R"(
%b1 = block { # root
@@ -1392,28 +1353,26 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ loop->SetResults(b.InstructionResult(ty.i32()));
+ b.With(loop->Body(), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func, 42_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* loop = sb.Loop();
- loop->SetResults(b.InstructionResult(ty.i32()));
- auto lb = b.With(loop->Body());
- auto* ifelse = lb.If(cond);
- {
- auto tb = b.With(ifelse->True());
- tb.Return(func, 42_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
- }
- lb.Store(global, 2_i);
- lb.Continue(loop);
+ b.Store(global, 2_i);
+ b.Continue(loop);
+ });
- auto cb = b.With(loop->Continuing());
- cb.Store(global, 1_i);
- cb.BreakIf(true, loop, 4_i);
+ b.With(loop->Continuing(), [&] {
+ b.Store(global, 1_i);
+ b.BreakIf(true, loop, 4_i);
+ });
- sb.Store(global, 3_i);
- sb.Return(func, loop->Result(0));
-
+ b.Store(global, 3_i);
+ b.Return(func, loop->Result(0));
+ });
auto* src = R"(
%b1 = block { # root
%1:ptr<private, i32, read_write> = var
@@ -1506,15 +1465,13 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* sw = b.Switch(cond);
+ b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}), [&] { b.Return(func, 42_i); });
+ b.With(b.Case(sw, {Switch::CaseSelector{}}), [&] { b.ExitSwitch(sw); });
- auto* sw = sb.Switch(cond);
- auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
- caseA.Return(func, 42_i);
- auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{}}));
- caseB.ExitSwitch(sw);
-
- sb.Return(func, 0_i);
+ b.Return(func, 0_i);
+ });
auto* src = R"(
%foo = func(%2:i32):i32 -> %b1 {
@@ -1575,24 +1532,21 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* sw = b.Switch(cond);
+ b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}), [&] {
+ auto* ifelse = b.If(cond);
+ b.With(ifelse->True(), [&] { b.Return(func, 42_i); });
+ b.With(ifelse->False(), [&] { b.ExitIf(ifelse); });
- auto* sw = sb.Switch(cond);
- auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
- auto* ifelse = caseA.If(cond);
- {
- auto tb = b.With(ifelse->True());
- tb.Return(func, 42_i);
- auto fb = b.With(ifelse->False());
- fb.ExitIf(ifelse);
- }
- caseA.Store(global, 2_i);
- caseA.ExitSwitch(sw);
+ b.Store(global, 2_i);
+ b.ExitSwitch(sw);
+ });
- auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{}}));
- caseB.ExitSwitch(sw);
+ b.With(b.Case(sw, {Switch::CaseSelector{}}), [&] { b.ExitSwitch(sw); });
- sb.Return(func, 0_i);
+ b.Return(func, 0_i);
+ });
auto* src = R"(
%b1 = block { # root
@@ -1682,20 +1636,16 @@
func->SetParams({cond});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
+ b.With(func->Block(), [&] {
+ auto* sw = b.Switch(cond);
+ sw->SetResults(b.InstructionResult(ty.i32())); // NOLINT: false detection of std::tuple
+ b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}), [&] { b.Return(func, 42_i); });
+ b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(2_i)}}), [&] { b.Return(func, 99_i); });
+ b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(3_i)}}), [&] { b.ExitSwitch(sw, 1_i); });
+ b.With(b.Case(sw, {Switch::CaseSelector{}}), [&] { b.ExitSwitch(sw, 0_i); });
- auto* sw = sb.Switch(cond);
- sw->SetResults(b.InstructionResult(ty.i32())); // NOLINT: false detection of std::tuple
- auto caseA = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(1_i)}}));
- caseA.Return(func, 42_i);
- auto caseB = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(2_i)}}));
- caseB.Return(func, 99_i);
- auto caseC = b.With(b.Case(sw, {Switch::CaseSelector{b.Constant(3_i)}}));
- caseC.ExitSwitch(sw, 1_i);
- auto caseD = b.With(b.Case(sw, {Switch::CaseSelector{}}));
- caseD.ExitSwitch(sw, 0_i);
-
- sb.Return(func, sw->Result(0));
+ b.Return(func, sw->Result(0));
+ });
auto* src = R"(
%foo = func(%2:i32):i32 -> %b1 {
@@ -1765,17 +1715,14 @@
auto* func = b.Function("foo", ty.void_());
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
- {
- auto* loop = sb.Loop();
- auto lb = sb.With(loop->Body());
- {
- auto ib = lb.With(lb.If(true)->True());
- ib.Return(func);
- }
- lb.Continue(loop);
- }
- sb.Unreachable();
+ b.With(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.With(loop->Body(), [&] {
+ b.With(b.If(true)->True(), [&] { b.Return(func); });
+ b.Continue(loop);
+ });
+ b.Unreachable();
+ });
auto* src = R"(
%foo = func():void -> %b1 {
@@ -1831,16 +1778,13 @@
auto* func = b.Function("foo", ty.i32());
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
- {
- auto outer = b.With(sb.If(true)->True());
- {
- auto inner = b.With(outer.If(true)->True());
- inner.Return(func, 1_i);
- }
- outer.Return(func, 2_i);
- }
- sb.Return(func, 3_i);
+ b.With(func->Block(), [&] {
+ b.With(b.If(true)->True(), [&] {
+ b.With(b.If(true)->True(), [&] { b.Return(func, 1_i); });
+ b.Return(func, 2_i);
+ });
+ b.Return(func, 3_i);
+ });
auto* src = R"(
%foo = func():i32 -> %b1 {
diff --git a/src/tint/ir/transform/var_for_dynamic_index.cc b/src/tint/ir/transform/var_for_dynamic_index.cc
index 27ddec4..a41e307 100644
--- a/src/tint/ir/transform/var_for_dynamic_index.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index.cc
@@ -121,7 +121,7 @@
if (to_replace.first_dynamic_index > 0) {
PartialAccess partial_access = {
access->Object(), access->Indices().Truncate(to_replace.first_dynamic_index)};
- source_object = source_object_to_value.GetOrCreate(partial_access, [&]() {
+ source_object = source_object_to_value.GetOrCreate(partial_access, [&] {
auto* intermediate_source = builder.Access(to_replace.dynamic_index_source_type,
source_object, partial_access.indices);
intermediate_source->InsertBefore(access);
@@ -130,7 +130,7 @@
}
// Declare a local variable and copy the source object to it.
- auto* local = object_to_local.GetOrCreate(source_object, [&]() {
+ auto* local = object_to_local.GetOrCreate(source_object, [&] {
auto* decl =
builder.Var(ir->Types().ptr(builtin::AddressSpace::kFunction, source_object->Type(),
builtin::Access::kReadWrite));
diff --git a/src/tint/ir/transform/var_for_dynamic_index_test.cc b/src/tint/ir/transform/var_for_dynamic_index_test.cc
index 785927d..ce0d691 100644
--- a/src/tint/ir/transform/var_for_dynamic_index_test.cc
+++ b/src/tint/ir/transform/var_for_dynamic_index_test.cc
@@ -34,7 +34,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_i));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -58,7 +58,7 @@
auto* func = b.Function("foo", ty.f32());
func->SetParams({mat});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), mat, 1_i, 0_i));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -83,7 +83,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.ptr<function, i32>(), arr, idx));
auto* load = block->Append(b.Load(access));
block->Append(b.Return(func, load));
@@ -110,7 +110,7 @@
auto* func = b.Function("foo", ty.f32());
func->SetParams({mat, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.ptr<function, f32>(), mat, idx, idx));
auto* load = block->Append(b.Load(access));
block->Append(b.Return(func, load));
@@ -137,7 +137,7 @@
auto* func = b.Function("foo", ty.f32());
func->SetParams({vec, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), vec, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -162,7 +162,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -189,7 +189,7 @@
auto* func = b.Function("foo", ty.vec2<f32>());
func->SetParams({mat, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.vec2<f32>(), mat, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -216,7 +216,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, idx, 1_u, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -243,7 +243,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -271,7 +271,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.i32(), arr, 1_u, idx, 2_u, idx));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -305,7 +305,7 @@
auto* func = b.Function("foo", ty.f32());
func->SetParams({str_val, idx});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
auto* access = block->Append(b.Access(ty.f32(), str_val, 1_u, idx, 0_u));
block->Append(b.Return(func, access));
mod.functions.Push(func);
@@ -341,7 +341,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx_a, idx_b, idx_c});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
block->Append(b.Access(ty.i32(), arr, idx_a));
block->Append(b.Access(ty.i32(), arr, idx_b));
auto* access_c = block->Append(b.Access(ty.i32(), arr, idx_c));
@@ -376,7 +376,7 @@
auto* func = b.Function("foo", ty.i32());
func->SetParams({arr, idx_a, idx_b, idx_c});
- auto* block = func->StartTarget();
+ auto* block = func->Block();
block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_a));
block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_b));
auto* access_c = block->Append(b.Access(ty.i32(), arr, 1_u, 2_u, idx_c));
diff --git a/src/tint/ir/validate.cc b/src/tint/ir/validate.cc
index 6687111..d8871b9 100644
--- a/src/tint/ir/validate.cc
+++ b/src/tint/ir/validate.cc
@@ -158,10 +158,11 @@
std::string("root block: invalid instruction: ") + inst->TypeInfo().name);
continue;
}
+ CheckVar(var);
}
}
- void CheckFunction(Function* func) { CheckBlock(func->StartTarget()); }
+ void CheckFunction(Function* func) { CheckBlock(func->Block()); }
void CheckBlock(Block* blk) {
TINT_SCOPED_ASSIGNMENT(current_block_, blk);
@@ -181,10 +182,39 @@
}
void CheckInstruction(Instruction* inst) {
+ if (!inst->Alive()) {
+ AddError(inst, "destroyed instruction found in instruction list");
+ }
+ if (inst->Result()) {
+ if (inst->Result()->Source() == nullptr) {
+ AddError(inst, "instruction result source is undefined");
+ } else if (inst->Result()->Source() != inst) {
+ AddError(inst, "instruction result source has wrong instruction");
+ }
+ }
+
+ auto ops = inst->Operands();
+ for (size_t i = 0; i < ops.Length(); ++i) {
+ auto* op = ops[i];
+ if (!op) {
+ continue;
+ }
+
+ // Note, a `nullptr` is a valid operand in some cases, like `var` so we can't just check
+ // for `nullptr` here.
+ if (!op->Alive()) {
+ AddError(inst, "instruction has undefined operand");
+ }
+
+ if (!op->Usages().Contains({inst, i})) {
+ AddError(inst, i, "instruction operand missing usage");
+ }
+ }
+
tint::Switch(
inst, //
[&](Access* a) { CheckAccess(a); }, //
- [&](Binary*) {}, //
+ [&](Binary* b) { CheckBinary(b); }, //
[&](Call* c) { CheckCall(c); }, //
[&](If* if_) { CheckIf(if_); }, //
[&](Load*) {}, //
@@ -194,7 +224,7 @@
[&](Swizzle*) {}, //
[&](Terminator* b) { CheckTerminator(b); }, //
[&](Unary*) {}, //
- [&](Var*) {}, //
+ [&](Var* var) { CheckVar(var); }, //
[&](Default) {
AddError(std::string("missing validation of: ") + inst->TypeInfo().name);
});
@@ -282,6 +312,18 @@
}
}
+ void CheckBinary(ir::Binary* b) {
+ if (b->LHS() == nullptr) {
+ AddError(b, "binary: left operand is undefined");
+ }
+ if (b->RHS() == nullptr) {
+ AddError(b, "binary: right operand is undefined");
+ }
+ if (b->Result() == nullptr) {
+ AddError(b, "binary: result is undefined");
+ }
+ }
+
void CheckTerminator(ir::Terminator* b) {
tint::Switch(
b, //
@@ -304,12 +346,24 @@
void CheckIf(If* if_) {
if (!if_->Condition()) {
- AddError(if_, "if: condition is nullptr");
+ AddError(if_, "if: condition is undefined");
}
if (if_->Condition() && !if_->Condition()->Type()->Is<type::Bool>()) {
AddError(if_, If::kConditionOperandOffset, "if: condition must be a `bool` type");
}
}
+
+ void CheckVar(Var* var) {
+ if (var->Result() == nullptr) {
+ AddError(var, "var: result is undefined");
+ }
+
+ if (var->Result() && var->Initializer()) {
+ if (var->Initializer()->Type() != var->Result()->Type()->UnwrapPtr()) {
+ AddError(var, "var initializer has incorrect type");
+ }
+ }
+ }
}; // namespace
} // namespace
diff --git a/src/tint/ir/validate_test.cc b/src/tint/ir/validate_test.cc
index 080625c..8d696e1 100644
--- a/src/tint/ir/validate_test.cc
+++ b/src/tint/ir/validate_test.cc
@@ -71,7 +71,7 @@
mod.functions.Push(f);
f->SetParams({b.FunctionParam(ty.i32()), b.FunctionParam(ty.f32())});
- f->StartTarget()->Append(b.Return(f));
+ f->Block()->Append(b.Return(f));
auto res = ir::Validate(mod);
EXPECT_TRUE(res) << res.Failure().str();
@@ -101,9 +101,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.f32(), obj, 1_u, 0_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u, 0_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
EXPECT_TRUE(res) << res.Failure().str();
@@ -115,9 +116,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
EXPECT_TRUE(res) << res.Failure().str();
@@ -129,9 +131,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.f32(), obj, -1_i);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, -1_i);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -159,9 +162,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.f32(), obj, 1_u, 3_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u, 3_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -193,9 +197,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.ptr<private_, f32>(), obj, 1_u, 3_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.ptr<private_, f32>(), obj, 1_u, 3_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -228,9 +233,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.f32(), obj, 1_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -258,9 +264,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.ptr<private_, f32>(), obj, 1_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.ptr<private_, f32>(), obj, 1_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -294,9 +301,10 @@
f->SetParams({obj, idx});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.i32(), obj, idx);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.i32(), obj, idx);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -336,9 +344,10 @@
f->SetParams({obj, idx});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.i32(), obj, idx);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.i32(), obj, idx);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -372,9 +381,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.i32(), obj, 1_u, 1_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.i32(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -403,9 +413,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.ptr<private_, i32>(), obj, 1_u, 1_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.ptr<private_, i32>(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -435,9 +446,10 @@
f->SetParams({obj});
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Access(ty.f32(), obj, 1_u, 1_u);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Access(ty.f32(), obj, 1_u, 1_u);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -465,9 +477,10 @@
auto* f = b.Function("my_func", ty.void_());
mod.functions.Push(f);
- auto sb = b.With(f->StartTarget());
- sb.Return(f);
- sb.Return(f);
+ b.With(f->Block(), [&] {
+ b.Return(f);
+ b.Return(f);
+ });
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -498,8 +511,8 @@
if_->True()->Append(b.Return(f));
if_->False()->Append(b.Return(f));
- f->StartTarget()->Append(if_);
- f->StartTarget()->Append(b.Return(f));
+ f->Block()->Append(if_);
+ f->Block()->Append(b.Return(f));
auto res = ir::Validate(mod);
ASSERT_FALSE(res);
@@ -528,5 +541,308 @@
)");
}
+TEST_F(IR_ValidateTest, Var_RootBlock_NullResult) {
+ auto* v = mod.instructions.Create<ir::Var>(nullptr);
+ b.RootBlock()->Append(v);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:2:11 error: var: result is undefined
+ undef = var
+ ^^^
+
+:1:1 note: In block
+%b1 = block { # root
+^^^^^^^^^^^
+
+note: # Disassembly
+%b1 = block { # root
+ undef = var
+}
+
+)");
+}
+
+TEST_F(IR_ValidateTest, Var_Function_NullResult) {
+ auto* v = mod.instructions.Create<ir::Var>(nullptr);
+
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ sb.Append(v);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:13 error: var: result is undefined
+ undef = var
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ undef = var
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Var_Init_WrongType) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ auto* v = sb.Var(ty.ptr<function, f32>());
+ sb.Return(f);
+
+ auto* result = sb.InstructionResult(ty.i32());
+ v->SetInitializer(result);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:41 error: var initializer has incorrect type
+ %2:ptr<function, f32, read_write> = var, %3
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, f32, read_write> = var, %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Instruction_AppendedDead) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ auto* v = sb.Var(ty.ptr<function, f32>());
+ auto* ret = sb.Return(f);
+
+ v->Destroy();
+ v->InsertBefore(ret);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:41 error: destroyed instruction found in instruction list
+ %2:ptr<function, f32, read_write> = var
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+:3:41 error: instruction result source is undefined
+ %2:ptr<function, f32, read_write> = var
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, f32, read_write> = var
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Instruction_NullSource) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ auto* v = sb.Var(ty.ptr<function, f32>());
+ sb.Return(f);
+
+ v->Result()->SetSource(nullptr);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:41 error: instruction result source is undefined
+ %2:ptr<function, f32, read_write> = var
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, f32, read_write> = var
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Instruction_DeadOperand) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ auto* v = sb.Var(ty.ptr<function, f32>());
+ sb.Return(f);
+
+ auto* result = sb.InstructionResult(ty.f32());
+ result->Destroy();
+ v->SetInitializer(result);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:41 error: instruction has undefined operand
+ %2:ptr<function, f32, read_write> = var, %3
+ ^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, f32, read_write> = var, %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Instruction_OperandUsageRemoved) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ auto* v = sb.Var(ty.ptr<function, f32>());
+ sb.Return(f);
+
+ auto* result = sb.InstructionResult(ty.f32());
+ v->SetInitializer(result);
+ result->RemoveUsage({v, 0u});
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:46 error: instruction operand missing usage
+ %2:ptr<function, f32, read_write> = var, %3
+ ^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:ptr<function, f32, read_write> = var, %3
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Binary_LHS_Nullptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ sb.Add(ty.i32(), nullptr, sb.Constant(2_i));
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: left operand is undefined
+ %2:i32 = add undef, 2i
+ ^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:i32 = add undef, 2i
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Binary_RHS_Nullptr) {
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ sb.Add(ty.i32(), sb.Constant(2_i), nullptr);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: right operand is undefined
+ %2:i32 = add 2i, undef
+ ^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ %2:i32 = add 2i, undef
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidateTest, Binary_Result_Nullptr) {
+ auto* bin = mod.instructions.Create<ir::Binary>(nullptr, ir::Binary::Kind::kAdd,
+ b.Constant(3_i), b.Constant(2_i));
+
+ auto* f = b.Function("my_func", ty.void_());
+ mod.functions.Push(f);
+
+ auto sb = b.With(f->Block());
+ sb.Append(bin);
+ sb.Return(f);
+
+ auto res = ir::Validate(mod);
+ ASSERT_FALSE(res);
+ EXPECT_EQ(res.Failure().str(), R"(:3:5 error: binary: result is undefined
+ undef = add 3i, 2i
+ ^^^^^^^^^^^^^^^^^^
+
+:2:3 note: In block
+ %b1 = block {
+ ^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func():void -> %b1 {
+ %b1 = block {
+ undef = add 3i, 2i
+ ret
+ }
+}
+)");
+}
+
} // namespace
} // namespace tint::ir
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 6f56081..dd4f35e 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -2899,7 +2899,7 @@
// Push statement blocks for the then-clause and the else-clause.
// But make sure we do it in the right order.
- auto push_else = [this, builder, else_end, construct, false_is_break, false_is_continue]() {
+ auto push_else = [this, builder, else_end, construct, false_is_break, false_is_continue] {
// Push the else clause onto the stack first.
PushNewStatementBlock(construct, else_end, [=](const StatementList& stmts) {
// Only set the else-clause if there are statements to fill it.
@@ -3413,7 +3413,7 @@
auto copy_name = namer_.MakeDerivedName(namer_.Name(phi_id) + "_c" +
std::to_string(block_info.id));
auto copy_sym = builder_.Symbols().Register(copy_name);
- copied_phis.GetOrCreate(phi_id, [copy_sym]() { return copy_sym; });
+ copied_phis.GetOrCreate(phi_id, [copy_sym] { return copy_sym; });
AddStatement(builder_.WrapInStatement(
builder_.Let(copy_sym, builder_.Expr(namer_.Name(phi_id)))));
}
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 835c305..5b5c684 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -881,7 +881,7 @@
return builder_.ty(builder_.Ident(source.Source(), ident.to_str()));
}
- auto args = expect_template_arg_block("type template arguments", [&]() {
+ auto args = expect_template_arg_block("type template arguments", [&] {
return expect_expression_list("type template argument list",
Token::Type::kTemplateArgsRight);
});
@@ -2075,7 +2075,7 @@
const ast::Identifier* ident = nullptr;
if (peek_is(Token::Type::kTemplateArgsLeft)) {
- auto tmpl_args = expect_template_arg_block("template arguments", [&]() {
+ auto tmpl_args = expect_template_arg_block("template arguments", [&] {
return expect_expression_list("template argument list",
Token::Type::kTemplateArgsRight);
});
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 5997b71..c64a325 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -48,6 +48,7 @@
kDiagnostic,
kGroup,
kId,
+ kIndex,
kInterpolate,
kInvariant,
kLocation,
@@ -76,7 +77,13 @@
AttributeKind kind;
bool should_pass;
};
-struct TestWithParams : ResolverTestWithParam<TestParams> {};
+struct TestWithParams : ResolverTestWithParam<TestParams> {
+ void EnableExtensionIfNecessary(AttributeKind attributeKind) {
+ if (attributeKind == AttributeKind::kIndex) {
+ Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
+ }
+ }
+};
static utils::Vector<const ast::Attribute*, 2> createAttributes(const Source& source,
ProgramBuilder& builder,
@@ -95,6 +102,8 @@
return {builder.Group(source, 1_a)};
case AttributeKind::kId:
return {builder.Id(source, 0_a)};
+ case AttributeKind::kIndex:
+ return {builder.Index(source, 0_a)};
case AttributeKind::kInterpolate:
return {builder.Interpolate(source, builtin::InterpolationType::kLinear,
builtin::InterpolationSampling::kCenter)};
@@ -134,6 +143,8 @@
return "@group";
case AttributeKind::kId:
return "@id";
+ case AttributeKind::kIndex:
+ return "@index";
case AttributeKind::kInterpolate:
return "@interpolate";
case AttributeKind::kInvariant:
@@ -161,6 +172,7 @@
using FunctionParameterAttributeTest = TestWithParams;
TEST_P(FunctionParameterAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
Func("main",
utils::Vector{
@@ -190,6 +202,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -204,6 +217,7 @@
using FunctionReturnTypeAttributeTest = TestWithParams;
TEST_P(FunctionReturnTypeAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
Func("main", utils::Empty, ty.f32(),
utils::Vector{
@@ -228,6 +242,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -244,6 +259,7 @@
using ComputeShaderParameterAttributeTest = TestWithParams;
TEST_P(ComputeShaderParameterAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
Func("main",
utils::Vector{
Param("a", ty.vec4<f32>(), createAttributes(Source{{12, 34}}, *this, params.kind)),
@@ -281,6 +297,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -295,6 +312,7 @@
using FragmentShaderParameterAttributeTest = TestWithParams;
TEST_P(FragmentShaderParameterAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
if (params.kind != AttributeKind::kBuiltin && params.kind != AttributeKind::kLocation) {
attrs.Push(Builtin(Source{{34, 56}}, builtin::BuiltinValue::kPosition));
@@ -321,6 +339,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
// kInterpolate tested separately (requires @location)
TestParams{AttributeKind::kInvariant, true},
TestParams{AttributeKind::kLocation, true},
@@ -335,6 +354,8 @@
using VertexShaderParameterAttributeTest = TestWithParams;
TEST_P(VertexShaderParameterAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
+
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
if (params.kind != AttributeKind::kLocation) {
attrs.Push(Location(Source{{34, 56}}, 2_a));
@@ -377,6 +398,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, true},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, true},
@@ -391,6 +413,8 @@
using ComputeShaderReturnTypeAttributeTest = TestWithParams;
TEST_P(ComputeShaderReturnTypeAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
+
Func("main", utils::Empty, ty.vec4<f32>(),
utils::Vector{
Return(Call<vec4<f32>>(1_f)),
@@ -411,7 +435,8 @@
R"(12:34 error: @builtin(position) cannot be used in output of compute pipeline stage)");
} else if (params.kind == AttributeKind::kInterpolate ||
params.kind == AttributeKind::kLocation ||
- params.kind == AttributeKind::kInvariant) {
+ params.kind == AttributeKind::kInvariant ||
+ params.kind == AttributeKind::kIndex) {
EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
" is not valid for compute shader output");
} else {
@@ -429,6 +454,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -443,6 +469,8 @@
using FragmentShaderReturnTypeAttributeTest = TestWithParams;
TEST_P(FragmentShaderReturnTypeAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
+
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
attrs.Push(Location(Source{{34, 56}}, 2_a));
Func("frag_main", utils::Empty, ty.vec4<f32>(), utils::Vector{Return(Call<vec4<f32>>())},
@@ -481,6 +509,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, true},
TestParams{AttributeKind::kInterpolate, true},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -495,6 +524,7 @@
using VertexShaderReturnTypeAttributeTest = TestWithParams;
TEST_P(VertexShaderReturnTypeAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
// a vertex shader must include the 'position' builtin in its return type
if (params.kind != AttributeKind::kBuiltin) {
@@ -517,6 +547,8 @@
EXPECT_EQ(r()->error(),
R"(34:56 error: multiple entry point IO attributes
12:34 note: previously consumed @location)");
+ } else if (params.kind == AttributeKind::kIndex) {
+ EXPECT_EQ(r()->error(), R"(12:34 error: @index is not valid for vertex shader output)");
} else {
EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
" is not valid for entry point return types");
@@ -531,6 +563,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
// kInterpolate tested separately (requires @location)
TestParams{AttributeKind::kInvariant, true},
TestParams{AttributeKind::kLocation, false},
@@ -617,6 +650,7 @@
using SpirvBlockAttribute = ast::transform::AddBlockAttribute::BlockAttribute;
TEST_P(StructAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
Structure("mystruct", utils::Vector{Member("a", ty.f32())},
createAttributes(Source{{12, 34}}, *this, params.kind));
@@ -637,6 +671,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -651,6 +686,7 @@
using StructMemberAttributeTest = TestWithParams;
TEST_P(StructMemberAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
utils::Vector<const ast::StructMember*, 1> members;
if (params.kind == AttributeKind::kBuiltin) {
members.Push(
@@ -675,6 +711,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ // kIndex tested separately (requires @location)
// kInterpolate tested separately (requires @location)
// kInvariant tested separately (requires position builtin)
TestParams{AttributeKind::kLocation, true},
@@ -897,6 +934,7 @@
using ArrayAttributeTest = TestWithParams;
TEST_P(ArrayAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
auto arr = ty.array(ty.f32(), createAttributes(Source{{12, 34}}, *this, params.kind));
Structure("mystruct", utils::Vector{
@@ -919,6 +957,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -933,6 +972,7 @@
using VariableAttributeTest = TestWithParams;
TEST_P(VariableAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
if (IsBindingAttribute(params.kind)) {
@@ -959,6 +999,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -1001,6 +1042,7 @@
using ConstantAttributeTest = TestWithParams;
TEST_P(ConstantAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
GlobalConst("a", ty.f32(), Expr(1.23_f),
createAttributes(Source{{12, 34}}, *this, params.kind));
@@ -1021,6 +1063,7 @@
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -1045,6 +1088,7 @@
using OverrideAttributeTest = TestWithParams;
TEST_P(OverrideAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
Override("a", ty.f32(), Expr(1.23_f), createAttributes(Source{{12, 34}}, *this, params.kind));
@@ -1063,6 +1107,7 @@
TestParams{AttributeKind::kBuiltin, false},
TestParams{AttributeKind::kDiagnostic, false},
TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kId, true},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
@@ -1091,6 +1136,7 @@
using SwitchStatementAttributeTest = TestWithParams;
TEST_P(SwitchStatementAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(Switch(Expr(0_a), utils::Vector{DefaultCase()},
createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1111,6 +1157,7 @@
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -1125,6 +1172,7 @@
using SwitchBodyAttributeTest = TestWithParams;
TEST_P(SwitchBodyAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(Switch(Expr(0_a), utils::Vector{DefaultCase()}, utils::Empty,
createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1145,6 +1193,7 @@
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -1159,6 +1208,7 @@
using IfStatementAttributeTest = TestWithParams;
TEST_P(IfStatementAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(If(Expr(true), Block(), ElseStmt(),
createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1179,6 +1229,7 @@
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
@@ -1193,6 +1244,7 @@
using ForStatementAttributeTest = TestWithParams;
TEST_P(ForStatementAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(For(nullptr, Expr(false), nullptr, Block(),
createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1212,6 +1264,7 @@
TestParams{AttributeKind::kBuiltin, false},
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kId, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
@@ -1227,6 +1280,7 @@
using LoopStatementAttributeTest = TestWithParams;
TEST_P(LoopStatementAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(
Loop(Block(Return()), Block(), createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1246,6 +1300,7 @@
TestParams{AttributeKind::kBuiltin, false},
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kId, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
@@ -1261,6 +1316,7 @@
using WhileStatementAttributeTest = TestWithParams;
TEST_P(WhileStatementAttributeTest, IsValid) {
auto& params = GetParam();
+ EnableExtensionIfNecessary(params.kind);
WrapInFunction(
While(Expr(false), Block(), createAttributes(Source{{12, 34}}, *this, params.kind)));
@@ -1280,6 +1336,7 @@
TestParams{AttributeKind::kBuiltin, false},
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kId, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
@@ -1304,6 +1361,9 @@
"error: " + name(GetParam().kind) + " is not valid for block statements");
}
}
+
+ public:
+ BlockStatementTest() { EnableExtensionIfNecessary(GetParam().kind); }
};
TEST_P(BlockStatementTest, CompoundStatement) {
Func("foo", utils::Empty, ty.void_(),
@@ -1383,6 +1443,7 @@
TestParams{AttributeKind::kDiagnostic, true},
TestParams{AttributeKind::kGroup, false},
TestParams{AttributeKind::kId, false},
+ TestParams{AttributeKind::kIndex, false},
TestParams{AttributeKind::kInterpolate, false},
TestParams{AttributeKind::kInvariant, false},
TestParams{AttributeKind::kLocation, false},
diff --git a/src/tint/resolver/builtin_test.cc b/src/tint/resolver/builtin_test.cc
index 07abb21..04e23f6 100644
--- a/src/tint/resolver/builtin_test.cc
+++ b/src/tint/resolver/builtin_test.cc
@@ -2113,7 +2113,7 @@
case type::TextureDimension::kCubeArray:
return ty.vec3(scalar);
default:
- [=]() {
+ [=] {
utils::StringStream str;
str << dim;
FAIL() << "Unsupported texture dimension: " << str.str();
diff --git a/src/tint/resolver/dual_source_blending_extension_test.cc b/src/tint/resolver/dual_source_blending_extension_test.cc
index ebaf61c..4974940 100644
--- a/src/tint/resolver/dual_source_blending_extension_test.cc
+++ b/src/tint/resolver/dual_source_blending_extension_test.cc
@@ -1,4 +1,4 @@
-// Copyright 2022 The Tint Authors.
+// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,9 +37,15 @@
"'chromium_internal_dual_source_blending'");
}
-TEST_F(DualSourceBlendingExtensionTest, IndexF32Error) {
- Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
+class DualSourceBlendingExtensionTests : public ResolverTest {
+ public:
+ DualSourceBlendingExtensionTests() {
+ Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
+ }
+};
+// Using an F32 as an index value should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexF32Error) {
Structure("Output", utils::Vector{
Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
utils::Vector{Location(0_a), Index(Source{{12, 34}}, 0_f)}),
@@ -49,9 +55,8 @@
EXPECT_EQ(r()->error(), "12:34 error: @location must be an i32 or u32 value");
}
-TEST_F(DualSourceBlendingExtensionTest, IndexFloatValueError) {
- Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
-
+// Using a floating point number as an index value should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexFloatValueError) {
Structure("Output", utils::Vector{
Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
utils::Vector{Location(0_a), Index(Source{{12, 34}}, 1.0_a)}),
@@ -60,9 +65,8 @@
EXPECT_EQ(r()->error(), "12:34 error: @location must be an i32 or u32 value");
}
-TEST_F(DualSourceBlendingExtensionTest, IndexNegativeValue) {
- Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
-
+// Using a number less than zero as an index value should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexNegativeValue) {
Structure("Output", utils::Vector{
Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
utils::Vector{Location(0_a), Index(Source{{12, 34}}, -1_a)}),
@@ -72,9 +76,8 @@
EXPECT_EQ(r()->error(), "12:34 error: @index value must be zero or one");
}
-TEST_F(DualSourceBlendingExtensionTest, IndexValueAboveOne) {
- Enable(builtin::Extension::kChromiumInternalDualSourceBlending);
-
+// Using a number greater than one as an index value should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexValueAboveOne) {
Structure("Output", utils::Vector{
Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
utils::Vector{Location(0_a), Index(Source{{12, 34}}, 2_a)}),
@@ -84,5 +87,38 @@
EXPECT_EQ(r()->error(), "12:34 error: @index value must be zero or one");
}
+// Using an index value at the same location multiple times should fail.
+TEST_F(DualSourceBlendingExtensionTests, DuplicateIndexes) {
+ Structure("Output", utils::Vector{
+ Member("a", ty.vec4<f32>(), utils::Vector{Location(0_a), Index(0_a)}),
+ Member(Source{{12, 34}}, "b", ty.vec4<f32>(),
+ utils::Vector{Location(0_a), Index(Source{{12, 34}}, 0_a)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: @location(0) @index(0) appears multiple times");
+}
+
+// Using the index attribute without a location attribute should fail.
+TEST_F(DualSourceBlendingExtensionTests, IndexWithMissingLocationAttribute) {
+ Structure("Output", utils::Vector{
+ Member(Source{{12, 34}}, "a", ty.vec4<f32>(),
+ utils::Vector{Index(Source{{12, 34}}, 1_a)}),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: index attribute must only be used with @location");
+}
+
+// Using an index attribute on a struct member should pass.
+TEST_F(DualSourceBlendingExtensionTests, StructMemberIndexAttribute) {
+ Structure("Output", utils::Vector{
+ Member("a", ty.vec4<f32>(),
+ utils::Vector{Location(0_a), Index(Source{{12, 34}}, 0_a)}),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 27cf5cf..9021135 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -1362,7 +1362,7 @@
}
auto eval_stage = match.overload->const_eval_fn ? sem::EvaluationStage::kConstant
: sem::EvaluationStage::kRuntime;
- auto* target = constructors.GetOrCreate(match, [&]() {
+ auto* target = constructors.GetOrCreate(match, [&] {
return builder.create<sem::ValueConstructor>(match.return_type, std::move(params),
eval_stage);
});
@@ -1370,7 +1370,7 @@
}
// Conversion.
- auto* target = converters.GetOrCreate(match, [&]() {
+ auto* target = converters.GetOrCreate(match, [&] {
auto param = builder.create<sem::Parameter>(
nullptr, 0u, match.parameters[0].type, builtin::AddressSpace::kUndefined,
builtin::Access::kUndefined, match.parameters[0].usage);
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 51bf27c..254c969 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1044,6 +1044,14 @@
func->SetReturnLocation(value.Get());
return kSuccess;
},
+ [&](const ast::IndexAttribute* attr) {
+ auto value = IndexAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ func->SetReturnIndex(value.Get());
+ return kSuccess;
+ },
[&](const ast::BuiltinAttribute* attr) {
return BuiltinAttribute(attr) ? kSuccess : kErrored;
},
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 45c919e..007de95 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -644,7 +644,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -680,7 +680,7 @@
}
// Add an edge from the variable exit node to its value at this point.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -783,7 +783,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -859,7 +859,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -1107,7 +1107,7 @@
}
// Add an edge from the variable exit node to its new value.
- auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
+ auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&] {
auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 17b75e1..3ef674e 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1052,7 +1052,7 @@
// TODO(jrprice): This state could be stored in sem::Function instead, and then passed to
// sem::Function since it would be useful there too.
utils::Hashset<builtin::BuiltinValue, 4> builtins;
- utils::Hashset<uint32_t, 8> locations;
+ utils::Hashset<std::pair<uint32_t, uint32_t>, 8> locationsAndIndexes;
enum class ParamOrRetType {
kParameter,
kReturnType,
@@ -1063,10 +1063,13 @@
const type::Type* ty, Source source,
ParamOrRetType param_or_ret,
bool is_struct_member,
- std::optional<uint32_t> location) {
+ std::optional<uint32_t> location,
+ std::optional<uint32_t> index) {
// Scan attributes for pipeline IO attributes.
// Check for overlap with attributes that have been seen previously.
const ast::Attribute* pipeline_io_attribute = nullptr;
+ const ast::LocationAttribute* location_attribute = nullptr;
+ const ast::IndexAttribute* index_attribute = nullptr;
const ast::InterpolateAttribute* interpolate_attribute = nullptr;
const ast::InvariantAttribute* invariant_attribute = nullptr;
for (auto* attr : attrs) {
@@ -1097,6 +1100,7 @@
}
builtins.Add(builtin);
} else if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
+ location_attribute = loc_attr;
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", attr->source);
AddNote("previously consumed " + AttrToStr(pipeline_io_attribute),
@@ -1112,12 +1116,12 @@
return false;
}
- if (!LocationAttribute(loc_attr, location.value(), ty, locations, stage, source,
- is_input)) {
+ if (!LocationAttribute(loc_attr, ty, stage, source, is_input)) {
return false;
}
} else if (auto* index_attr = attr->As<ast::IndexAttribute>()) {
- return IndexAttribute(index_attr);
+ index_attribute = index_attr;
+ return IndexAttribute(index_attr, stage);
} else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_attribute = true;
@@ -1177,6 +1181,34 @@
}
}
+ if (index_attribute) {
+ if (Is<ast::LocationAttribute>(pipeline_io_attribute)) {
+ AddError("index attribute must only be used with @location",
+ index_attribute->source);
+ return false;
+ }
+ }
+
+ if (location_attribute) {
+ uint32_t idx = 0xffffffff;
+ if (index_attribute) {
+ idx = index.value();
+ }
+
+ std::pair<uint32_t, uint32_t> locationAndIndex(location.value(), idx);
+ if (!locationsAndIndexes.Add(locationAndIndex)) {
+ utils::StringStream err;
+ if (!index_attribute) {
+ err << "@location(" << location.value() << ") appears multiple times";
+ } else {
+ err << "@location(" << location.value() << ") @index(" << index.value()
+ << ") appears multiple times";
+ }
+ AddError(err.str(), location_attribute->source);
+ return false;
+ }
+ }
+
if (interpolate_attribute) {
if (!pipeline_io_attribute ||
!pipeline_io_attribute->Is<ast::LocationAttribute>()) {
@@ -1207,9 +1239,11 @@
// Outer lambda for validating the entry point attributes for a type.
auto validate_entry_point_attributes =
[&](utils::VectorRef<const ast::Attribute*> attrs, const type::Type* ty, Source source,
- ParamOrRetType param_or_ret, std::optional<uint32_t> location) {
+ ParamOrRetType param_or_ret, std::optional<uint32_t> location,
+ std::optional<uint32_t> index) {
if (!validate_entry_point_attributes_inner(attrs, ty, source, param_or_ret,
- /*is_struct_member*/ false, location)) {
+ /*is_struct_member*/ false, location,
+ index)) {
return false;
}
@@ -1218,7 +1252,8 @@
if (!validate_entry_point_attributes_inner(
member->Declaration()->attributes, member->Type(),
member->Declaration()->source, param_or_ret,
- /*is_struct_member*/ true, member->Attributes().location)) {
+ /*is_struct_member*/ true, member->Attributes().location,
+ member->Attributes().index)) {
AddNote("while analyzing entry point '" + decl->name->symbol.Name() + "'",
decl->source);
return false;
@@ -1233,7 +1268,7 @@
auto* param_decl = param->Declaration();
if (!validate_entry_point_attributes(param_decl->attributes, param->Type(),
param_decl->source, ParamOrRetType::kParameter,
- param->Location())) {
+ param->Location(), std::nullopt)) {
return false;
}
}
@@ -1241,12 +1276,12 @@
// Clear IO sets after parameter validation. Builtin and location attributes in return types
// should be validated independently from those used in parameters.
builtins.Clear();
- locations.Clear();
+ locationsAndIndexes.Clear();
if (!func->ReturnType()->Is<type::Void>()) {
if (!validate_entry_point_attributes(decl->return_type_attributes, func->ReturnType(),
decl->source, ParamOrRetType::kReturnType,
- func->ReturnLocation())) {
+ func->ReturnLocation(), func->ReturnIndex())) {
return false;
}
}
@@ -2073,7 +2108,7 @@
return false;
}
- utils::Hashset<uint32_t, 8> locations;
+ utils::Hashset<std::pair<uint32_t, uint32_t>, 8> locationsAndIndexes;
for (auto* member : str->Members()) {
if (auto* r = member->Type()->As<type::Array>()) {
if (r->Count()->Is<type::RuntimeArrayCount>()) {
@@ -2095,8 +2130,9 @@
return false;
}
- auto has_location = false;
auto has_position = false;
+ const ast::IndexAttribute* index_attribute = nullptr;
+ const ast::LocationAttribute* location_attribute = nullptr;
const ast::InvariantAttribute* invariant_attribute = nullptr;
const ast::InterpolateAttribute* interpolate_attribute = nullptr;
for (auto* attr : member->Declaration()->attributes) {
@@ -2107,16 +2143,18 @@
return true;
},
[&](const ast::LocationAttribute* location) {
- has_location = true;
+ location_attribute = location;
TINT_ASSERT(Resolver, member->Attributes().location.has_value());
- if (!LocationAttribute(location, member->Attributes().location.value(),
- member->Type(), locations, stage,
+ if (!LocationAttribute(location, member->Type(), stage,
member->Declaration()->source)) {
return false;
}
return true;
},
- [&](const ast::IndexAttribute* index) { return IndexAttribute(index); },
+ [&](const ast::IndexAttribute* index) {
+ index_attribute = index;
+ return IndexAttribute(index, stage);
+ },
[&](const ast::BuiltinAttribute* builtin_attr) {
if (!BuiltinAttribute(builtin_attr, member->Type(), stage,
/* is_input */ false)) {
@@ -2157,20 +2195,46 @@
return false;
}
- if (interpolate_attribute && !has_location) {
+ if (index_attribute && !location_attribute) {
+ AddError("index attribute must only be used with @location", index_attribute->source);
+ return false;
+ }
+
+ if (interpolate_attribute && !location_attribute) {
AddError("interpolate attribute must only be used with @location",
interpolate_attribute->source);
return false;
}
+
+ // Ensure all locations and index pairs are unique
+ if (location_attribute) {
+ uint32_t index = 0xffffffff;
+ if (index_attribute) {
+ index = member->Attributes().index.value();
+ }
+ uint32_t location = member->Attributes().location.value();
+ std::pair<uint32_t, uint32_t> locationAndIndex(location, index);
+ if (!locationsAndIndexes.Add(locationAndIndex)) {
+ utils::StringStream err;
+ if (!index_attribute) {
+ err << "@location(" << location << ") appears multiple times";
+ AddError(err.str(), location_attribute->source);
+ } else {
+ err << "@location(" << location << ") @index(" << index
+ << ") appears multiple times";
+ AddError(err.str(), index_attribute->source);
+ }
+
+ return false;
+ }
+ }
}
return true;
}
bool Validator::LocationAttribute(const ast::LocationAttribute* loc_attr,
- uint32_t location,
const type::Type* type,
- utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input) const {
@@ -2191,17 +2255,11 @@
return false;
}
- if (!locations.Add(location)) {
- utils::StringStream err;
- err << "@location(" << location << ") appears multiple times";
- AddError(err.str(), loc_attr->source);
- return false;
- }
-
return true;
}
-bool Validator::IndexAttribute(const ast::IndexAttribute* index_attr) const {
+bool Validator::IndexAttribute(const ast::IndexAttribute* index_attr,
+ ast::PipelineStage stage) const {
if (!enabled_extensions_.Contains(builtin::Extension::kChromiumInternalDualSourceBlending)) {
AddError(
"use of '@index' attribute requires enabling extension "
@@ -2210,7 +2268,19 @@
return false;
}
- return false;
+ if (stage == ast::PipelineStage::kCompute) {
+ AddError("@" + index_attr->Name() + " is not valid for compute shader output",
+ index_attr->source);
+ return false;
+ }
+
+ if (stage == ast::PipelineStage::kVertex) {
+ AddError("@" + index_attr->Name() + " is not valid for vertex shader output",
+ index_attr->source);
+ return false;
+ }
+
+ return true;
}
bool Validator::Return(const ast::ReturnStatement* ret,
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index 19ceb5e..6f32d89 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -314,25 +314,22 @@
/// Validates a location attribute
/// @param loc_attr the location attribute to validate
- /// @param location the location value
/// @param type the variable type
- /// @param locations the set of locations in the module
/// @param stage the current pipeline stage
/// @param source the source of the attribute
/// @param is_input true if this is an input variable
/// @returns true on success, false otherwise.
bool LocationAttribute(const ast::LocationAttribute* loc_attr,
- uint32_t location,
const type::Type* type,
- utils::Hashset<uint32_t, 8>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false) const;
/// Validates a index attribute
/// @param index_attr the index attribute to validate
+ /// @param stage the current pipeline stage
/// @returns true on success, false otherwise.
- bool IndexAttribute(const ast::IndexAttribute* index_attr) const;
+ bool IndexAttribute(const ast::IndexAttribute* index_attr, ast::PipelineStage stage) const;
/// Validates a loop statement
/// @param stmt the loop statement
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index 31390f7..6a27be4 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -63,6 +63,10 @@
/// @param return_location the location value
void SetReturnLocation(uint32_t return_location) { return_location_ = return_location; }
+ // Sets the function's return index
+ /// @param return_index the index value
+ void SetReturnIndex(uint32_t return_index) { return_index_ = return_index; }
+
/// @returns the ast::Function declaration
const ast::Function* Declaration() const { return declaration_; }
@@ -255,6 +259,9 @@
/// @return the location for the return, if provided
std::optional<uint32_t> ReturnLocation() const { return return_location_; }
+ /// @return the index for the return, if provided
+ std::optional<uint32_t> ReturnIndex() const { return return_index_; }
+
/// Modifies the severity of a specific diagnostic rule for this function.
/// @param rule the diagnostic rule
/// @param severity the new diagnostic severity
@@ -290,6 +297,7 @@
builtin::DiagnosticRuleSeverities diagnostic_severities_;
std::optional<uint32_t> return_location_;
+ std::optional<uint32_t> return_index_;
};
} // namespace tint::sem
diff --git a/src/tint/sem/test_helper.h b/src/tint/sem/test_helper.h
index 94f66e4..9433a8c 100644
--- a/src/tint/sem/test_helper.h
+++ b/src/tint/sem/test_helper.h
@@ -30,7 +30,7 @@
/// @return the built program
Program Build() {
diag::Formatter formatter;
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< formatter.format(Diagnostics());
}();
diff --git a/src/tint/transform/manager.cc b/src/tint/transform/manager.cc
index 92d0427..def22dd 100644
--- a/src/tint/transform/manager.cc
+++ b/src/tint/transform/manager.cc
@@ -88,7 +88,7 @@
// Helper functions to get the current program state as either an AST program or IR module,
// performing a conversion if necessary.
- auto get_ast = [&]() {
+ auto get_ast = [&] {
#if TINT_BUILD_IR
if (std::holds_alternative<ir::Module*>(target)) {
// Convert the IR module to an AST program.
@@ -100,7 +100,7 @@
return std::get<const Program*>(target);
};
#if TINT_BUILD_IR
- auto get_ir = [&]() {
+ auto get_ir = [&] {
if (std::holds_alternative<const Program*>(target)) {
// Convert the AST program to an IR module.
auto converted = ir::FromProgram(std::get<const Program*>(target));
diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc
index c1083b2..e578766 100644
--- a/src/tint/transform/manager_test.cc
+++ b/src/tint/transform/manager_test.cc
@@ -51,7 +51,7 @@
void Run(ir::Module* mod, const DataMap&, DataMap&) const override {
ir::Builder builder(*mod);
auto* func = builder.Function("ir_func", mod->Types().Get<type::Void>());
- func->StartTarget()->Append(builder.Return(func));
+ func->Block()->Append(builder.Return(func));
mod->functions.Push(func);
}
};
@@ -68,7 +68,7 @@
ir::Module mod;
ir::Builder builder(mod);
auto* func = builder.Function("main", mod.Types().Get<type::Void>());
- func->StartTarget()->Append(builder.Return(func));
+ func->Block()->Append(builder.Return(func));
builder.ir.functions.Push(func);
return mod;
}
diff --git a/src/tint/type/struct.cc b/src/tint/type/struct.cc
index 9b3bf49..d8424ae 100644
--- a/src/tint/type/struct.cc
+++ b/src/tint/type/struct.cc
@@ -116,7 +116,7 @@
<< ") */ struct " << struct_name << " {\n";
};
- auto print_struct_end_line = [&]() {
+ auto print_struct_end_line = [&] {
ss << "/* " << std::setw(offset_w + size_w + align_w) << " "
<< "*/ };";
};
diff --git a/src/tint/type/test_helper.h b/src/tint/type/test_helper.h
index 95827b0..c494024 100644
--- a/src/tint/type/test_helper.h
+++ b/src/tint/type/test_helper.h
@@ -30,7 +30,7 @@
/// @return the built program
Program Build() {
diag::Formatter formatter;
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< formatter.format(Diagnostics());
}();
diff --git a/src/tint/utils/hashmap_base.h b/src/tint/utils/hashmap_base.h
index d2aaf1d..cb5f957 100644
--- a/src/tint/utils/hashmap_base.h
+++ b/src/tint/utils/hashmap_base.h
@@ -476,7 +476,7 @@
const auto hash = Hash(key);
- auto make_entry = [&]() {
+ auto make_entry = [&] {
if constexpr (ValueIsVoid) {
return std::forward<K>(key);
} else {
diff --git a/src/tint/writer/ast_text_generator.cc b/src/tint/writer/ast_text_generator.cc
new file mode 100644
index 0000000..dc0530b
--- /dev/null
+++ b/src/tint/writer/ast_text_generator.cc
@@ -0,0 +1,42 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/ast_text_generator.h"
+
+#include <algorithm>
+#include <limits>
+
+#include "src/tint/utils/map.h"
+
+namespace tint::writer {
+
+ASTTextGenerator::ASTTextGenerator(const Program* program)
+ : program_(program), builder_(ProgramBuilder::Wrap(program)) {}
+
+ASTTextGenerator::~ASTTextGenerator() = default;
+
+std::string ASTTextGenerator::UniqueIdentifier(const std::string& prefix) {
+ return builder_.Symbols().New(prefix).Name();
+}
+
+std::string ASTTextGenerator::StructName(const type::Struct* s) {
+ auto name = s->Name().Name();
+ if (name.size() > 1 && name[0] == '_' && name[1] == '_') {
+ name = utils::GetOrCreate(builtin_struct_names_, s,
+ [&] { return UniqueIdentifier(name.substr(2)); });
+ }
+ return name;
+}
+
+} // namespace tint::writer
diff --git a/src/tint/writer/ast_text_generator.h b/src/tint/writer/ast_text_generator.h
new file mode 100644
index 0000000..243458c
--- /dev/null
+++ b/src/tint/writer/ast_text_generator.h
@@ -0,0 +1,71 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_WRITER_AST_TEXT_GENERATOR_H_
+#define SRC_TINT_WRITER_AST_TEXT_GENERATOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/writer/text_generator.h"
+
+namespace tint::writer {
+
+/// Helper methods for generators which are creating text output
+class ASTTextGenerator : public TextGenerator {
+ public:
+ /// Constructor
+ /// @param program the program used by the generator
+ explicit ASTTextGenerator(const Program* program);
+ ~ASTTextGenerator();
+
+ /// @return a new, unique identifier with the given prefix.
+ /// @param prefix optional prefix to apply to the generated identifier. If
+ /// empty "tint_symbol" will be used.
+ std::string UniqueIdentifier(const std::string& prefix = "");
+
+ /// @param s the semantic structure
+ /// @returns the name of the structure, taking special care of builtin
+ /// structures that start with double underscores. If the structure is a
+ /// builtin, then the returned name will be a unique name without the leading
+ /// underscores.
+ std::string StructName(const type::Struct* s);
+
+ protected:
+ /// @returns the resolved type of the ast::Expression `expr`
+ /// @param expr the expression
+ const type::Type* TypeOf(const ast::Expression* expr) const { return builder_.TypeOf(expr); }
+
+ /// @returns the resolved type of the ast::TypeDecl `type_decl`
+ /// @param type_decl the type
+ const type::Type* TypeOf(const ast::TypeDecl* type_decl) const {
+ return builder_.TypeOf(type_decl);
+ }
+
+ /// The program
+ Program const* const program_;
+ /// A ProgramBuilder that thinly wraps program_
+ ProgramBuilder builder_;
+
+ private:
+ /// Map of builtin structure to unique generated name
+ std::unordered_map<const type::Struct*, std::string> builtin_struct_names_;
+};
+
+} // namespace tint::writer
+
+#endif // SRC_TINT_WRITER_AST_TEXT_GENERATOR_H_
diff --git a/src/tint/writer/text_generator_test.cc b/src/tint/writer/ast_text_generator_test.cc
similarity index 84%
rename from src/tint/writer/text_generator_test.cc
rename to src/tint/writer/ast_text_generator_test.cc
index 1bfa41f..d879f68 100644
--- a/src/tint/writer/text_generator_test.cc
+++ b/src/tint/writer/ast_text_generator_test.cc
@@ -12,29 +12,29 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/tint/writer/text_generator.h"
+#include "src/tint/writer/ast_text_generator.h"
#include "gtest/gtest.h"
namespace tint::writer {
namespace {
-TEST(TextGeneratorTest, UniqueIdentifier) {
+TEST(ASTTextGeneratorTest, UniqueIdentifier) {
Program program(ProgramBuilder{});
- TextGenerator gen(&program);
+ ASTTextGenerator gen(&program);
ASSERT_EQ(gen.UniqueIdentifier("ident"), "ident");
ASSERT_EQ(gen.UniqueIdentifier("ident"), "ident_1");
}
-TEST(TextGeneratorTest, UniqueIdentifier_ConflictWithExisting) {
+TEST(ASTTextGeneratorTest, UniqueIdentifier_ConflictWithExisting) {
ProgramBuilder builder;
builder.Symbols().Register("ident_1");
builder.Symbols().Register("ident_2");
Program program(std::move(builder));
- TextGenerator gen(&program);
+ ASTTextGenerator gen(&program);
ASSERT_EQ(gen.UniqueIdentifier("ident"), "ident");
ASSERT_EQ(gen.UniqueIdentifier("ident"), "ident_3");
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 630e24f..537d328 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -254,7 +254,7 @@
}
GeneratorImpl::GeneratorImpl(const Program* program, const Version& version)
- : TextGenerator(program), version_(version) {}
+ : ASTTextGenerator(program), version_(version) {}
GeneratorImpl::~GeneratorImpl() = default;
@@ -2244,7 +2244,7 @@
}
void GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
- auto emit_continuing = [this, stmt]() {
+ auto emit_continuing = [this, stmt] {
if (stmt->continuing && !stmt->continuing->Empty()) {
EmitBlock(stmt->continuing);
}
@@ -2303,7 +2303,7 @@
}
if (emit_as_loop) {
- auto emit_continuing = [&]() { current_buffer_->Append(cont_buf); };
+ auto emit_continuing = [&] { current_buffer_->Append(cont_buf); };
TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
line() << "while (true) {";
@@ -2360,7 +2360,7 @@
EmitExpression(cond_buf, cond);
}
- auto emit_continuing = [&]() {};
+ auto emit_continuing = [&] {};
TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
// If the whilehas a multi-statement conditional, then we cannot emit this
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 7f9e31c..23633ae 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -38,9 +38,9 @@
#include "src/tint/scope_stack.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/string_stream.h"
+#include "src/tint/writer/ast_text_generator.h"
#include "src/tint/writer/glsl/generator.h"
#include "src/tint/writer/glsl/version.h"
-#include "src/tint/writer/text_generator.h"
// Forward declarations
namespace tint::sem {
@@ -75,7 +75,7 @@
const std::string& entry_point);
/// Implementation class for GLSL generator
-class GeneratorImpl : public TextGenerator {
+class GeneratorImpl : public ASTTextGenerator {
public:
/// Constructor
/// @param program the program to generate
diff --git a/src/tint/writer/glsl/test_helper.h b/src/tint/writer/glsl/test_helper.h
index 98c7d8d..c25c98b 100644
--- a/src/tint/writer/glsl/test_helper.h
+++ b/src/tint/writer/glsl/test_helper.h
@@ -49,12 +49,12 @@
if (gen_) {
return *gen_;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
gen_ = std::make_unique<GeneratorImpl>(program.get(), version);
@@ -74,15 +74,15 @@
return *gen_;
}
diag::Formatter formatter;
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< formatter.format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
+ [&] { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
auto sanitized_result = Sanitize(program.get(), options, /* entry_point */ "");
- [&]() {
+ [&] {
ASSERT_TRUE(sanitized_result.program.IsValid())
<< formatter.format(sanitized_result.program.Diagnostics());
}();
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 1267bea..d00859c 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -328,7 +328,7 @@
return result;
}
-GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
+GeneratorImpl::GeneratorImpl(const Program* program) : ASTTextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default;
@@ -3621,7 +3621,7 @@
}
bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
- auto emit_continuing = [this, stmt]() {
+ auto emit_continuing = [this, stmt] {
if (stmt->continuing && !stmt->continuing->Empty()) {
if (!EmitBlock(stmt->continuing)) {
return false;
@@ -3695,7 +3695,7 @@
}
if (emit_as_loop) {
- auto emit_continuing = [&]() {
+ auto emit_continuing = [&] {
current_buffer_->Append(cont_buf);
return true;
};
@@ -3766,7 +3766,7 @@
}
}
- auto emit_continuing = [&]() { return true; };
+ auto emit_continuing = [&] { return true; };
TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
// If the while has a multi-statement conditional, then we cannot emit this
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index 2ae653c..3028853 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -38,8 +38,8 @@
#include "src/tint/sem/binding_point.h"
#include "src/tint/utils/hash.h"
#include "src/tint/writer/array_length_from_uniform_options.h"
+#include "src/tint/writer/ast_text_generator.h"
#include "src/tint/writer/hlsl/generator.h"
-#include "src/tint/writer/text_generator.h"
// Forward declarations
namespace tint::sem {
@@ -74,7 +74,7 @@
SanitizedResult Sanitize(const Program* program, const Options& options);
/// Implementation class for HLSL generator
-class GeneratorImpl : public TextGenerator {
+class GeneratorImpl : public ASTTextGenerator {
public:
/// Constructor
/// @param program the program to generate
diff --git a/src/tint/writer/hlsl/test_helper.h b/src/tint/writer/hlsl/test_helper.h
index 2f8e886..35608a0 100644
--- a/src/tint/writer/hlsl/test_helper.h
+++ b/src/tint/writer/hlsl/test_helper.h
@@ -50,12 +50,12 @@
if (gen_) {
return *gen_;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
gen_ = std::make_unique<GeneratorImpl>(program.get());
@@ -73,15 +73,15 @@
return *gen_;
}
diag::Formatter formatter;
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< formatter.format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
+ [&] { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
auto sanitized_result = Sanitize(program.get(), options);
- [&]() {
+ [&] {
ASSERT_TRUE(sanitized_result.program.IsValid())
<< formatter.format(sanitized_result.program.Diagnostics());
}();
@@ -94,7 +94,7 @@
/* preserve_unicode */ true);
transform_manager.Add<tint::ast::transform::Renamer>();
auto result = transform_manager.Run(&sanitized_result.program, transform_data, outputs);
- [&]() { ASSERT_TRUE(result.IsValid()) << formatter.format(result.Diagnostics()); }();
+ [&] { ASSERT_TRUE(result.IsValid()) << formatter.format(result.Diagnostics()); }();
*program = std::move(result);
gen_ = std::make_unique<GeneratorImpl>(program.get());
return *gen_;
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 28101df..3cf5b31 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -271,7 +271,7 @@
return result;
}
-GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
+GeneratorImpl::GeneratorImpl(const Program* program) : ASTTextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default;
@@ -1014,7 +1014,7 @@
// Helper to emit the texture expression, wrapped in parentheses if the
// expression includes an operator with lower precedence than the member
// accessor used for the function calls.
- auto texture_expr = [&]() {
+ auto texture_expr = [&] {
bool paren_lhs = !texture->IsAnyOf<ast::AccessorExpression, ast::CallExpression,
ast::IdentifierExpression>();
if (paren_lhs) {
@@ -2137,7 +2137,7 @@
}
bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
- auto emit_continuing = [this, stmt]() {
+ auto emit_continuing = [this, stmt] {
if (stmt->continuing && !stmt->continuing->Empty()) {
if (!EmitBlock(stmt->continuing)) {
return false;
@@ -2211,7 +2211,7 @@
});
if (emit_as_loop) {
- auto emit_continuing = [&]() {
+ auto emit_continuing = [&] {
current_buffer_->Append(cont_buf);
return true;
};
@@ -2283,7 +2283,7 @@
}
}
- auto emit_continuing = [&]() { return true; };
+ auto emit_continuing = [&] { return true; };
TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
// If the while has a multi-statement conditional, then we cannot emit this
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index 0be80be..b9bd350 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -42,8 +42,8 @@
#include "src/tint/sem/struct.h"
#include "src/tint/utils/string_stream.h"
#include "src/tint/writer/array_length_from_uniform_options.h"
+#include "src/tint/writer/ast_text_generator.h"
#include "src/tint/writer/msl/generator.h"
-#include "src/tint/writer/text_generator.h"
// Forward declarations
namespace tint::sem {
@@ -80,7 +80,7 @@
SanitizedResult Sanitize(const Program* program, const Options& options);
/// Implementation class for MSL generator
-class GeneratorImpl : public TextGenerator {
+class GeneratorImpl : public ASTTextGenerator {
public:
/// Constructor
/// @param program the program to generate
diff --git a/src/tint/writer/msl/test_helper.h b/src/tint/writer/msl/test_helper.h
index 17802d7..78be5af 100644
--- a/src/tint/writer/msl/test_helper.h
+++ b/src/tint/writer/msl/test_helper.h
@@ -49,12 +49,12 @@
if (gen_) {
return *gen_;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
gen_ = std::make_unique<GeneratorImpl>(program.get());
@@ -71,17 +71,17 @@
if (gen_) {
return *gen_;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
auto result = Sanitize(program.get(), options);
- [&]() {
+ [&] {
ASSERT_TRUE(result.program.IsValid())
<< diag::Formatter().format(result.program.Diagnostics());
}();
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index a0ce6f0..ccc51dc 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -2626,7 +2626,7 @@
image_operands.reserve(4); // Enough to fit most parameter lists
// Appends `result_type` and `result_id` to `spirv_params`
- auto append_result_type_and_id_to_spirv_params = [&]() {
+ auto append_result_type_and_id_to_spirv_params = [&] {
spirv_params.emplace_back(std::move(result_type));
spirv_params.emplace_back(std::move(result_id));
};
@@ -2641,7 +2641,7 @@
//
// If the texture is not a depth texture, then this function simply delegates
// to calling append_result_type_and_id_to_spirv_params().
- auto append_result_type_and_id_to_spirv_params_for_read = [&]() {
+ auto append_result_type_and_id_to_spirv_params_for_read = [&] {
if (texture_type->IsAnyOf<type::DepthTexture, type::DepthMultisampledTexture>()) {
auto* f32 = builder_.create<type::F32>();
auto* spirv_result_type = builder_.create<type::Vector>(f32, 4u);
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index fc52fe7..0166177 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -149,7 +149,7 @@
}
uint32_t GeneratorImplIr::Constant(const constant::Value* constant) {
- return constants_.GetOrCreate(constant, [&]() {
+ return constants_.GetOrCreate(constant, [&] {
auto id = module_.NextId();
auto* ty = constant->Type();
Switch(
@@ -204,7 +204,7 @@
}
uint32_t GeneratorImplIr::ConstantNull(const type::Type* type) {
- return constant_nulls_.GetOrCreate(type, [&]() {
+ return constant_nulls_.GetOrCreate(type, [&] {
auto id = module_.NextId();
module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
return id;
@@ -212,7 +212,7 @@
}
uint32_t GeneratorImplIr::Type(const type::Type* ty) {
- return types_.GetOrCreate(ty, [&]() {
+ return types_.GetOrCreate(ty, [&] {
auto id = module_.NextId();
Switch(
ty, //
@@ -278,7 +278,7 @@
}
uint32_t GeneratorImplIr::Label(ir::Block* block) {
- return block_labels_.GetOrCreate(block, [&]() { return module_.NextId(); });
+ return block_labels_.GetOrCreate(block, [&] { return module_.NextId(); });
}
void GeneratorImplIr::EmitStructType(uint32_t id, const type::Struct* str) {
@@ -356,7 +356,7 @@
}
// Get the ID for the function type (creating it if needed).
- auto function_type_id = function_types_.GetOrCreate(function_type, [&]() {
+ auto function_type_id = function_types_.GetOrCreate(function_type, [&] {
auto func_ty_id = module_.NextId();
OperandList operands = {func_ty_id, return_type_id};
operands.insert(operands.end(), function_type.param_type_ids.begin(),
@@ -376,7 +376,7 @@
TINT_DEFER(current_function_ = Function());
// Emit the body of the function.
- EmitBlock(func->StartTarget());
+ EmitBlock(func->Block());
// Add the function to the module.
module_.PushFunction(current_function_);
@@ -744,7 +744,7 @@
auto glsl_ext_inst = [&](enum GLSLstd450 inst) {
constexpr const char* kGLSLstd450 = "GLSL.std.450";
op = spv::Op::OpExtInst;
- operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&]() {
+ operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&] {
// Import the instruction set the first time it is requested.
auto import = module_.NextId();
module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)});
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
index 1f33350..40f4d4d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_access_test.cc
@@ -27,9 +27,10 @@
auto* func = b.Function("foo", ty.void_());
func->SetParams({arr_val});
- auto sb = b.With(func->StartTarget());
- sb.Access(ty.i32(), arr_val, 1_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Access(ty.i32(), arr_val, 1_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -54,10 +55,11 @@
TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* arr_var = sb.Var(ty.ptr<function, array<i32, 4>>());
- sb.Access(ty.ptr<function, i32>(), arr_var, 1_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* arr_var = b.Var(ty.ptr<function, array<i32, 4>>());
+ b.Access(ty.ptr<function, i32>(), arr_var, 1_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -85,12 +87,13 @@
TEST_F(SpvGeneratorImplTest_Access, Array_Pointer_DynamicIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* idx_var = sb.Var(ty.ptr<function, i32>());
- auto* idx = sb.Load(idx_var);
- auto* arr_var = sb.Var(ty.ptr<function, array<i32, 4>>());
- sb.Access(ty.ptr<function, i32>(), arr_var, idx);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* idx_var = b.Var(ty.ptr<function, i32>());
+ auto* idx = b.Load(idx_var);
+ auto* arr_var = b.Var(ty.ptr<function, array<i32, 4>>());
+ b.Access(ty.ptr<function, i32>(), arr_var, idx);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -121,10 +124,11 @@
auto* func = b.Function("foo", ty.void_());
func->SetParams({mat_val});
- auto sb = b.With(func->StartTarget());
- sb.Access(ty.vec2(ty.f32()), mat_val, 1_u);
- sb.Access(ty.f32(), mat_val, 1_u, 0_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Access(ty.vec2(ty.f32()), mat_val, 1_u);
+ b.Access(ty.f32(), mat_val, 1_u, 0_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -148,11 +152,12 @@
TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* mat_var = sb.Var(ty.ptr<function, mat2x2<f32>>());
- sb.Access(ty.ptr<function, vec2<f32>>(), mat_var, 1_u);
- sb.Access(ty.ptr<function, f32>(), mat_var, 1_u, 0_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* mat_var = b.Var(ty.ptr<function, mat2x2<f32>>());
+ b.Access(ty.ptr<function, vec2<f32>>(), mat_var, 1_u);
+ b.Access(ty.ptr<function, f32>(), mat_var, 1_u, 0_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -182,13 +187,14 @@
TEST_F(SpvGeneratorImplTest_Access, Matrix_Pointer_DynamicIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* idx_var = sb.Var(ty.ptr<function, i32>());
- auto* idx = sb.Load(idx_var);
- auto* mat_var = sb.Var(ty.ptr<function, mat2x2<f32>>());
- sb.Access(ty.ptr<function, vec2<f32>>(), mat_var, idx);
- sb.Access(ty.ptr<function, f32>(), mat_var, idx, idx);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* idx_var = b.Var(ty.ptr<function, i32>());
+ auto* idx = b.Load(idx_var);
+ auto* mat_var = b.Var(ty.ptr<function, mat2x2<f32>>());
+ b.Access(ty.ptr<function, vec2<f32>>(), mat_var, idx);
+ b.Access(ty.ptr<function, f32>(), mat_var, idx, idx);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -221,9 +227,10 @@
auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
func->SetParams({vec_val});
- auto sb = b.With(func->StartTarget());
- sb.Access(ty.i32(), vec_val, 1_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Access(ty.i32(), vec_val, 1_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -247,11 +254,12 @@
auto* vec_val = b.FunctionParam(ty.vec4(ty.i32()));
func->SetParams({vec_val});
- auto sb = b.With(func->StartTarget());
- auto* idx_var = sb.Var(ty.ptr<function, i32>());
- auto* idx = sb.Load(idx_var);
- sb.Access(ty.i32(), vec_val, idx);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* idx_var = b.Var(ty.ptr<function, i32>());
+ auto* idx = b.Load(idx_var);
+ b.Access(ty.i32(), vec_val, idx);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -276,10 +284,11 @@
TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_ConstantIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* vec_var = sb.Var(ty.ptr<function, vec4<i32>>());
- sb.Access(ty.ptr<function, i32>(), vec_var, 1_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* vec_var = b.Var(ty.ptr<function, vec4<i32>>());
+ b.Access(ty.ptr<function, i32>(), vec_var, 1_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -305,12 +314,13 @@
TEST_F(SpvGeneratorImplTest_Access, Vector_Pointer_DynamicIndex) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* idx_var = sb.Var(ty.ptr<function, i32>());
- auto* idx = sb.Load(idx_var);
- auto* vec_var = sb.Var(ty.ptr<function, vec4<i32>>());
- sb.Access(ty.ptr<function, i32>(), vec_var, idx);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* idx_var = b.Var(ty.ptr<function, i32>());
+ auto* idx = b.Load(idx_var);
+ auto* vec_var = b.Var(ty.ptr<function, vec4<i32>>());
+ b.Access(ty.ptr<function, i32>(), vec_var, idx);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -338,11 +348,12 @@
auto* func = b.Function("foo", ty.void_());
func->SetParams({val});
- auto sb = b.With(func->StartTarget());
- auto* idx_var = sb.Var(ty.ptr<function, i32>());
- auto* idx = sb.Load(idx_var);
- sb.Access(ty.i32(), val, 1_u, 2_u, idx);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* idx_var = b.Var(ty.ptr<function, i32>());
+ auto* idx = b.Load(idx_var);
+ b.Access(ty.i32(), val, 1_u, 2_u, idx);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -381,10 +392,11 @@
auto* func = b.Function("foo", ty.void_());
func->SetParams({str_val});
- auto sb = b.With(func->StartTarget());
- sb.Access(ty.i32(), str_val, 1_u);
- sb.Access(ty.i32(), str_val, 1_u, 2_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Access(ty.i32(), str_val, 1_u);
+ b.Access(ty.i32(), str_val, 1_u, 2_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -419,11 +431,12 @@
});
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* str_var = sb.Var(ty.ptr(function, str, read_write));
- sb.Access(ty.ptr<function, i32>(), str_var, 1_u);
- sb.Access(ty.ptr<function, i32>(), str_var, 1_u, 2_u);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* str_var = b.Var(ty.ptr(function, str, read_write));
+ b.Access(ty.ptr<function, i32>(), str_var, 1_u);
+ b.Access(ty.ptr<function, i32>(), str_var, 1_u, 2_u);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
index 37b702c..6d4d890 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc
@@ -37,10 +37,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
- MakeScalarValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
+ MakeScalarValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -51,10 +52,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
+ MakeVectorValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -87,10 +89,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
- MakeScalarValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, MakeScalarType(params.type), MakeScalarValue(params.type),
+ MakeScalarValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -101,10 +104,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, MakeVectorType(params.type), MakeVectorValue(params.type),
+ MakeVectorValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -129,9 +133,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, ty.bool_(), MakeScalarValue(params.type), MakeScalarValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, ty.bool_(), MakeScalarValue(params.type),
+ MakeScalarValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -143,10 +149,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Binary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
- MakeVectorValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Binary(params.kind, ty.vec2(ty.bool_()), MakeVectorValue(params.type),
+ MakeVectorValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -202,10 +209,11 @@
TEST_F(SpvGeneratorImplTest, Binary_Chain) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* a = sb.Subtract(ty.i32(), 1_i, 2_i);
- sb.Add(ty.i32(), a, a);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* a = b.Subtract(ty.i32(), 1_i, 2_i);
+ b.Add(ty.i32(), a, a);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
index ae45e1c..f9bbcd6 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc
@@ -38,9 +38,10 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -51,9 +52,10 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -68,9 +70,10 @@
// Test that abs of an unsigned value just folds away.
TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) {
auto* func = b.Function("foo", MakeScalarType(kU32));
- auto sb = b.With(func->StartTarget());
- auto* result = sb.Call(MakeScalarType(kU32), builtin::Function::kAbs, MakeScalarValue(kU32));
- sb.Return(func, result);
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(MakeScalarType(kU32), builtin::Function::kAbs, MakeScalarValue(kU32));
+ b.Return(func, result);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -88,9 +91,10 @@
TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) {
auto* func = b.Function("foo", MakeVectorType(kU32));
- auto sb = b.With(func->StartTarget());
- auto* result = sb.Call(MakeVectorType(kU32), builtin::Function::kAbs, MakeVectorValue(kU32));
- sb.Return(func, result);
+ b.With(func->Block(), [&] {
+ auto* result = b.Call(MakeVectorType(kU32), builtin::Function::kAbs, MakeVectorValue(kU32));
+ b.Return(func, result);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -115,10 +119,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type),
- MakeScalarValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Call(MakeScalarType(params.type), params.function, MakeScalarValue(params.type),
+ MakeScalarValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -129,10 +134,11 @@
auto params = GetParam();
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type),
- MakeVectorValue(params.type));
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Call(MakeVectorType(params.type), params.function, MakeVectorValue(params.type),
+ MakeVectorValue(params.type));
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
index 732c6ff..7ba7cff 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_construct_test.cc
@@ -29,8 +29,7 @@
b.FunctionParam(ty.i32()),
});
- auto sb = b.With(func->StartTarget());
- sb.Return(func, sb.Construct(ty.vec4<i32>(), func->Params()));
+ b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.vec4<i32>(), func->Params())); });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -59,8 +58,7 @@
b.FunctionParam(ty.vec4<f32>()),
});
- auto sb = b.With(func->StartTarget());
- sb.Return(func, sb.Construct(ty.mat3x4<f32>(), func->Params()));
+ b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.mat3x4<f32>(), func->Params())); });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -90,8 +88,7 @@
b.FunctionParam(ty.f32()),
});
- auto sb = b.With(func->StartTarget());
- sb.Return(func, sb.Construct(ty.array<f32, 4>(), func->Params()));
+ b.With(func->Block(), [&] { b.Return(func, b.Construct(ty.array<f32, 4>(), func->Params())); });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -130,8 +127,7 @@
b.FunctionParam(ty.vec4<f32>()),
});
- auto sb = b.With(func->StartTarget());
- sb.Return(func, sb.Construct(str, func->Params()));
+ b.With(func->Block(), [&] { b.Return(func, b.Construct(str, func->Params())); });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
index 81f1091..00bfeb0 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc
@@ -19,7 +19,7 @@
TEST_F(SpvGeneratorImplTest, Function_Empty) {
auto* func = b.Function("foo", ty.void_());
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -37,7 +37,7 @@
// Test that we do not emit the same function type more than once.
TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
auto* func = b.Function("foo", ty.void_());
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -52,7 +52,7 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func =
b.Function("main", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -71,7 +71,7 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kFragment);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -90,7 +90,7 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func = b.Function("main", ty.void_(), ir::Function::PipelineStage::kVertex);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -108,13 +108,13 @@
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
auto* f1 = b.Function("main1", ty.void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
- f1->StartTarget()->Append(b.Return(f1));
+ f1->Block()->Append(b.Return(f1));
auto* f2 = b.Function("main2", ty.void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
- f2->StartTarget()->Append(b.Return(f2));
+ f2->Block()->Append(b.Return(f2));
auto* f3 = b.Function("main3", ty.void_(), ir::Function::PipelineStage::kFragment);
- f3->StartTarget()->Append(b.Return(f3));
+ f3->Block()->Append(b.Return(f3));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -149,7 +149,7 @@
TEST_F(SpvGeneratorImplTest, Function_ReturnValue) {
auto* func = b.Function("foo", ty.i32());
- func->StartTarget()->Append(b.Return(func, i32(42)));
+ func->Block()->Append(b.Return(func, i32(42)));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -174,9 +174,10 @@
mod.SetName(x, "x");
mod.SetName(y, "y");
- auto sb = b.With(func->StartTarget());
- auto* result = sb.Add(i32, x, y);
- sb.Return(func, result);
+ b.With(func->Block(), [&] {
+ auto* result = b.Add(i32, x, y);
+ b.Return(func, result);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -203,18 +204,16 @@
auto* foo = b.Function("foo", i32_ty);
foo->SetParams({x, y});
- {
- auto sb = b.With(foo->StartTarget());
- auto* result = sb.Add(i32_ty, x, y);
- sb.Return(foo, result);
- }
+ b.With(foo->Block(), [&] {
+ auto* result = b.Add(i32_ty, x, y);
+ b.Return(foo, result);
+ });
auto* bar = b.Function("bar", ty.void_());
- {
- auto sb = b.With(bar->StartTarget());
- sb.Call(i32_ty, foo, i32(2), i32(3));
- sb.Return(bar);
- }
+ b.With(bar->Block(), [&] {
+ b.Call(i32_ty, foo, i32(2), i32(3));
+ b.Return(bar);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -245,12 +244,13 @@
TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
auto* foo = b.Function("foo", ty.void_());
- foo->StartTarget()->Append(b.Return(foo));
+ foo->Block()->Append(b.Return(foo));
auto* bar = b.Function("bar", ty.void_());
- auto sb = b.With(bar->StartTarget());
- sb.Call(ty.void_(), foo, utils::Empty);
- sb.Return(bar);
+ b.With(bar->Block(), [&] {
+ b.Call(ty.void_(), foo, utils::Empty);
+ b.Return(bar);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
index 62afeda..892409d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc
@@ -25,8 +25,8 @@
auto* i = b.If(true);
i->True()->Append(b.ExitIf(i));
i->False()->Append(b.ExitIf(i));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -52,12 +52,13 @@
auto* i = b.If(true);
i->False()->Append(b.ExitIf(i));
- auto tb = b.With(i->True());
- tb.Add(ty.i32(), 1_i, 1_i);
- tb.ExitIf(i);
+ b.With(i->True(), [&] {
+ b.Add(ty.i32(), 1_i, 1_i);
+ b.ExitIf(i);
+ });
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -88,12 +89,13 @@
auto* i = b.If(true);
i->True()->Append(b.ExitIf(i));
- auto fb = b.With(i->False());
- fb.Add(ty.i32(), 1_i, 1_i);
- fb.ExitIf(i);
+ b.With(i->False(), [&] {
+ b.Add(ty.i32(), 1_i, 1_i);
+ b.ExitIf(i);
+ });
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -125,8 +127,8 @@
i->True()->Append(b.Return(func));
i->False()->Append(b.Return(func));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Unreachable());
+ func->Block()->Append(i);
+ func->Block()->Append(b.Unreachable());
ASSERT_TRUE(IRIsValid()) << Error();
@@ -158,8 +160,8 @@
i->True()->Append(b.ExitIf(i, 10_i));
i->False()->Append(b.ExitIf(i, 20_i));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func, i));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func, i));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -194,8 +196,8 @@
i->True()->Append(b.Return(func, 42_i));
i->False()->Append(b.ExitIf(i, 20_i));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func, i));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func, i));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -230,8 +232,8 @@
i->True()->Append(b.ExitIf(i, 10_i));
i->False()->Append(b.Return(func, 42_i));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func, i));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func, i));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -266,8 +268,8 @@
i->True()->Append(b.ExitIf(i, 10_i, true));
i->False()->Append(b.ExitIf(i, 20_i, false));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func, i->Result(0)));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func, i->Result(0)));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -304,8 +306,8 @@
i->True()->Append(b.ExitIf(i, 10_i, true));
i->False()->Append(b.ExitIf(i, 20_i, false));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func, i->Result(1)));
+ func->Block()->Append(i);
+ func->Block()->Append(b.Return(func, i->Result(1)));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
index 88a25e1..63bc82c 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc
@@ -27,8 +27,8 @@
loop->Body()->Append(b.Continue(loop));
loop->Continuing()->Append(b.BreakIf(true, loop));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -62,8 +62,8 @@
loop->Body()->Append(b.ExitLoop(loop));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -100,8 +100,8 @@
loop->Body()->Append(b.Continue(loop));
loop->Continuing()->Append(b.NextIteration(loop));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -145,8 +145,8 @@
loop->Body()->Append(b.ExitLoop(loop));
loop->Continuing()->Append(b.NextIteration(loop));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -185,8 +185,8 @@
auto* loop = b.Loop();
loop->Body()->Append(b.Return(func));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Unreachable());
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Unreachable());
ASSERT_TRUE(IRIsValid()) << Error();
@@ -220,8 +220,8 @@
loop->Continuing()->Append(b.BreakIf(result, loop));
- func->StartTarget()->Append(loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -262,8 +262,8 @@
outer_loop->Body()->Append(b.Continue(outer_loop));
outer_loop->Continuing()->Append(b.BreakIf(true, outer_loop));
- func->StartTarget()->Append(outer_loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(outer_loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -311,8 +311,8 @@
outer_loop->Continuing()->Append(inner_loop);
outer_loop->Continuing()->Append(b.BreakIf(true, outer_loop));
- func->StartTarget()->Append(outer_loop);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(outer_loop);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -350,31 +350,28 @@
TEST_F(SpvGeneratorImplTest, Loop_Phi_SingleValue) {
auto* func = b.Function("foo", ty.void_());
- auto fb = b.With(func->StartTarget());
- auto* l = fb.Loop();
+ b.With(func->Block(), [&] {
+ auto* l = b.Loop();
- {
- auto ib = b.With(l->Initializer());
- ib.NextIteration(l, 1_i, false);
- }
+ b.With(l->Initializer(), [&] { b.NextIteration(l, 1_i, false); });
- auto* loop_param = b.BlockParam(ty.i32());
- l->Body()->SetParams({loop_param});
- {
- auto lb = b.With(l->Body());
- auto* inc = lb.Add(ty.i32(), loop_param, 1_i);
- lb.Continue(l, inc);
- }
+ auto* loop_param = b.BlockParam(ty.i32());
+ l->Body()->SetParams({loop_param});
- auto* cont_param = b.BlockParam(ty.i32());
- l->Continuing()->SetParams({cont_param});
- {
- auto cb = b.With(l->Continuing());
- auto* cmp = cb.GreaterThan(ty.bool_(), cont_param, 5_i);
- cb.BreakIf(cmp, l, cont_param);
- }
+ b.With(l->Body(), [&] {
+ auto* inc = b.Add(ty.i32(), loop_param, 1_i);
+ b.Continue(l, inc);
+ });
- fb.Return(func);
+ auto* cont_param = b.BlockParam(ty.i32());
+ l->Continuing()->SetParams({cont_param});
+ b.With(l->Continuing(), [&] {
+ auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
+ b.BreakIf(cmp, l, cont_param);
+ });
+
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -411,34 +408,31 @@
TEST_F(SpvGeneratorImplTest, Loop_Phi_MultipleValue) {
auto* func = b.Function("foo", ty.void_());
- auto fb = b.With(func->StartTarget());
- auto* l = fb.Loop();
+ b.With(func->Block(), [&] {
+ auto* l = b.Loop();
- {
- auto ib = b.With(l->Initializer());
- ib.NextIteration(l, 1_i, false);
- }
+ b.With(l->Initializer(), [&] { b.NextIteration(l, 1_i, false); });
- auto* loop_param_a = b.BlockParam(ty.i32());
- auto* loop_param_b = b.BlockParam(ty.bool_());
- l->Body()->SetParams({loop_param_a, loop_param_b});
- {
- auto lb = b.With(l->Body());
- auto* inc = lb.Add(ty.i32(), loop_param_a, 1_i);
- lb.Continue(l, inc, loop_param_b);
- }
+ auto* loop_param_a = b.BlockParam(ty.i32());
+ auto* loop_param_b = b.BlockParam(ty.bool_());
+ l->Body()->SetParams({loop_param_a, loop_param_b});
- auto* cont_param_a = b.BlockParam(ty.i32());
- auto* cont_param_b = b.BlockParam(ty.bool_());
- l->Continuing()->SetParams({cont_param_a, cont_param_b});
- {
- auto cb = b.With(l->Continuing());
- auto* cmp = cb.GreaterThan(ty.bool_(), cont_param_a, 5_i);
- auto* not_b = cb.Not(ty.bool_(), cont_param_b);
- cb.BreakIf(cmp, l, cont_param_a, not_b);
- }
+ b.With(l->Body(), [&] {
+ auto* inc = b.Add(ty.i32(), loop_param_a, 1_i);
+ b.Continue(l, inc, loop_param_b);
+ });
- fb.Return(func);
+ auto* cont_param_a = b.BlockParam(ty.i32());
+ auto* cont_param_b = b.BlockParam(ty.bool_());
+ l->Continuing()->SetParams({cont_param_a, cont_param_b});
+ b.With(l->Continuing(), [&] {
+ auto* cmp = b.GreaterThan(ty.bool_(), cont_param_a, 5_i);
+ auto* not_b = b.Not(ty.bool_(), cont_param_b);
+ b.BreakIf(cmp, l, cont_param_a, not_b);
+ });
+
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
index c0f32cf..396e0fe 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -27,8 +27,8 @@
auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
- func->StartTarget()->Append(swtch);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(swtch);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -64,8 +64,8 @@
auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
- func->StartTarget()->Append(swtch);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(swtch);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -108,8 +108,8 @@
ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
- func->StartTarget()->Append(swtch);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(swtch);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -149,8 +149,8 @@
auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.Return(func));
- func->StartTarget()->Append(swtch);
- func->StartTarget()->Append(b.Unreachable());
+ func->Block()->Append(swtch);
+ func->Block()->Append(b.Unreachable());
ASSERT_TRUE(IRIsValid()) << Error();
@@ -192,8 +192,8 @@
auto* def_case = b.Case(swtch, utils::Vector{ir::Switch::CaseSelector()});
def_case->Append(b.ExitSwitch(swtch));
- func->StartTarget()->Append(swtch);
- func->StartTarget()->Append(b.Return(func));
+ func->Block()->Append(swtch);
+ func->Block()->Append(b.Return(func));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -236,8 +236,8 @@
auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.ExitSwitch(s, 20_i));
- func->StartTarget()->Append(s);
- func->StartTarget()->Append(b.Return(func, s));
+ func->Block()->Append(s);
+ func->Block()->Append(b.Return(func, s));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -275,8 +275,8 @@
auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.ExitSwitch(s, 20_i));
- func->StartTarget()->Append(s);
- func->StartTarget()->Append(b.Return(func, s));
+ func->Block()->Append(s);
+ func->Block()->Append(b.Return(func, s));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -314,8 +314,8 @@
auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.ExitSwitch(s, 20_i, false));
- func->StartTarget()->Append(s);
- func->StartTarget()->Append(b.Return(func, s->Result(0)));
+ func->Block()->Append(s);
+ func->Block()->Append(b.Return(func, s->Result(0)));
ASSERT_TRUE(IRIsValid()) << Error();
@@ -357,8 +357,8 @@
auto* case_b = b.Case(s, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
case_b->Append(b.ExitSwitch(s, 20_i, false));
- func->StartTarget()->Append(s);
- func->StartTarget()->Append(b.Return(func, s->Result(1)));
+ func->Block()->Append(s);
+ func->Block()->Append(b.Return(func, s->Result(1)));
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
index ef33bfb..9ec2a47 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_test.cc
@@ -36,8 +36,8 @@
i->True()->Append(b.Return(func, 10_i));
i->False()->Append(b.Return(func, 20_i));
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Unreachable());
+ func->Block()->Append(i);
+ func->Block()->Append(b.Unreachable());
ASSERT_TRUE(IRIsValid()) << Error();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index feb7c63..f0d3e0d 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -24,9 +24,10 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- sb.Var(ty.ptr<function, i32>());
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Var(ty.ptr<function, i32>());
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -47,11 +48,12 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* v = sb.Var(ty.ptr<function, i32>());
- v->SetInitializer(b.Constant(42_i));
+ b.With(func->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ v->SetInitializer(b.Constant(42_i));
- sb.Return(func);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -74,11 +76,12 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_Name) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* v = sb.Var(ty.ptr<function, i32>());
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Return(func);
- mod.SetName(v, "myvar");
+ mod.SetName(v, "myvar");
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -100,18 +103,17 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) {
auto* func = b.Function("foo", ty.void_());
- auto* i = b.If(true);
+ b.With(func->Block(), [&] {
+ auto* i = b.If(true);
+ b.With(i->True(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ v->SetInitializer(b.Constant(42_i));
+ b.ExitIf(i);
+ });
+ b.With(i->False(), [&] { b.Return(func); });
- auto tb = b.With(i->True());
- auto* v = tb.Var(ty.ptr<function, i32>());
- v->SetInitializer(b.Constant(42_i));
- tb.ExitIf(i);
-
- i->False()->Append(b.Return(func));
-
- func->StartTarget()->Append(i);
- func->StartTarget()->Append(b.Return(func));
-
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
generator_.EmitFunction(func);
@@ -142,12 +144,12 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_Load) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
-
- auto* store_ty = ty.i32();
- auto* v = sb.Var(ty.ptr(function, store_ty));
- sb.Load(v);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* store_ty = ty.i32();
+ auto* v = b.Var(ty.ptr(function, store_ty));
+ b.Load(v);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -169,10 +171,11 @@
TEST_F(SpvGeneratorImplTest, FunctionVar_Store) {
auto* func = b.Function("foo", ty.void_());
- auto sb = b.With(func->StartTarget());
- auto* v = sb.Var(ty.ptr<function, i32>());
- sb.Store(v, 42_i);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ auto* v = b.Var(ty.ptr<function, i32>());
+ b.Store(v, 42_i);
+ b.Return(func);
+ });
ASSERT_TRUE(IRIsValid()) << Error();
@@ -273,11 +276,12 @@
v->SetInitializer(b.Constant(42_i));
b.RootBlock()->Append(v);
- auto sb = b.With(func->StartTarget());
- sb.Load(v);
- auto* add = sb.Add(store_ty, v, 1_i);
- sb.Store(v, add);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Load(v);
+ auto* add = b.Add(store_ty, v, 1_i);
+ b.Store(v, add);
+ b.Return(func);
+ });
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -354,11 +358,12 @@
auto* store_ty = ty.i32();
auto* v = b.RootBlock()->Append(b.Var(ty.ptr(workgroup, store_ty)));
- auto sb = b.With(func->StartTarget());
- sb.Load(v);
- auto* add = sb.Add(store_ty, v, 1_i);
- sb.Store(v, add);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Load(v);
+ auto* add = b.Add(store_ty, v, 1_i);
+ b.Store(v, add);
+ b.Return(func);
+ });
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -476,11 +481,12 @@
std::array{1u, 1u, 1u});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
- sb.Load(v);
- auto* add = sb.Add(ty.i32(), v, 1_i);
- sb.Store(v, add);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Load(v);
+ auto* add = b.Add(ty.i32(), v, 1_i);
+ b.Store(v, add);
+ b.Return(func);
+ });
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
@@ -587,9 +593,10 @@
std::array{1u, 1u, 1u});
mod.functions.Push(func);
- auto sb = b.With(func->StartTarget());
- sb.Load(v);
- sb.Return(func);
+ b.With(func->Block(), [&] {
+ b.Load(v);
+ b.Return(func);
+ });
ASSERT_TRUE(generator_.Generate()) << generator_.Diagnostics().str();
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
diff --git a/src/tint/writer/spirv/test_helper.h b/src/tint/writer/spirv/test_helper.h
index b9a90be..d030523 100644
--- a/src/tint/writer/spirv/test_helper.h
+++ b/src/tint/writer/spirv/test_helper.h
@@ -49,12 +49,12 @@
if (spirv_builder) {
return *spirv_builder;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
spirv_builder = std::make_unique<spirv::Builder>(program.get());
@@ -71,16 +71,16 @@
if (spirv_builder) {
return *spirv_builder;
}
- [&]() {
+ [&] {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());
}();
program = std::make_unique<Program>(std::move(*this));
- [&]() {
+ [&] {
ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics());
}();
auto result = Sanitize(program.get(), options);
- [&]() {
+ [&] {
ASSERT_TRUE(result.program.IsValid())
<< diag::Formatter().format(result.program.Diagnostics());
}();
diff --git a/src/tint/writer/syntax_tree/generator_impl.cc b/src/tint/writer/syntax_tree/generator_impl.cc
index 56f7f57..e8e58a3 100644
--- a/src/tint/writer/syntax_tree/generator_impl.cc
+++ b/src/tint/writer/syntax_tree/generator_impl.cc
@@ -41,7 +41,7 @@
namespace tint::writer::syntax_tree {
-GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
+GeneratorImpl::GeneratorImpl(const Program* program) : ASTTextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default;
diff --git a/src/tint/writer/syntax_tree/generator_impl.h b/src/tint/writer/syntax_tree/generator_impl.h
index dc2cf01..2473df2 100644
--- a/src/tint/writer/syntax_tree/generator_impl.h
+++ b/src/tint/writer/syntax_tree/generator_impl.h
@@ -36,12 +36,12 @@
#include "src/tint/program.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/string_stream.h"
-#include "src/tint/writer/text_generator.h"
+#include "src/tint/writer/ast_text_generator.h"
namespace tint::writer::syntax_tree {
/// Implementation class for AST generator
-class GeneratorImpl : public TextGenerator {
+class GeneratorImpl : public ASTTextGenerator {
public:
/// Constructor
/// @param program the program
diff --git a/src/tint/writer/text_generator.cc b/src/tint/writer/text_generator.cc
index 4b0ab6a..89e1a6f 100644
--- a/src/tint/writer/text_generator.cc
+++ b/src/tint/writer/text_generator.cc
@@ -17,28 +17,14 @@
#include <algorithm>
#include <limits>
-#include "src/tint/utils/map.h"
+#include "src/tint/debug.h"
namespace tint::writer {
-TextGenerator::TextGenerator(const Program* program)
- : program_(program), builder_(ProgramBuilder::Wrap(program)) {}
+TextGenerator::TextGenerator() = default;
TextGenerator::~TextGenerator() = default;
-std::string TextGenerator::UniqueIdentifier(const std::string& prefix) {
- return builder_.Symbols().New(prefix).Name();
-}
-
-std::string TextGenerator::StructName(const type::Struct* s) {
- auto name = s->Name().Name();
- if (name.size() > 1 && name[0] == '_' && name[1] == '_') {
- name = utils::GetOrCreate(builtin_struct_names_, s,
- [&] { return UniqueIdentifier(name.substr(2)); });
- }
- return name;
-}
-
TextGenerator::LineWriter::LineWriter(TextBuffer* buf) : buffer(buf) {}
TextGenerator::LineWriter::LineWriter(LineWriter&& other) {
diff --git a/src/tint/writer/text_generator.h b/src/tint/writer/text_generator.h
index ad3eb80..8458db0 100644
--- a/src/tint/writer/text_generator.h
+++ b/src/tint/writer/text_generator.h
@@ -21,7 +21,6 @@
#include <vector>
#include "src/tint/diagnostic/diagnostic.h"
-#include "src/tint/program_builder.h"
#include "src/tint/utils/string_stream.h"
namespace tint::writer {
@@ -87,8 +86,7 @@
};
/// Constructor
- /// @param program the program used by the generator
- explicit TextGenerator(const Program* program);
+ TextGenerator();
~TextGenerator();
/// Increment the emitter indent level
@@ -102,18 +100,6 @@
/// @returns the list of diagnostics raised by the generator.
const diag::List& Diagnostics() const { return diagnostics_; }
- /// @return a new, unique identifier with the given prefix.
- /// @param prefix optional prefix to apply to the generated identifier. If
- /// empty "tint_symbol" will be used.
- std::string UniqueIdentifier(const std::string& prefix = "");
-
- /// @param s the semantic structure
- /// @returns the name of the structure, taking special care of builtin
- /// structures that start with double underscores. If the structure is a
- /// builtin, then the returned name will be a unique name without the leading
- /// underscores.
- std::string StructName(const type::Struct* s);
-
protected:
/// LineWriter is a helper that acts as a string buffer, who's content is
/// emitted to the TextBuffer as a single line on destruction.
@@ -183,16 +169,6 @@
TextBuffer* buffer_;
};
- /// @returns the resolved type of the ast::Expression `expr`
- /// @param expr the expression
- const type::Type* TypeOf(const ast::Expression* expr) const { return builder_.TypeOf(expr); }
-
- /// @returns the resolved type of the ast::TypeDecl `type_decl`
- /// @param type_decl the type
- const type::Type* TypeOf(const ast::TypeDecl* type_decl) const {
- return builder_.TypeOf(type_decl);
- }
-
/// @returns a new LineWriter, used for buffering and writing a line to
/// the end of #current_buffer_.
LineWriter line() { return LineWriter(current_buffer_); }
@@ -202,10 +178,6 @@
/// the end of `buffer`.
static LineWriter line(TextBuffer* buffer) { return LineWriter(buffer); }
- /// The program
- Program const* const program_;
- /// A ProgramBuilder that thinly wraps program_
- ProgramBuilder builder_;
/// Diagnostics generated by the generator
diag::List diagnostics_;
/// The buffer the TextGenerator is currently appending lines to
@@ -214,8 +186,6 @@
private:
/// The primary text buffer that the generator will emit
TextBuffer main_buffer_;
- /// Map of builtin structure to unique generated name
- std::unordered_map<const type::Struct*, std::string> builtin_struct_names_;
};
} // namespace tint::writer
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 0d626f2..751d16f 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -42,7 +42,7 @@
namespace tint::writer::wgsl {
-GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
+GeneratorImpl::GeneratorImpl(const Program* program) : ASTTextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default;
diff --git a/src/tint/writer/wgsl/generator_impl.h b/src/tint/writer/wgsl/generator_impl.h
index 5d69821..6c67d79 100644
--- a/src/tint/writer/wgsl/generator_impl.h
+++ b/src/tint/writer/wgsl/generator_impl.h
@@ -36,12 +36,12 @@
#include "src/tint/program.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/string_stream.h"
-#include "src/tint/writer/text_generator.h"
+#include "src/tint/writer/ast_text_generator.h"
namespace tint::writer::wgsl {
/// Implementation class for WGSL generator
-class GeneratorImpl : public TextGenerator {
+class GeneratorImpl : public ASTTextGenerator {
public:
/// Constructor
/// @param program the program
diff --git a/src/tint/writer/wgsl/generator_impl_binary_test.cc b/src/tint/writer/wgsl/generator_impl_binary_test.cc
index 1829cf8..c40db05 100644
--- a/src/tint/writer/wgsl/generator_impl_binary_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_binary_test.cc
@@ -34,7 +34,7 @@
TEST_P(WgslBinaryTest, Emit) {
auto params = GetParam();
- auto op_ty = [&]() {
+ auto op_ty = [&] {
if (params.op == ast::BinaryOp::kLogicalAnd || params.op == ast::BinaryOp::kLogicalOr) {
return ty.bool_();
} else {
diff --git a/src/tint/writer/wgsl/test_helper.h b/src/tint/writer/wgsl/test_helper.h
index 4cf1e93..b2f652c 100644
--- a/src/tint/writer/wgsl/test_helper.h
+++ b/src/tint/writer/wgsl/test_helper.h
@@ -42,7 +42,7 @@
}
program = std::make_unique<Program>(std::move(*this));
diag::Formatter formatter;
- [&]() { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
+ [&] { ASSERT_TRUE(program->IsValid()) << formatter.format(program->Diagnostics()); }();
gen_ = std::make_unique<GeneratorImpl>(program.get());
return *gen_;
}