tint/transform: Refactor transforms
Replace the ShouldRun() method with Apply() which will do the
transformation if it needs to be done, otherwise returns
'SkipTransform'.
This reduces a bunch of duplicated scanning between the old ShouldRun()
and Transform().
This change also adjusts code style to make the transforms more
consistent.
Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/fuzzers/shuffle_transform.cc b/src/tint/fuzzers/shuffle_transform.cc
index 5f5f6e6..6ae405a 100644
--- a/src/tint/fuzzers/shuffle_transform.cc
+++ b/src/tint/fuzzers/shuffle_transform.cc
@@ -15,6 +15,7 @@
#include "src/tint/fuzzers/shuffle_transform.h"
#include <random>
+#include <utility>
#include "src/tint/program_builder.h"
@@ -22,15 +23,21 @@
ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {}
-void ShuffleTransform::Run(CloneContext& ctx,
- const tint::transform::DataMap&,
- tint::transform::DataMap&) const {
- auto decls = ctx.src->AST().GlobalDeclarations();
+transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src,
+ const transform::DataMap&,
+ transform::DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto decls = src->AST().GlobalDeclarations();
auto rng = std::mt19937_64{seed_};
std::shuffle(std::begin(decls), std::end(decls), rng);
for (auto* decl : decls) {
- ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl));
+ b.AST().AddGlobalDeclaration(ctx.Clone(decl));
}
+
+ ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::fuzzers
diff --git a/src/tint/fuzzers/shuffle_transform.h b/src/tint/fuzzers/shuffle_transform.h
index 0a64fe3..ee54f97 100644
--- a/src/tint/fuzzers/shuffle_transform.h
+++ b/src/tint/fuzzers/shuffle_transform.h
@@ -20,16 +20,16 @@
namespace tint::fuzzers {
/// ShuffleTransform reorders the module scope declarations into a random order
-class ShuffleTransform : public tint::transform::Transform {
+class ShuffleTransform : public transform::Transform {
public:
/// Constructor
/// @param seed the random seed to use for the shuffling
explicit ShuffleTransform(size_t seed);
- protected:
- void Run(CloneContext& ctx,
- const tint::transform::DataMap&,
- tint::transform::DataMap&) const override;
+ /// @copydoc transform::Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const transform::DataMap& inputs,
+ transform::DataMap& outputs) const override;
private:
size_t seed_;
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 8e271d7..634f5da 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -23,6 +23,7 @@
#include "src/tint/ast/access.h"
#include "src/tint/ast/address_space.h"
+#include "src/tint/ast/parameter.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/parameter_usage.h"
@@ -212,6 +213,11 @@
/// Destructor
~Parameter() override;
+ /// @returns the AST declaration node
+ const ast::Parameter* Declaration() const {
+ return static_cast<const ast::Parameter*>(Variable::Declaration());
+ }
+
/// @return the index of the parmeter in the function
uint32_t Index() const { return index_; }
diff --git a/src/tint/transform/add_block_attribute.cc b/src/tint/transform/add_block_attribute.cc
index 77d8719..513925f 100644
--- a/src/tint/transform/add_block_attribute.cc
+++ b/src/tint/transform/add_block_attribute.cc
@@ -31,21 +31,29 @@
AddBlockAttribute::~AddBlockAttribute() = default;
-void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
+Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ auto& sem = src->Sem();
// A map from a type in the source program to a block-decorated wrapper that contains it in the
// destination program.
utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs;
// Process global 'var' declarations that are buffers.
- for (auto* global : ctx.src->AST().GlobalVariables()) {
+ bool made_changes = false;
+ for (auto* global : src->AST().GlobalVariables()) {
auto* var = sem.Get(global);
if (!ast::IsHostShareable(var->AddressSpace())) {
// Not declared in a host-sharable address space
continue;
}
+ made_changes = true;
+
auto* ty = var->Type()->UnwrapRef();
auto* str = ty->As<sem::Struct>();
@@ -61,33 +69,36 @@
const char* kMemberName = "inner";
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
- auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
- ctx.dst->AllocateNodeID());
- auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block";
- auto* ret = ctx.dst->create<ast::Struct>(
- ctx.dst->Symbols().New(wrapper_name),
- utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
+ auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
+ auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block";
+ auto* ret = b.create<ast::Struct>(
+ b.Symbols().New(wrapper_name),
+ utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
utils::Vector{block});
- ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret);
+ ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret);
return ret;
});
- ctx.Replace(global->type, ctx.dst->ty.Of(wrapper));
+ ctx.Replace(global->type, b.ty.Of(wrapper));
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
for (auto* user : var->Users()) {
ctx.Replace(user->Declaration(),
- ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName));
+ b.MemberAccessor(ctx.Clone(global->symbol), kMemberName));
}
} else {
// Add a block attribute to this struct directly.
- auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
- ctx.dst->AllocateNodeID());
+ auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
ctx.InsertFront(str->Declaration()->attributes, block);
}
}
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
ctx.Clone();
+ return Program(std::move(b));
}
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid)
diff --git a/src/tint/transform/add_block_attribute.h b/src/tint/transform/add_block_attribute.h
index 2bfd63e..d5e8c4e 100644
--- a/src/tint/transform/add_block_attribute.h
+++ b/src/tint/transform/add_block_attribute.h
@@ -53,14 +53,10 @@
/// Destructor
~AddBlockAttribute() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/add_empty_entry_point.cc b/src/tint/transform/add_empty_entry_point.cc
index 5ef4fe8..f71394d 100644
--- a/src/tint/transform/add_empty_entry_point.cc
+++ b/src/tint/transform/add_empty_entry_point.cc
@@ -23,12 +23,9 @@
using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
+namespace {
-AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
-
-AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
-
-bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* func : program->AST().Functions()) {
if (func->IsEntryPoint()) {
return false;
@@ -37,13 +34,30 @@
return true;
}
-void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {},
- utils::Vector{
- ctx.dst->Stage(ast::PipelineStage::kCompute),
- ctx.dst->WorkgroupSize(1_i),
- });
+} // namespace
+
+AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
+
+AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
+
+Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
+ utils::Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/add_empty_entry_point.h b/src/tint/transform/add_empty_entry_point.h
index 5530355..828f3b5 100644
--- a/src/tint/transform/add_empty_entry_point.h
+++ b/src/tint/transform/add_empty_entry_point.h
@@ -27,19 +27,10 @@
/// Destructor
~AddEmptyEntryPoint() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/array_length_from_uniform.cc b/src/tint/transform/array_length_from_uniform.cc
index 3938b0c..70097f2 100644
--- a/src/tint/transform/array_length_from_uniform.cc
+++ b/src/tint/transform/array_length_from_uniform.cc
@@ -31,13 +31,153 @@
namespace tint::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
+
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
-/// The PIMPL state for this transform
+/// PIMPL state for the transform
struct ArrayLengthFromUniform::State {
+ /// Constructor
+ /// @param program the source program
+ /// @param in the input transform data
+ /// @param out the output transform data
+ explicit State(const Program* program, const DataMap& in, DataMap& out)
+ : src(program), inputs(in), outputs(out) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ auto* cfg = inputs.Get<Config>();
+ if (cfg == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " +
+ std::string(TypeInfo::Of<ArrayLengthFromUniform>().name));
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(ctx.src)) {
+ return SkipTransform;
+ }
+
+ const char* kBufferSizeMemberName = "buffer_size";
+
+ // Determine the size of the buffer size array.
+ uint32_t max_buffer_size_index = 0;
+
+ IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*,
+ const sem::GlobalVariable* var) {
+ auto binding = var->BindingPoint();
+ auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
+ if (idx_itr->second > max_buffer_size_index) {
+ max_buffer_size_index = idx_itr->second;
+ }
+ });
+
+ // Get (or create, on first call) the uniform buffer that will receive the
+ // size of each storage buffer in the module.
+ const ast::Variable* buffer_size_ubo = nullptr;
+ auto get_ubo = [&]() {
+ if (!buffer_size_ubo) {
+ // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
+ // We do this because UBOs require an element stride that is 16-byte
+ // aligned.
+ auto* buffer_size_struct = b.Structure(
+ b.Sym(), utils::Vector{
+ b.Member(kBufferSizeMemberName,
+ b.ty.array(b.ty.vec4(b.ty.u32()),
+ u32((max_buffer_size_index / 4) + 1))),
+ });
+ buffer_size_ubo =
+ b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform,
+ b.Group(AInt(cfg->ubo_binding.group)),
+ b.Binding(AInt(cfg->ubo_binding.binding)));
+ }
+ return buffer_size_ubo;
+ };
+
+ std::unordered_set<uint32_t> used_size_indices;
+
+ IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
+ const sem::VariableUser* storage_buffer_sem,
+ const sem::GlobalVariable* var) {
+ auto binding = var->BindingPoint();
+ auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
+
+ uint32_t size_index = idx_itr->second;
+ used_size_indices.insert(size_index);
+
+ // Load the total storage buffer size from the UBO.
+ uint32_t array_index = size_index / 4;
+ auto* vec_expr = b.IndexAccessor(
+ b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
+ uint32_t vec_index = size_index % 4;
+ auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index));
+
+ // Calculate actual array length
+ // total_storage_buffer_size - array_offset
+ // array_length = ----------------------------------------
+ // array_stride
+ const ast::Expression* total_size = total_storage_buffer_size;
+ auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
+ const sem::Array* array_type = nullptr;
+ if (auto* str = storage_buffer_type->As<sem::Struct>()) {
+ // The variable is a struct, so subtract the byte offset of the array
+ // member.
+ auto* array_member_sem = str->Members().back();
+ array_type = array_member_sem->Type()->As<sem::Array>();
+ total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
+ } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
+ array_type = arr;
+ } else {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ return;
+ }
+ auto* array_length = b.Div(total_size, u32(array_type->Stride()));
+
+ ctx.Replace(call_expr, array_length);
+ });
+
+ outputs.Add<Result>(used_size_indices);
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The transform inputs
+ const DataMap& inputs;
+ /// The transform outputs
+ DataMap& outputs;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
@@ -48,10 +188,10 @@
/// sem::GlobalVariable for the storage buffer.
template <typename F>
void IterateArrayLengthOnStorageVar(F&& functor) {
- auto& sem = ctx.src->Sem();
+ auto& sem = src->Sem();
// Find all calls to the arrayLength() builtin.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
if (!call_expr) {
continue;
@@ -79,7 +219,7 @@
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (!param || param->op != ast::UnaryOp::kAddressOf) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@@ -90,7 +230,7 @@
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@@ -99,8 +239,7 @@
// Get the index to use for the buffer size array.
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
if (!var) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "storage buffer is not a global variable";
+ TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable";
break;
}
functor(call_expr, storage_buffer_sem, var);
@@ -108,117 +247,10 @@
}
};
-bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (auto* sem_fn = program->Sem().Get(fn)) {
- for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- return true;
- }
- }
- }
- }
- return false;
-}
-
-void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
- auto* cfg = inputs.Get<Config>();
- if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
- }
-
- const char* kBufferSizeMemberName = "buffer_size";
-
- // Determine the size of the buffer size array.
- uint32_t max_buffer_size_index = 0;
-
- State{ctx}.IterateArrayLengthOnStorageVar(
- [&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
- auto binding = var->BindingPoint();
- auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
- if (idx_itr == cfg->bindpoint_to_size_index.end()) {
- return;
- }
- if (idx_itr->second > max_buffer_size_index) {
- max_buffer_size_index = idx_itr->second;
- }
- });
-
- // Get (or create, on first call) the uniform buffer that will receive the
- // size of each storage buffer in the module.
- const ast::Variable* buffer_size_ubo = nullptr;
- auto get_ubo = [&]() {
- if (!buffer_size_ubo) {
- // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
- // We do this because UBOs require an element stride that is 16-byte
- // aligned.
- auto* buffer_size_struct = ctx.dst->Structure(
- ctx.dst->Sym(),
- utils::Vector{
- ctx.dst->Member(kBufferSizeMemberName,
- ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
- u32((max_buffer_size_index / 4) + 1))),
- });
- buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
- ast::AddressSpace::kUniform,
- ctx.dst->Group(AInt(cfg->ubo_binding.group)),
- ctx.dst->Binding(AInt(cfg->ubo_binding.binding)));
- }
- return buffer_size_ubo;
- };
-
- std::unordered_set<uint32_t> used_size_indices;
-
- State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
- const sem::VariableUser* storage_buffer_sem,
- const sem::GlobalVariable* var) {
- auto binding = var->BindingPoint();
- auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
- if (idx_itr == cfg->bindpoint_to_size_index.end()) {
- return;
- }
-
- uint32_t size_index = idx_itr->second;
- used_size_indices.insert(size_index);
-
- // Load the total storage buffer size from the UBO.
- uint32_t array_index = size_index / 4;
- auto* vec_expr = ctx.dst->IndexAccessor(
- ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
- uint32_t vec_index = size_index % 4;
- auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index));
-
- // Calculate actual array length
- // total_storage_buffer_size - array_offset
- // array_length = ----------------------------------------
- // array_stride
- const ast::Expression* total_size = total_storage_buffer_size;
- auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
- const sem::Array* array_type = nullptr;
- if (auto* str = storage_buffer_type->As<sem::Struct>()) {
- // The variable is a struct, so subtract the byte offset of the array
- // member.
- auto* array_member_sem = str->Members().back();
- array_type = array_member_sem->Type()->As<sem::Array>();
- total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
- } else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
- array_type = arr;
- } else {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- return;
- }
- auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride()));
-
- ctx.Replace(call_expr, array_length);
- });
-
- ctx.Clone();
-
- outputs.Add<Result>(used_size_indices);
+Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ return State{src, inputs, outputs}.Run();
}
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
diff --git a/src/tint/transform/array_length_from_uniform.h b/src/tint/transform/array_length_from_uniform.h
index 8bd6af5..507ea37 100644
--- a/src/tint/transform/array_length_from_uniform.h
+++ b/src/tint/transform/array_length_from_uniform.h
@@ -100,22 +100,12 @@
std::unordered_set<uint32_t> used_size_indices;
};
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
- /// The PIMPL state for this transform
struct State;
};
diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc
index 1058bf1..b5d9e77 100644
--- a/src/tint/transform/array_length_from_uniform_test.cc
+++ b/src/tint/transform/array_length_from_uniform_test.cc
@@ -28,7 +28,13 @@
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
- EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
@@ -45,7 +51,13 @@
}
)";
- EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
@@ -63,7 +75,13 @@
}
)";
- EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc
index 798b228..0781355 100644
--- a/src/tint/transform/binding_remapper.cc
+++ b/src/tint/transform/binding_remapper.cc
@@ -40,19 +40,21 @@
BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default;
-bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
- if (auto* remappings = inputs.Get<Remappings>()) {
- return !remappings->binding_points.empty() || !remappings->access_controls.empty();
- }
- return false;
-}
+Transform::ApplyResult BindingRemapper::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
-void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
auto* remappings = inputs.Get<Remappings>();
if (!remappings) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ if (remappings->binding_points.empty() && remappings->access_controls.empty()) {
+ return SkipTransform;
}
// A set of post-remapped binding points that need to be decorated with a
@@ -62,11 +64,11 @@
if (remappings->allow_collisions) {
// Scan for binding point collisions generated by this transform.
// Populate all collisions in the `add_collision_attr` set.
- for (auto* func_ast : ctx.src->AST().Functions()) {
+ for (auto* func_ast : src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
- auto* func = ctx.src->Sem().Get(func_ast);
+ auto* func = src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* global : func->TransitivelyReferencedGlobals()) {
if (global->Declaration()->HasBindingPoint()) {
@@ -90,9 +92,9 @@
}
}
- for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
+ for (auto* var : src->AST().Globals<ast::Var>()) {
if (var->HasBindingPoint()) {
- auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var);
+ auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
// The original binding point
BindingPoint from = global_sem->BindingPoint();
@@ -106,8 +108,8 @@
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
BindingPoint to = bp_it->second;
- auto* new_group = ctx.dst->Group(AInt(to.group));
- auto* new_binding = ctx.dst->Binding(AInt(to.binding));
+ auto* new_group = b.Group(AInt(to.group));
+ auto* new_binding = b.Binding(AInt(to.binding));
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
@@ -122,37 +124,37 @@
if (ac_it != remappings->access_controls.end()) {
ast::Access ac = ac_it->second;
if (ac == ast::Access::kUndefined) {
- ctx.dst->Diagnostics().add_error(
+ b.Diagnostics().add_error(
diag::System::Transform,
"invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")");
- return;
+ return Program(std::move(b));
}
- auto* sem = ctx.src->Sem().Get(var);
+ auto* sem = src->Sem().Get(var);
if (sem->AddressSpace() != ast::AddressSpace::kStorage) {
- ctx.dst->Diagnostics().add_error(
+ b.Diagnostics().add_error(
diag::System::Transform,
"cannot apply access control to variable with address space " +
std::string(utils::ToString(sem->AddressSpace())));
- return;
+ return Program(std::move(b));
}
auto* ty = sem->Type()->UnwrapRef();
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
- auto* new_var =
- ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
- var->declared_address_space, ac, ctx.Clone(var->initializer),
- ctx.Clone(var->attributes));
+ auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
+ var->declared_address_space, ac, ctx.Clone(var->initializer),
+ ctx.Clone(var->attributes));
ctx.Replace(var, new_var);
}
// Add `DisableValidationAttribute`s if required
if (add_collision_attr.count(bp)) {
- auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
+ auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
}
}
}
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/binding_remapper.h b/src/tint/transform/binding_remapper.h
index 77fc5bc..b0efe0d 100644
--- a/src/tint/transform/binding_remapper.h
+++ b/src/tint/transform/binding_remapper.h
@@ -67,19 +67,10 @@
BindingRemapper();
~BindingRemapper() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/binding_remapper_test.cc b/src/tint/transform/binding_remapper_test.cc
index 564a3a5..5bafb7e 100644
--- a/src/tint/transform/binding_remapper_test.cc
+++ b/src/tint/transform/binding_remapper_test.cc
@@ -23,12 +23,6 @@
using BindingRemapperTest = TransformTest;
-TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
- auto* src = R"()";
-
- EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
-}
-
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
auto* src = R"()";
@@ -350,7 +344,7 @@
}
)";
- auto* expect = src;
+ auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)";
auto got = Run<BindingRemapper>(src);
diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc
index 1e28e7e..db200e4 100644
--- a/src/tint/transform/builtin_polyfill.cc
+++ b/src/tint/transform/builtin_polyfill.cc
@@ -29,7 +29,7 @@
namespace tint::transform {
-/// The PIMPL state for the BuiltinPolyfill transform
+/// PIMPL state for the transform
struct BuiltinPolyfill::State {
/// Constructor
/// @param c the CloneContext
@@ -604,193 +604,100 @@
BuiltinPolyfill::~BuiltinPolyfill() = default;
-bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const {
- if (auto* cfg = data.Get<Config>()) {
- auto builtins = cfg->builtins;
- auto& sem = program->Sem();
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* call = sem.Get<sem::Call>(node)) {
- if (auto* builtin = call->Target()->As<sem::Builtin>()) {
- if (call->Stage() == sem::EvaluationStage::kConstant) {
- continue; // Don't polyfill @const expressions
- }
- switch (builtin->Type()) {
- case sem::BuiltinType::kAcosh:
- if (builtins.acosh != Level::kNone) {
- return true;
- }
- break;
- case sem::BuiltinType::kAsinh:
- if (builtins.asinh) {
- return true;
- }
- break;
- case sem::BuiltinType::kAtanh:
- if (builtins.atanh != Level::kNone) {
- return true;
- }
- break;
- case sem::BuiltinType::kClamp:
- if (builtins.clamp_int) {
- auto& sig = builtin->Signature();
- return sig.parameters[0]->Type()->is_integer_scalar_or_vector();
- }
- break;
- case sem::BuiltinType::kCountLeadingZeros:
- if (builtins.count_leading_zeros) {
- return true;
- }
- break;
- case sem::BuiltinType::kCountTrailingZeros:
- if (builtins.count_trailing_zeros) {
- return true;
- }
- break;
- case sem::BuiltinType::kExtractBits:
- if (builtins.extract_bits != Level::kNone) {
- return true;
- }
- break;
- case sem::BuiltinType::kFirstLeadingBit:
- if (builtins.first_leading_bit) {
- return true;
- }
- break;
- case sem::BuiltinType::kFirstTrailingBit:
- if (builtins.first_trailing_bit) {
- return true;
- }
- break;
- case sem::BuiltinType::kInsertBits:
- if (builtins.insert_bits != Level::kNone) {
- return true;
- }
- break;
- case sem::BuiltinType::kSaturate:
- if (builtins.saturate) {
- return true;
- }
- break;
- case sem::BuiltinType::kTextureSampleBaseClampToEdge:
- if (builtins.texture_sample_base_clamp_to_edge_2d_f32) {
- auto& sig = builtin->Signature();
- auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
- if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
- return stex->type()->Is<sem::F32>();
- }
- }
- break;
- case sem::BuiltinType::kQuantizeToF16:
- if (builtins.quantize_to_vec_f16) {
- if (builtin->ReturnType()->Is<sem::Vector>()) {
- return true;
- }
- }
- break;
- default:
- break;
- }
- }
- }
- }
- }
- return false;
-}
-
-void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const {
+Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
+ const DataMap& data,
+ DataMap&) const {
auto* cfg = data.Get<Config>();
if (!cfg) {
- ctx.Clone();
- return;
+ return SkipTransform;
}
- std::unordered_map<const sem::Builtin*, Symbol> polyfills;
+ auto& builtins = cfg->builtins;
- ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
- auto builtins = cfg->builtins;
- State s{ctx, builtins};
- if (auto* call = s.sem.Get<sem::Call>(expr)) {
+ utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills;
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ State s{ctx, builtins};
+
+ bool made_changes = false;
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* call = src->Sem().Get<sem::Call>(node)) {
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (call->Stage() == sem::EvaluationStage::kConstant) {
- return nullptr; // Don't polyfill @const expressions
+ continue; // Don't polyfill @const expressions
}
Symbol polyfill;
switch (builtin->Type()) {
case sem::BuiltinType::kAcosh:
if (builtins.acosh != Level::kNone) {
- polyfill = utils::GetOrCreate(
- polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.acosh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAsinh:
if (builtins.asinh) {
- polyfill = utils::GetOrCreate(
- polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.asinh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAtanh:
if (builtins.atanh != Level::kNone) {
- polyfill = utils::GetOrCreate(
- polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.atanh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kClamp:
if (builtins.clamp_int) {
auto& sig = builtin->Signature();
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.clampInteger(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
}
}
break;
case sem::BuiltinType::kCountLeadingZeros:
if (builtins.count_leading_zeros) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countLeadingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kCountTrailingZeros:
if (builtins.count_trailing_zeros) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countTrailingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kExtractBits:
if (builtins.extract_bits != Level::kNone) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.extractBits(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.extractBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstLeadingBit:
if (builtins.first_leading_bit) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.firstLeadingBit(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstTrailingBit:
if (builtins.first_trailing_bit) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.firstTrailingBit(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kInsertBits:
if (builtins.insert_bits != Level::kNone) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.insertBits(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.insertBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kSaturate:
if (builtins.saturate) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
- return s.saturate(builtin->ReturnType());
- });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.saturate(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
@@ -799,7 +706,7 @@
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
if (stex->type()->Is<sem::F32>()) {
- polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
+ polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.textureSampleBaseClampToEdge_2d_f32();
});
}
@@ -809,8 +716,8 @@
case sem::BuiltinType::kQuantizeToF16:
if (builtins.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
- polyfill = utils::GetOrCreate(polyfills, builtin,
- [&] { return s.quantizeToF16(vec); });
+ polyfill = polyfills.GetOrCreate(
+ builtin, [&] { return s.quantizeToF16(vec); });
}
}
break;
@@ -819,14 +726,20 @@
break;
}
if (polyfill.IsValid()) {
- return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
+ auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
+ ctx.Replace(call->Declaration(), replacement);
+ made_changes = true;
}
}
}
- return nullptr;
- });
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
ctx.Clone();
+ return Program(std::move(b));
}
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
diff --git a/src/tint/transform/builtin_polyfill.h b/src/tint/transform/builtin_polyfill.h
index 7106915..231d753 100644
--- a/src/tint/transform/builtin_polyfill.h
+++ b/src/tint/transform/builtin_polyfill.h
@@ -87,21 +87,13 @@
const Builtins builtins;
};
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
+ private:
struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/builtin_polyfill_test.cc b/src/tint/transform/builtin_polyfill_test.cc
index 65547ab..4bff0ff 100644
--- a/src/tint/transform/builtin_polyfill_test.cc
+++ b/src/tint/transform/builtin_polyfill_test.cc
@@ -1561,7 +1561,8 @@
TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) {
auto* src = R"(
fn f() {
- let r : i32 = insertBits(1234, 5678, 5u, 6u);
+ let v = 1234i;
+ let r : i32 = insertBits(v, 5678, 5u, 6u);
}
)";
@@ -1975,10 +1976,6 @@
)";
auto* expect = R"(
-@group(0) @binding(0) var t : texture_2d<f32>;
-
-@group(0) @binding(1) var s : sampler;
-
fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> {
let dims = vec2<f32>(textureDimensions(t, 0));
let half_texel = (vec2<f32>(0.5) / dims);
@@ -1986,6 +1983,10 @@
return textureSampleLevel(t, s, clamped, 0);
}
+@group(0) @binding(0) var t : texture_2d<f32>;
+
+@group(0) @binding(1) var s : sampler;
+
fn f() {
let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
}
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index 2ca5e54..9dcdd7b 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -40,6 +40,19 @@
namespace {
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (auto* sem_fn = program->Sem().Get(fn)) {
+ for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
+ if (builtin->Type() == sem::BuiltinType::kArrayLength) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
/// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map.
struct ArrayUsage {
@@ -73,21 +86,16 @@
CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default;
-bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (auto* sem_fn = program->Sem().Get(fn)) {
- for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
- if (builtin->Type() == sem::BuiltinType::kArrayLength) {
- return true;
- }
- }
- }
+Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
- return false;
-}
-void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ auto& sem = src->Sem();
// get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
@@ -95,24 +103,20 @@
std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
- auto name = ctx.dst->Sym();
+ auto name = b.Sym();
auto* type = CreateASTTypeFor(ctx, buffer_type);
- auto* disable_validation =
- ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
- ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
+ auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
+ b.AST().AddFunction(b.create<ast::Function>(
name,
utils::Vector{
- ctx.dst->Param("buffer",
- ctx.dst->ty.pointer(type, buffer_type->AddressSpace(),
- buffer_type->Access()),
- utils::Vector{disable_validation}),
- ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
- ast::AddressSpace::kFunction)),
+ b.Param("buffer",
+ b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
+ utils::Vector{disable_validation}),
+ b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)),
},
- ctx.dst->ty.void_(), nullptr,
+ b.ty.void_(), nullptr,
utils::Vector{
- ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(),
- ctx.dst->AllocateNodeID()),
+ b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
},
utils::Empty));
@@ -123,7 +127,7 @@
std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
// Find all the arrayLength() calls...
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
@@ -149,7 +153,7 @@
auto* arg = call_expr->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
}
auto* storage_buffer_expr = address_of->expr;
@@ -158,7 +162,7 @@
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@@ -179,25 +183,24 @@
// Construct the variable that'll hold the result of
// RWByteAddressBuffer.GetDimensions()
- auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var(
- ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u)));
+ auto* buffer_size_result =
+ b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
// Call storage_buffer.GetDimensions(&buffer_size_result)
- auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
+ auto* call_get_dims = b.CallStmt(b.Call(
// BufferSizeIntrinsic(X, ARGS...) is
// translated to:
// X.GetDimensions(ARGS..) by the writer
- buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
- ctx.dst->AddressOf(
- ctx.dst->Expr(buffer_size_result->variable->symbol))));
+ buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
+ b.AddressOf(b.Expr(buffer_size_result->variable->symbol))));
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
- auto name = ctx.dst->Sym();
+ auto name = b.Sym();
const ast::Expression* total_size =
- ctx.dst->Expr(buffer_size_result->variable);
+ b.Expr(buffer_size_result->variable);
const sem::Array* array_type = Switch(
storage_buffer_type->StoreType(),
@@ -205,23 +208,21 @@
// The variable is a struct, so subtract the byte offset of
// the array member.
auto* array_member_sem = str->Members().back();
- total_size =
- ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
+ total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
return array_member_sem->Type()->As<sem::Array>();
},
[&](const sem::Array* arr) { return arr; });
if (!array_type) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be "
"&array_var or &struct_var.array_member";
return name;
}
uint32_t array_stride = array_type->Size();
- auto* array_length_var = ctx.dst->Decl(
- ctx.dst->Let(name, ctx.dst->ty.u32(),
- ctx.dst->Div(total_size, u32(array_stride))));
+ auto* array_length_var = b.Decl(
+ b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
// Insert the array length calculations at the top of the block
ctx.InsertBefore(block->statements, block->statements[0],
@@ -234,13 +235,14 @@
});
// Replace the call to arrayLength() with the array length variable
- ctx.Replace(call_expr, ctx.dst->Expr(array_length));
+ ctx.Replace(call_expr, b.Expr(array_length));
}
}
}
}
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/calculate_array_length.h b/src/tint/transform/calculate_array_length.h
index 8db8dcc..e5714a8 100644
--- a/src/tint/transform/calculate_array_length.h
+++ b/src/tint/transform/calculate_array_length.h
@@ -59,19 +59,10 @@
/// Destructor
~CalculateArrayLength() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/canonicalize_entry_point_io.cc b/src/tint/transform/canonicalize_entry_point_io.cc
index b990965..8a14fb7 100644
--- a/src/tint/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/transform/canonicalize_entry_point_io.cc
@@ -123,7 +123,7 @@
} // namespace
-/// State holds the current transform state for a single entry point.
+/// PIMPL state for the transform
struct CanonicalizeEntryPointIO::State {
/// OutputValue represents a shader result that the wrapper function produces.
struct OutputValue {
@@ -770,17 +770,22 @@
}
};
-void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
}
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
- for (auto* ty : ctx.src->AST().TypeDecls()) {
+ for (auto* ty : src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
@@ -792,7 +797,7 @@
}
}
- for (auto* func_ast : ctx.src->AST().Functions()) {
+ for (auto* func_ast : src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
@@ -802,6 +807,7 @@
}
ctx.Clone();
+ return Program(std::move(b));
}
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
diff --git a/src/tint/transform/canonicalize_entry_point_io.h b/src/tint/transform/canonicalize_entry_point_io.h
index 95f8b19..fbfed5e 100644
--- a/src/tint/transform/canonicalize_entry_point_io.h
+++ b/src/tint/transform/canonicalize_entry_point_io.h
@@ -127,15 +127,12 @@
CanonicalizeEntryPointIO();
~CanonicalizeEntryPointIO() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+ private:
struct State;
};
diff --git a/src/tint/transform/clamp_frag_depth.cc b/src/tint/transform/clamp_frag_depth.cc
index e67dda3..4551925 100644
--- a/src/tint/transform/clamp_frag_depth.cc
+++ b/src/tint/transform/clamp_frag_depth.cc
@@ -14,7 +14,7 @@
#include "src/tint/transform/clamp_frag_depth.h"
- #include <utility>
+#include <utility>
#include "src/tint/ast/attribute.h"
#include "src/tint/ast/builtin_attribute.h"
@@ -64,12 +64,7 @@
return false;
}
-} // anonymous namespace
-
-ClampFragDepth::ClampFragDepth() = default;
-ClampFragDepth::~ClampFragDepth() = default;
-
-bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
auto& sem = program->Sem();
for (auto* fn : program->AST().Functions()) {
@@ -82,22 +77,33 @@
return false;
}
-void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+} // anonymous namespace
+
+ClampFragDepth::ClampFragDepth() = default;
+ClampFragDepth::~ClampFragDepth() = default;
+
+Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
// Abort on any use of push constants in the module.
- for (auto* global : ctx.src->AST().GlobalVariables()) {
+ for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (var->declared_address_space == ast::AddressSpace::kPushConstant) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "ClampFragDepth doesn't know how to handle module that already use push "
"constants.";
- return;
+ return Program(std::move(b));
}
}
}
- auto& b = *ctx.dst;
- auto& sem = ctx.src->Sem();
- auto& sym = ctx.src->Symbols();
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ auto& sem = src->Sem();
+ auto& sym = src->Symbols();
// At least one entry-point needs clamping. Add the following to the module:
//
@@ -197,6 +203,7 @@
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/clamp_frag_depth.h b/src/tint/transform/clamp_frag_depth.h
index 3b15f11..1e9d0d6 100644
--- a/src/tint/transform/clamp_frag_depth.h
+++ b/src/tint/transform/clamp_frag_depth.h
@@ -61,19 +61,10 @@
/// Destructor
~ClampFragDepth() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc
index 97650ad..e7286d4 100644
--- a/src/tint/transform/combine_samplers.cc
+++ b/src/tint/transform/combine_samplers.cc
@@ -47,10 +47,14 @@
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
CombineSamplers::BindingInfo::~BindingInfo() = default;
-/// The PIMPL state for the CombineSamplers transform
+/// PIMPL state for the transform
struct CombineSamplers::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// The binding info
const BindingInfo* binding_info;
@@ -88,9 +92,9 @@
}
/// Constructor
- /// @param context the clone context
+ /// @param program the source program
/// @param info the binding map information
- State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {}
+ State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {}
/// Creates a combined sampler global variables.
/// (Note this is actually a Texture node at the AST level, but it will be
@@ -145,8 +149,9 @@
}
}
- /// Performs the transformation
- void Run() {
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
auto& sem = ctx.src->Sem();
// Remove all texture and sampler global variables. These will be replaced
@@ -169,14 +174,14 @@
// Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found.
- ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
- if (auto* func = sem.Get(src)) {
- auto pairs = func->TextureSamplerPairs();
+ ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* {
+ if (auto* fn = sem.Get(ast_fn)) {
+ auto pairs = fn->TextureSamplerPairs();
if (pairs.IsEmpty()) {
return nullptr;
}
utils::Vector<const ast::Parameter*, 8> params;
- for (auto pair : func->TextureSamplerPairs()) {
+ for (auto pair : fn->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
std::string name =
@@ -197,23 +202,23 @@
auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var);
- function_combined_texture_samplers_[func][pair] = var;
+ function_combined_texture_samplers_[fn][pair] = var;
}
}
// Filter out separate textures and samplers from the original
// function signature.
- for (auto* var : src->params) {
- if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
- params.Push(ctx.Clone(var));
+ for (auto* param : fn->Parameters()) {
+ if (!param->Type()->IsAnyOf<sem::Texture, sem::Sampler>()) {
+ params.Push(ctx.Clone(param->Declaration()));
}
}
// Create a new function signature that differs only in the parameter
// list.
- auto symbol = ctx.Clone(src->symbol);
- auto* return_type = ctx.Clone(src->return_type);
- auto* body = ctx.Clone(src->body);
- auto attributes = ctx.Clone(src->attributes);
- auto return_type_attributes = ctx.Clone(src->return_type_attributes);
+ auto symbol = ctx.Clone(ast_fn->symbol);
+ auto* return_type = ctx.Clone(ast_fn->return_type);
+ auto* body = ctx.Clone(ast_fn->body);
+ auto attributes = ctx.Clone(ast_fn->attributes);
+ auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
return ctx.dst->create<ast::Function>(symbol, params, return_type, body,
std::move(attributes),
std::move(return_type_attributes));
@@ -327,6 +332,7 @@
});
ctx.Clone();
+ return Program(std::move(b));
}
};
@@ -334,15 +340,18 @@
CombineSamplers::~CombineSamplers() = default;
-void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult CombineSamplers::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
auto* binding_info = inputs.Get<BindingInfo>();
if (!binding_info) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
+ ProgramBuilder b;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
}
- State(ctx, binding_info).Run();
+ return State(src, binding_info).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/combine_samplers.h b/src/tint/transform/combine_samplers.h
index 8dfc098..6834abe 100644
--- a/src/tint/transform/combine_samplers.h
+++ b/src/tint/transform/combine_samplers.h
@@ -88,17 +88,13 @@
/// Destructor
~CombineSamplers() override;
- protected:
- /// The PIMPL state for this transform
- struct State;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index 68324af..046583e 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -47,6 +47,18 @@
namespace {
+bool ShouldRun(const Program* program) {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
+ if (var->AddressSpace() == ast::AddressSpace::kStorage ||
+ var->AddressSpace() == ast::AddressSpace::kUniform) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
@@ -291,7 +303,7 @@
} // namespace
-/// State holds the current transform state
+/// PIMPL state for the transform
struct DecomposeMemoryAccess::State {
/// The clone context
CloneContext& ctx;
@@ -477,7 +489,7 @@
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles storage and uniform.
// * Runtime-sized arrays are not loadable.
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
@@ -578,7 +590,7 @@
// * Override-expression counts can only be applied to workgroup
// arrays, and this method only handles storage and uniform.
// * Runtime-sized arrays are not storable.
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
@@ -808,21 +820,16 @@
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
-bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
- if (var->AddressSpace() == ast::AddressSpace::kStorage ||
- var->AddressSpace() == ast::AddressSpace::kUniform) {
- return true;
- }
- }
+Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
- return false;
-}
-void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
-
+ auto& sem = src->Sem();
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
// Scan the AST nodes for storage and uniform buffer accesses. Complex
@@ -833,7 +840,7 @@
// Inner-most expression nodes are guaranteed to be visited first because AST
// nodes are fully immutable and require their children to be constructed
// first so their pointer can be passed to the parent's initializer.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
@@ -1001,6 +1008,7 @@
}
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_memory_access.h b/src/tint/transform/decompose_memory_access.h
index 2e92a3a..21c196b 100644
--- a/src/tint/transform/decompose_memory_access.h
+++ b/src/tint/transform/decompose_memory_access.h
@@ -108,20 +108,12 @@
/// Destructor
~DecomposeMemoryAccess() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
-
+ private:
struct State;
};
diff --git a/src/tint/transform/decompose_strided_array.cc b/src/tint/transform/decompose_strided_array.cc
index e9f51a5..73a6629 100644
--- a/src/tint/transform/decompose_strided_array.cc
+++ b/src/tint/transform/decompose_strided_array.cc
@@ -34,13 +34,7 @@
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
-} // namespace
-
-DecomposeStridedArray::DecomposeStridedArray() = default;
-
-DecomposeStridedArray::~DecomposeStridedArray() = default;
-
-bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ast = node->As<ast::Array>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
@@ -51,8 +45,22 @@
return false;
}
-void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- const auto& sem = ctx.src->Sem();
+} // namespace
+
+DecomposeStridedArray::DecomposeStridedArray() = default;
+
+DecomposeStridedArray::~DecomposeStridedArray() = default;
+
+Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ const auto& sem = src->Sem();
static constexpr const char* kMemberName = "el";
@@ -69,23 +77,23 @@
if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
- auto name = ctx.dst->Symbols().New("strided_arr");
+ auto name = b.Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ast->type);
- auto* member = ctx.dst->Member(kMemberName, member_ty,
- utils::Vector{
- ctx.dst->MemberSize(AInt(arr->Stride())),
- });
- ctx.dst->Structure(name, utils::Vector{member});
+ auto* member = b.Member(kMemberName, member_ty,
+ utils::Vector{
+ b.MemberSize(AInt(arr->Stride())),
+ });
+ b.Structure(name, utils::Vector{member});
return name;
});
auto* count = ctx.Clone(ast->count);
- return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
+ return b.ty.array(b.ty.type_name(el_ty), count);
}
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ast->type);
auto* count = ctx.Clone(ast->count);
- return ctx.dst->ty.array(ty, count);
+ return b.ty.array(ty, count);
}
}
return nullptr;
@@ -96,11 +104,11 @@
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
- if (auto* ty = ctx.src->TypeOf(idx->object)) {
+ if (auto* ty = src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (!arr->IsStrideImplicit()) {
auto* expr = ctx.CloneWithoutTransform(idx);
- return ctx.dst->MemberAccessor(expr, kMemberName);
+ return b.MemberAccessor(expr, kMemberName);
}
}
}
@@ -136,21 +144,23 @@
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.Reserve(expr->args.Length());
for (auto* arg : expr->args) {
- args.Push(ctx.dst->Call(it->second, ctx.Clone(arg)));
+ args.Push(b.Call(it->second, ctx.Clone(arg)));
}
} else {
args = ctx.Clone(expr->args);
}
- return target.type ? ctx.dst->Construct(target.type, std::move(args))
- : ctx.dst->Call(target.name, std::move(args));
+ return target.type ? b.Construct(target.type, std::move(args))
+ : b.Call(target.name, std::move(args));
}
}
}
}
return nullptr;
});
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_array.h b/src/tint/transform/decompose_strided_array.h
index 5dbaaa5..9555a9a 100644
--- a/src/tint/transform/decompose_strided_array.h
+++ b/src/tint/transform/decompose_strided_array.h
@@ -35,19 +35,10 @@
/// Destructor
~DecomposeStridedArray() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc
index 91aed43..5494ca2 100644
--- a/src/tint/transform/decompose_strided_matrix.cc
+++ b/src/tint/transform/decompose_strided_matrix.cc
@@ -53,24 +53,25 @@
};
};
-/// Return type of the callback function of GatherCustomStrideMatrixMembers
-enum GatherResult { kContinue, kStop };
+} // namespace
-/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
-/// storage and uniform structs, which are of a matrix type, and have a custom
-/// matrix stride attribute. For each matrix member found, `callback` is called.
-/// `callback` is a function with the signature:
-/// GatherResult(const sem::StructMember* member,
-/// sem::Matrix* matrix,
-/// uint32_t stride)
-/// If `callback` return GatherResult::kStop, then the scanning will immediately
-/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
-/// scanning will continue.
-template <typename F>
-void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
- for (auto* node : program->ASTNodes().Objects()) {
+DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
+
+DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
+
+Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ // Scan the program for all storage and uniform structure matrix members with
+ // a custom stride attribute. Replace these matrices with an equivalent array,
+ // and populate the `decomposed` map with the members that have been replaced.
+ utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
- auto* str_ty = program->Sem().Get(str);
+ auto* str_ty = src->Sem().Get(str);
if (!str_ty->UsedAs(ast::AddressSpace::kUniform) &&
!str_ty->UsedAs(ast::AddressSpace::kStorage)) {
continue;
@@ -89,46 +90,20 @@
if (matrix->ColumnStride() == stride) {
continue;
}
- if (callback(member, matrix, stride) == GatherResult::kStop) {
- return;
- }
+ // We've got ourselves a struct member of a matrix type with a custom
+ // stride. Replace this with an array of column vectors.
+ MatrixInfo info{stride, matrix};
+ auto* replacement =
+ b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+ ctx.Replace(member->Declaration(), replacement);
+ decomposed.Add(member->Declaration(), info);
}
}
}
-}
-} // namespace
-
-DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
-
-DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
-
-bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
- bool should_run = false;
- GatherCustomStrideMatrixMembers(program,
- [&](const sem::StructMember*, const sem::Matrix*, uint32_t) {
- should_run = true;
- return GatherResult::kStop;
- });
- return should_run;
-}
-
-void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- // Scan the program for all storage and uniform structure matrix members with
- // a custom stride attribute. Replace these matrices with an equivalent array,
- // and populate the `decomposed` map with the members that have been replaced.
- std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
- GatherCustomStrideMatrixMembers(
- ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) {
- // We've got ourselves a struct member of a matrix type with a custom
- // stride. Replace this with an array of column vectors.
- MatrixInfo info{stride, matrix};
- auto* replacement =
- ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
- ctx.Replace(member->Declaration(), replacement);
- decomposed.emplace(member->Declaration(), info);
- return GatherResult::kContinue;
- });
+ if (decomposed.IsEmpty()) {
+ return SkipTransform;
+ }
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
@@ -136,12 +111,11 @@
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it != decomposed.end()) {
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
+ if (decomposed.Contains(access->Member()->Declaration())) {
auto* obj = ctx.CloneWithoutTransform(expr->object);
auto* idx = ctx.Clone(expr->index);
- return ctx.dst->IndexAccessor(obj, idx);
+ return b.IndexAccessor(obj, idx);
}
}
return nullptr;
@@ -154,39 +128,36 @@
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
+ if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
+ auto name =
+ b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
+ std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride) + "_to_arr");
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto mat = b.Sym("m");
+ utils::Vector<const ast::Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(mat, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(mat, matrix()),
+ },
+ array(),
+ utils::Vector{
+ b.Return(b.Construct(array(), columns)),
+ });
+ return name;
+ });
+ auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
+ auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
+ return b.Assign(lhs, rhs);
}
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
- auto name =
- ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" +
- std::to_string(info.stride) + "_to_arr");
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto mat = ctx.dst->Sym("m");
- utils::Vector<const ast::Expression*, 4> columns;
- for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
- columns.Push(ctx.dst->IndexAccessor(mat, u32(i)));
- }
- ctx.dst->Func(name,
- utils::Vector{
- ctx.dst->Param(mat, matrix()),
- },
- array(),
- utils::Vector{
- ctx.dst->Return(ctx.dst->Construct(array(), columns)),
- });
- return name;
- });
- auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
- auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
- return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
@@ -196,41 +167,40 @@
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
- if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
- auto it = decomposed.find(access->Member()->Declaration());
- if (it == decomposed.end()) {
- return nullptr;
+ if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
+ if (auto* info = decomposed.Find(access->Member()->Declaration())) {
+ auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
+ auto name =
+ b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
+ "x" + std::to_string(info->matrix->rows()) + "_stride_" +
+ std::to_string(info->stride));
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
+ auto array = [&] { return info->array(ctx.dst); };
+
+ auto arr = b.Sym("arr");
+ utils::Vector<const ast::Expression*, 4> columns;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
+ columns.Push(b.IndexAccessor(arr, u32(i)));
+ }
+ b.Func(name,
+ utils::Vector{
+ b.Param(arr, array()),
+ },
+ matrix(),
+ utils::Vector{
+ b.Return(b.Construct(matrix(), columns)),
+ });
+ return name;
+ });
+ return b.Call(fn, ctx.CloneWithoutTransform(expr));
}
- MatrixInfo info = it->second;
- auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
- auto name = ctx.dst->Symbols().New(
- "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
- std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
-
- auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
- auto array = [&] { return info.array(ctx.dst); };
-
- auto arr = ctx.dst->Sym("arr");
- utils::Vector<const ast::Expression*, 4> columns;
- for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
- columns.Push(ctx.dst->IndexAccessor(arr, u32(i)));
- }
- ctx.dst->Func(name,
- utils::Vector{
- ctx.dst->Param(arr, array()),
- },
- matrix(),
- utils::Vector{
- ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
- });
- return name;
- });
- return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_matrix.h b/src/tint/transform/decompose_strided_matrix.h
index 40e9c3e..947dfc6 100644
--- a/src/tint/transform/decompose_strided_matrix.h
+++ b/src/tint/transform/decompose_strided_matrix.h
@@ -35,19 +35,10 @@
/// Destructor
~DecomposeStridedMatrix() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/disable_uniformity_analysis.cc b/src/tint/transform/disable_uniformity_analysis.cc
index 918b1f1..ffd0b18 100644
--- a/src/tint/transform/disable_uniformity_analysis.cc
+++ b/src/tint/transform/disable_uniformity_analysis.cc
@@ -27,14 +27,20 @@
DisableUniformityAnalysis::~DisableUniformityAnalysis() = default;
-bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const {
- return !program->Sem().Module()->Extensions().Contains(
- ast::Extension::kChromiumDisableUniformityAnalysis);
-}
+Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (src->Sem().Module()->Extensions().Contains(
+ ast::Extension::kChromiumDisableUniformityAnalysis)) {
+ return SkipTransform;
+ }
-void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/disable_uniformity_analysis.h b/src/tint/transform/disable_uniformity_analysis.h
index 3c9fb53..a9922af 100644
--- a/src/tint/transform/disable_uniformity_analysis.h
+++ b/src/tint/transform/disable_uniformity_analysis.h
@@ -27,19 +27,10 @@
/// Destructor
~DisableUniformityAnalysis() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
index f15e28c..9fa81dd 100644
--- a/src/tint/transform/expand_compound_assignment.cc
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -31,11 +31,9 @@
namespace tint::transform {
-ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
+namespace {
-ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
-
-bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
return true;
@@ -44,21 +42,10 @@
return false;
}
-namespace {
+} // namespace
-/// Internal class used to collect statement expansions during the transform.
-class State {
- private:
- /// The clone context.
- CloneContext& ctx;
-
- /// The program builder.
- ProgramBuilder& b;
-
- /// The HoistToDeclBefore helper instance.
- HoistToDeclBefore hoist_to_decl_before;
-
- public:
+/// PIMPL state for the transform
+struct ExpandCompoundAssignment::State {
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
@@ -158,15 +145,32 @@
ctx.Replace(stmt, b.Assign(new_lhs(), value));
}
- /// Finalize the transformation and clone the module.
- void Finalize() { ctx.Clone(); }
+ private:
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// The program builder.
+ ProgramBuilder& b;
+
+ /// The HoistToDeclBefore helper instance.
+ HoistToDeclBefore hoist_to_decl_before;
};
-} // namespace
+ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
-void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
+
+Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
@@ -175,7 +179,9 @@
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
}
}
- state.Finalize();
+
+ ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h
index 1081df7..6b299c5 100644
--- a/src/tint/transform/expand_compound_assignment.h
+++ b/src/tint/transform/expand_compound_assignment.h
@@ -45,19 +45,13 @@
/// Destructor
~ExpandCompoundAssignment() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/first_index_offset.cc b/src/tint/transform/first_index_offset.cc
index cafca32..eb698be 100644
--- a/src/tint/transform/first_index_offset.cc
+++ b/src/tint/transform/first_index_offset.cc
@@ -35,6 +35,15 @@
constexpr char kFirstVertexName[] = "first_vertex_index";
constexpr char kFirstInstanceName[] = "first_instance_index";
+bool ShouldRun(const Program* program) {
+ for (auto* fn : program->AST().Functions()) {
+ if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
FirstIndexOffset::BindingPoint::BindingPoint() = default;
@@ -49,16 +58,16 @@
FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default;
-bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
- return true;
- }
+Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
- return false;
-}
-void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
// Get the uniform buffer binding point
uint32_t ub_binding = binding_;
uint32_t ub_group = group_;
@@ -115,17 +124,17 @@
if (has_vertex_or_instance_index) {
// Add uniform buffer members and calculate byte offsets
utils::Vector<const ast::StructMember*, 8> members;
- members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
- members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
- auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
+ members.Push(b.Member(kFirstVertexName, b.ty.u32()));
+ members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
+ auto* struct_ = b.Structure(b.Sym(), std::move(members));
// Create a global to hold the uniform buffer
- Symbol buffer_name = ctx.dst->Sym();
- ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform,
- utils::Vector{
- ctx.dst->Binding(AInt(ub_binding)),
- ctx.dst->Group(AInt(ub_group)),
- });
+ Symbol buffer_name = b.Sym();
+ b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform,
+ utils::Vector{
+ b.Binding(AInt(ub_binding)),
+ b.Group(AInt(ub_group)),
+ });
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
@@ -150,9 +159,10 @@
});
}
- ctx.Clone();
-
outputs.Add<Data>(has_vertex_or_instance_index);
+
+ ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/first_index_offset.h b/src/tint/transform/first_index_offset.h
index 04758cd..f84d811 100644
--- a/src/tint/transform/first_index_offset.h
+++ b/src/tint/transform/first_index_offset.h
@@ -103,19 +103,10 @@
/// Destructor
~FirstIndexOffset() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
uint32_t binding_ = 0;
diff --git a/src/tint/transform/for_loop_to_loop.cc b/src/tint/transform/for_loop_to_loop.cc
index e585790..63ccb12 100644
--- a/src/tint/transform/for_loop_to_loop.cc
+++ b/src/tint/transform/for_loop_to_loop.cc
@@ -14,17 +14,17 @@
#include "src/tint/transform/for_loop_to_loop.h"
+#include <utility>
+
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop);
namespace tint::transform {
-ForLoopToLoop::ForLoopToLoop() = default;
+namespace {
-ForLoopToLoop::~ForLoopToLoop() = default;
-
-bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
return true;
@@ -33,19 +33,31 @@
return false;
}
-void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+} // namespace
+
+ForLoopToLoop::ForLoopToLoop() = default;
+
+ForLoopToLoop::~ForLoopToLoop() = default;
+
+Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 8> stmts;
if (auto* cond = for_loop->condition) {
// !condition
- auto* not_cond =
- ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
+ auto* not_cond = b.Not(ctx.Clone(cond));
// { break; }
- auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
+ auto* break_body = b.Block(b.Break());
// if (!condition) { break; }
- stmts.Push(ctx.dst->If(not_cond, break_body));
+ stmts.Push(b.If(not_cond, break_body));
}
for (auto* stmt : for_loop->body->statements) {
stmts.Push(ctx.Clone(stmt));
@@ -53,20 +65,21 @@
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
- continuing = ctx.dst->Block(ctx.Clone(cont));
+ continuing = b.Block(ctx.Clone(cont));
}
- auto* body = ctx.dst->Block(stmts);
- auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
+ auto* body = b.Block(stmts);
+ auto* loop = b.Loop(body, continuing);
if (auto* init = for_loop->initializer) {
- return ctx.dst->Block(ctx.Clone(init), loop);
+ return b.Block(ctx.Clone(init), loop);
}
return loop;
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/for_loop_to_loop.h b/src/tint/transform/for_loop_to_loop.h
index 5ab690a..fe3db97 100644
--- a/src/tint/transform/for_loop_to_loop.h
+++ b/src/tint/transform/for_loop_to_loop.h
@@ -29,19 +29,10 @@
/// Destructor
~ForLoopToLoop() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/localize_struct_array_assignment.cc b/src/tint/transform/localize_struct_array_assignment.cc
index 8077393..bfe8865 100644
--- a/src/tint/transform/localize_struct_array_assignment.cc
+++ b/src/tint/transform/localize_struct_array_assignment.cc
@@ -32,70 +32,15 @@
namespace tint::transform {
-/// Private implementation of LocalizeStructArrayAssignment transform
-class LocalizeStructArrayAssignment::State {
- private:
- CloneContext& ctx;
- ProgramBuilder& b;
-
- /// Returns true if `expr` contains an index accessor expression to a
- /// structure member of array type.
- bool ContainsStructArrayIndex(const ast::Expression* expr) {
- bool result = false;
- ast::TraverseExpressions(
- expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
- // Indexing using a runtime value?
- auto* idx_sem = ctx.src->Sem().Get(ia->index);
- if (!idx_sem->ConstantValue()) {
- // Indexing a member access expr?
- if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
- // That accesses an array?
- if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
- result = true;
- return ast::TraverseAction::Stop;
- }
- }
- }
- return ast::TraverseAction::Descend;
- });
-
- return result;
- }
-
- // Returns the type and address space of the originating variable of the lhs
- // of the assignment statement.
- // See https://www.w3.org/TR/WGSL/#originating-variable-section
- std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
- const ast::AssignmentStatement* assign_stmt) {
- auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable();
- if (!source_var) {
- TINT_ICE(Transform, b.Diagnostics())
- << "Unable to determine originating variable for lhs of assignment "
- "statement";
- return {};
- }
-
- auto* type = source_var->Type();
- if (auto* ref = type->As<sem::Reference>()) {
- return {ref->StoreType(), ref->AddressSpace()};
- } else if (auto* ptr = type->As<sem::Pointer>()) {
- return {ptr->StoreType(), ptr->AddressSpace()};
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "Expecting to find variable of type pointer or reference on lhs "
- "of assignment statement";
- return {};
- }
-
- public:
+/// PIMPL state for the transform
+struct LocalizeStructArrayAssignment::State {
/// Constructor
- /// @param ctx_in the CloneContext primed with the input program and
- /// ProgramBuilder
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Runs the transform
- void Run() {
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
struct Shared {
bool process_nested_nodes = false;
utils::Vector<const ast::Statement*, 4> insert_before_stmts;
@@ -189,6 +134,65 @@
});
ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Returns true if `expr` contains an index accessor expression to a
+ /// structure member of array type.
+ bool ContainsStructArrayIndex(const ast::Expression* expr) {
+ bool result = false;
+ ast::TraverseExpressions(
+ expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
+ // Indexing using a runtime value?
+ auto* idx_sem = src->Sem().Get(ia->index);
+ if (!idx_sem->ConstantValue()) {
+ // Indexing a member access expr?
+ if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
+ // That accesses an array?
+ if (src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
+ result = true;
+ return ast::TraverseAction::Stop;
+ }
+ }
+ }
+ return ast::TraverseAction::Descend;
+ });
+
+ return result;
+ }
+
+ // Returns the type and address space of the originating variable of the lhs
+ // of the assignment statement.
+ // See https://www.w3.org/TR/WGSL/#originating-variable-section
+ std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
+ const ast::AssignmentStatement* assign_stmt) {
+ auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable();
+ if (!source_var) {
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Unable to determine originating variable for lhs of assignment "
+ "statement";
+ return {};
+ }
+
+ auto* type = source_var->Type();
+ if (auto* ref = type->As<sem::Reference>()) {
+ return {ref->StoreType(), ref->AddressSpace()};
+ } else if (auto* ptr = type->As<sem::Pointer>()) {
+ return {ptr->StoreType(), ptr->AddressSpace()};
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "Expecting to find variable of type pointer or reference on lhs "
+ "of assignment statement";
+ return {};
}
};
@@ -196,9 +200,10 @@
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
-void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State state(ctx);
- state.Run();
+Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ return State{src}.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/localize_struct_array_assignment.h b/src/tint/transform/localize_struct_array_assignment.h
index 130f8cc..169e33c 100644
--- a/src/tint/transform/localize_struct_array_assignment.h
+++ b/src/tint/transform/localize_struct_array_assignment.h
@@ -36,17 +36,13 @@
/// Destructor
~LocalizeStructArrayAssignment() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
- class State;
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/manager.cc b/src/tint/transform/manager.cc
index 4e83320..79603c8 100644
--- a/src/tint/transform/manager.cc
+++ b/src/tint/transform/manager.cc
@@ -31,9 +31,9 @@
Manager::Manager() = default;
Manager::~Manager() = default;
-Output Manager::Run(const Program* program, const DataMap& data) const {
- const Program* in = program;
-
+Transform::ApplyResult Manager::Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const {
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(in);
@@ -46,34 +46,30 @@
};
#endif
- Output out;
+ std::optional<Program> output;
+
for (const auto& transform : transforms_) {
- if (!transform->ShouldRun(in, data)) {
- TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name
- << std::endl);
- continue;
- }
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
- auto res = transform->Run(in, data);
- out.program = std::move(res.program);
- out.data.Add(std::move(res.data));
- in = &out.program;
- if (!in->IsValid()) {
- TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
- return out;
- }
+ if (auto result = transform->Apply(program, inputs, outputs)) {
+ output.emplace(std::move(result.value()));
+ program = &output.value();
- if (transform == transforms_.back()) {
- TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
+ if (!program->IsValid()) {
+ TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
+ break;
+ }
+
+ if (transform == transforms_.back()) {
+ TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
+ }
+ } else {
+ TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name
+ << std::endl);
}
}
- if (program == in) {
- out.program = program->Clone();
- }
-
- return out;
+ return output;
}
} // namespace tint::transform
diff --git a/src/tint/transform/manager.h b/src/tint/transform/manager.h
index 9d4049f..64ca847 100644
--- a/src/tint/transform/manager.h
+++ b/src/tint/transform/manager.h
@@ -47,11 +47,10 @@
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
}
- /// Runs the transforms on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific input data
- /// @returns the transformed program and diagnostics
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
std::vector<std::unique_ptr<Transform>> transforms_;
diff --git a/src/tint/transform/merge_return.cc b/src/tint/transform/merge_return.cc
index aec6b6d..2b45b73 100644
--- a/src/tint/transform/merge_return.cc
+++ b/src/tint/transform/merge_return.cc
@@ -65,15 +65,6 @@
MergeReturn::~MergeReturn() = default;
-bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* func : program->AST().Functions()) {
- if (NeedsTransform(program, func)) {
- return true;
- }
- }
- return false;
-}
-
namespace {
/// Internal class used to during the transform.
@@ -223,7 +214,12 @@
} // namespace
-void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ bool made_changes = false;
+
for (auto* func : ctx.src->AST().Functions()) {
if (!NeedsTransform(ctx.src, func)) {
continue;
@@ -231,9 +227,15 @@
State state(ctx, func);
state.ProcessStatement(func->body);
+ made_changes = true;
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
}
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/merge_return.h b/src/tint/transform/merge_return.h
index 1334a5c..f6db5c2 100644
--- a/src/tint/transform/merge_return.h
+++ b/src/tint/transform/merge_return.h
@@ -27,19 +27,10 @@
/// Destructor
~MergeReturn() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc
index f9c11e5..16a622e 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc
@@ -38,6 +38,15 @@
// The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr";
+bool ShouldRun(const Program* program) {
+ for (auto* decl : program->AST().GlobalDeclarations()) {
+ if (decl->Is<ast::Variable>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
@@ -56,7 +65,7 @@
}
} // namespace
-/// State holds the current transform state.
+/// PIMPL state for the transform
struct ModuleScopeVarToEntryPointParam::State {
/// The clone context.
CloneContext& ctx;
@@ -501,19 +510,20 @@
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
-bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (decl->Is<ast::Variable>()) {
- return true;
- }
+Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
- return false;
-}
-void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state{ctx};
state.Process();
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.h b/src/tint/transform/module_scope_var_to_entry_point_param.h
index 75bdaf3..377151f 100644
--- a/src/tint/transform/module_scope_var_to_entry_point_param.h
+++ b/src/tint/transform/module_scope_var_to_entry_point_param.h
@@ -69,20 +69,12 @@
/// Destructor
~ModuleScopeVarToEntryPointParam() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
-
+ private:
struct State;
};
diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc
index 002b858..c3ebf4a 100644
--- a/src/tint/transform/multiplanar_external_texture.cc
+++ b/src/tint/transform/multiplanar_external_texture.cc
@@ -31,6 +31,17 @@
namespace tint::transform {
namespace {
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* ty = node->As<ast::Type>()) {
+ if (program->Sem().Get<sem::ExternalTexture>(ty)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
/// This struct stores symbols for new bindings created as a result of transforming a
/// texture_external instance.
struct NewBindingSymbols {
@@ -40,7 +51,7 @@
};
} // namespace
-/// State holds the current transform state
+/// PIMPL state for the transform
struct MultiplanarExternalTexture::State {
/// The clone context.
CloneContext& ctx;
@@ -537,30 +548,26 @@
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
-bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* ty = node->As<ast::Type>()) {
- if (program->Sem().Get<sem::ExternalTexture>(ty)) {
- return true;
- }
- }
- }
- return false;
-}
-
// Within this transform, an instance of a texture_external binding is unpacked into two
// texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
// buffer binding representing a struct of parameters. Calls to texture builtins that contain a
// texture_external parameter will be transformed into a newly generated version of the function,
// which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
-void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
auto* new_binding_points = inputs.Get<NewBindingPoints>();
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
if (!new_binding_points) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "missing new binding point data for " + std::string(TypeInfo().name));
- return;
+ b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " +
+ std::string(TypeInfo().name));
+ return Program(std::move(b));
}
State state(ctx, new_binding_points);
@@ -568,6 +575,7 @@
state.Process();
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/multiplanar_external_texture.h b/src/tint/transform/multiplanar_external_texture.h
index a10fed4..695e38c 100644
--- a/src/tint/transform/multiplanar_external_texture.h
+++ b/src/tint/transform/multiplanar_external_texture.h
@@ -80,21 +80,13 @@
/// Destructor
~MultiplanarExternalTexture() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
+ private:
struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/multiplanar_external_texture_test.cc b/src/tint/transform/multiplanar_external_texture_test.cc
index dacbb1e..4416d35 100644
--- a/src/tint/transform/multiplanar_external_texture_test.cc
+++ b/src/tint/transform/multiplanar_external_texture_test.cc
@@ -23,7 +23,11 @@
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
auto* src = R"()";
- EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
@@ -31,14 +35,22 @@
type ET = texture_external;
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
auto* src = R"(
@group(0) @binding(0) var ext_tex : texture_external;
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
@@ -46,7 +58,11 @@
fn f(ext_tex : texture_external) {}
)";
- EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
+ DataMap data;
+ data.Add<MultiplanarExternalTexture::NewBindingPoints>(
+ MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
+
+ EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
// Running the transform without passing in data for the new bindings should result in an error.
diff --git a/src/tint/transform/num_workgroups_from_uniform.cc b/src/tint/transform/num_workgroups_from_uniform.cc
index 2122f07..e6681ca 100644
--- a/src/tint/transform/num_workgroups_from_uniform.cc
+++ b/src/tint/transform/num_workgroups_from_uniform.cc
@@ -29,6 +29,18 @@
namespace tint::transform {
namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* attr = node->As<ast::BuiltinAttribute>()) {
+ if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
/// Accessor describes the identifiers used in a member accessor that is being
/// used to retrieve the num_workgroups builtin from a parameter.
struct Accessor {
@@ -44,41 +56,40 @@
size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); }
};
};
+
} // namespace
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
-bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* attr = node->As<ast::BuiltinAttribute>()) {
- if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
- return true;
- }
- }
- }
- return false;
-}
+Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
-void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
- return;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
const char* kNumWorkgroupsMemberName = "num_workgroups";
// Find all entry point parameters that declare the num_workgroups builtin.
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
- for (auto* func : ctx.src->AST().Functions()) {
+ for (auto* func : src->AST().Functions()) {
// num_workgroups is only valid for compute stages.
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
continue;
}
- for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
+ for (auto* param : src->Sem().Get(func)->Parameters()) {
// Because the CanonicalizeEntryPointIO transform has been run, builtins
// will only appear as struct members.
auto* str = param->Type()->As<sem::Struct>();
@@ -108,7 +119,7 @@
// If this is the only member, remove the struct and parameter too.
if (str->Members().size() == 1) {
ctx.Remove(func->params, param->Declaration());
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
+ ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration());
}
}
}
@@ -119,11 +130,10 @@
const ast::Variable* num_workgroups_ubo = nullptr;
auto get_ubo = [&]() {
if (!num_workgroups_ubo) {
- auto* num_workgroups_struct = ctx.dst->Structure(
- ctx.dst->Sym(),
- utils::Vector{
- ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())),
- });
+ auto* num_workgroups_struct =
+ b.Structure(b.Sym(), utils::Vector{
+ b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())),
+ });
uint32_t group, binding;
if (cfg->ubo_binding.has_value()) {
@@ -135,9 +145,9 @@
// plus 1, or group 0 if no resource bound.
group = 0;
- for (auto* global : ctx.src->AST().GlobalVariables()) {
+ for (auto* global : src->AST().GlobalVariables()) {
if (global->HasBindingPoint()) {
- auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global);
+ auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global);
auto binding_point = global_sem->BindingPoint();
if (binding_point.group >= group) {
group = binding_point.group + 1;
@@ -148,16 +158,16 @@
binding = 0;
}
- num_workgroups_ubo = ctx.dst->GlobalVar(
- ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
- ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding)));
+ num_workgroups_ubo =
+ b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
+ b.Group(AInt(group)), b.Binding(AInt(binding)));
}
return num_workgroups_ubo;
};
// Now replace all the places where the builtins are accessed with the value
// loaded from the uniform buffer.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
auto* accessor = node->As<ast::MemberAccessorExpression>();
if (!accessor) {
continue;
@@ -168,12 +178,12 @@
}
if (to_replace.count({ident->symbol, accessor->member->symbol})) {
- ctx.Replace(accessor,
- ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
+ ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
}
}
ctx.Clone();
+ return Program(std::move(b));
}
NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp)
diff --git a/src/tint/transform/num_workgroups_from_uniform.h b/src/tint/transform/num_workgroups_from_uniform.h
index 292c823..25308f2 100644
--- a/src/tint/transform/num_workgroups_from_uniform.h
+++ b/src/tint/transform/num_workgroups_from_uniform.h
@@ -72,19 +72,10 @@
std::optional<sem::BindingPoint> ubo_binding;
};
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/num_workgroups_from_uniform_test.cc b/src/tint/transform/num_workgroups_from_uniform_test.cc
index 093081c..435ab33 100644
--- a/src/tint/transform/num_workgroups_from_uniform_test.cc
+++ b/src/tint/transform/num_workgroups_from_uniform_test.cc
@@ -28,7 +28,9 @@
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
- EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
+ DataMap data;
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
@@ -38,7 +40,9 @@
}
)";
- EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
+ DataMap data;
+ data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
+ EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
@@ -55,7 +59,6 @@
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
-
EXPECT_EQ(expect, str(got));
}
diff --git a/src/tint/transform/packed_vec3.cc b/src/tint/transform/packed_vec3.cc
index dde5aca..e947a53 100644
--- a/src/tint/transform/packed_vec3.cc
+++ b/src/tint/transform/packed_vec3.cc
@@ -33,14 +33,15 @@
namespace tint::transform {
-/// The PIMPL state for the PackedVec3 transform
+/// PIMPL state for the transform
struct PackedVec3::State {
/// Constructor
- /// @param c the CloneContext
- explicit State(CloneContext& c) : ctx(c) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Runs the transform
- void Run() {
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
// Packed vec3<T> struct members
utils::Hashset<const sem::StructMember*, 8> members;
@@ -72,6 +73,10 @@
}
}
+ if (members.IsEmpty()) {
+ return SkipTransform;
+ }
+
// Walk the nodes, starting with the most deeply nested, finding all the AST expressions
// that load a whole packed vector (not a scalar / swizzle of the vector).
utils::Hashset<const sem::Expression*, 16> refs;
@@ -137,36 +142,20 @@
}
ctx.Clone();
- }
-
- /// @returns true if this transform should be run for the given program
- /// @param program the program to inspect
- static bool ShouldRun(const Program* program) {
- for (auto* decl : program->AST().GlobalDeclarations()) {
- if (auto* str = program->Sem().Get<sem::Struct>(decl)) {
- if (str->IsHostShareable()) {
- for (auto* member : str->Members()) {
- if (auto* vec = member->Type()->As<sem::Vector>()) {
- if (vec->Width() == 3) {
- return true;
- }
- }
- }
- }
- }
- }
- return false;
+ return Program(std::move(b));
}
private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Alias to the semantic info in ctx.src
const sem::Info& sem = ctx.src->Sem();
/// Alias to the symbols in ctx.src
const SymbolTable& sym = ctx.src->Symbols();
- /// Alias to the ctx.dst program builder
- ProgramBuilder& b = *ctx.dst;
};
PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {}
@@ -183,12 +172,8 @@
PackedVec3::PackedVec3() = default;
PackedVec3::~PackedVec3() = default;
-bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const {
- return State::ShouldRun(program);
-}
-
-void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State{src}.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/packed_vec3.h b/src/tint/transform/packed_vec3.h
index 9d899cb..0d304fa 100644
--- a/src/tint/transform/packed_vec3.h
+++ b/src/tint/transform/packed_vec3.h
@@ -56,21 +56,13 @@
/// Destructor
~PackedVec3() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc
index 10b0565..4ceb39d 100644
--- a/src/tint/transform/pad_structs.cc
+++ b/src/tint/transform/pad_structs.cc
@@ -50,8 +50,10 @@
PadStructs::~PadStructs() = default;
-void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
+Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ auto& sem = src->Sem();
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs;
utils::Hashset<const ast::StructMember*, 8> padding_members;
@@ -65,7 +67,7 @@
bool has_runtime_sized_array = false;
utils::Vector<const ast::StructMember*, 8> new_members;
for (auto* mem : str->Members()) {
- auto name = ctx.src->Symbols().NameFor(mem->Name());
+ auto name = src->Symbols().NameFor(mem->Name());
if (offset < mem->Offset()) {
CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset);
@@ -75,7 +77,7 @@
auto* ty = mem->Type();
const ast::Type* type = CreateASTTypeFor(ctx, ty);
- new_members.Push(ctx.dst->Member(name, type));
+ new_members.Push(b.Member(name, type));
uint32_t size = ty->Size();
if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) {
@@ -97,8 +99,8 @@
if (offset < struct_size && !has_runtime_sized_array) {
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
}
- auto* new_struct = ctx.dst->create<ast::Struct>(ctx.Clone(ast_str->name),
- std::move(new_members), utils::Empty);
+ auto* new_struct =
+ b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty);
replaced_structs[ast_str] = new_struct;
return new_struct;
});
@@ -131,16 +133,17 @@
auto* arg = ast_call->args.begin();
for (auto* member : new_struct->members) {
if (padding_members.Contains(member)) {
- new_args.Push(ctx.dst->Expr(0_u));
+ new_args.Push(b.Expr(0_u));
} else {
new_args.Push(ctx.Clone(*arg));
arg++;
}
}
- return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args);
+ return b.Construct(CreateASTTypeFor(ctx, str), new_args);
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/pad_structs.h b/src/tint/transform/pad_structs.h
index 55fec74..e9996d4 100644
--- a/src/tint/transform/pad_structs.h
+++ b/src/tint/transform/pad_structs.h
@@ -30,14 +30,10 @@
/// Destructor
~PadStructs() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/promote_initializers_to_let.cc b/src/tint/transform/promote_initializers_to_let.cc
index 315a4ce..9e02c45 100644
--- a/src/tint/transform/promote_initializers_to_let.cc
+++ b/src/tint/transform/promote_initializers_to_let.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/transform/promote_initializers_to_let.h"
+
+#include <utility>
+
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/statement.h"
@@ -27,9 +30,16 @@
PromoteInitializersToLet::~PromoteInitializersToLet() = default;
-void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
HoistToDeclBefore hoist_to_decl_before(ctx);
+ bool any_promoted = false;
+
// Hoists array and structure initializers to a constant variable, declared
// just before the statement of usage.
auto promote = [&](const sem::Expression* expr) {
@@ -59,14 +69,15 @@
return true;
}
+ any_promoted = true;
return hoist_to_decl_before.Add(expr, expr->Declaration(), true);
};
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ for (auto* node : src->ASTNodes().Objects()) {
bool ok = Switch(
node, //
[&](const ast::CallExpression* expr) {
- if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* sem = src->Sem().Get(expr)) {
auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>();
if (ctor->Target()->Is<sem::TypeInitializer>()) {
return promote(sem);
@@ -75,7 +86,7 @@
return true;
},
[&](const ast::IdentifierExpression* expr) {
- if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* sem = src->Sem().Get(expr)) {
if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
// Identifier resolves to a variable
if (auto* stmt = user->Stmt()) {
@@ -96,13 +107,17 @@
return true;
},
[&](Default) { return true; });
-
if (!ok) {
- return;
+ return Program(std::move(b));
}
}
+ if (!any_promoted) {
+ return SkipTransform;
+ }
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/promote_initializers_to_let.h b/src/tint/transform/promote_initializers_to_let.h
index 78793c7..b1bb291 100644
--- a/src/tint/transform/promote_initializers_to_let.h
+++ b/src/tint/transform/promote_initializers_to_let.h
@@ -33,14 +33,10 @@
/// Destructor
~PromoteInitializersToLet() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index 2bcda04..ea13b0b 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -53,34 +53,36 @@
// to else {if}s so that the next transform, DecomposeSideEffects, can insert
// hoisted expressions above their current location.
struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> {
- class State;
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
+ ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
};
-class SimplifySideEffectStatements::State : public StateBase {
- HoistToDeclBefore hoist_to_decl_before;
+Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
- public:
- explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {}
+ bool made_changes = false;
- void Run() {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* expr = node->As<ast::Expression>()) {
- auto* sem_expr = sem.Get(expr);
- if (!sem_expr || !sem_expr->HasSideEffects()) {
- continue;
- }
-
- hoist_to_decl_before.Prepare(sem_expr);
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* expr = node->As<ast::Expression>()) {
+ auto* sem_expr = src->Sem().Get(expr);
+ if (!sem_expr || !sem_expr->HasSideEffects()) {
+ continue;
}
- }
- ctx.Clone();
- }
-};
-void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State state(ctx);
- state.Run();
+ hoist_to_decl_before.Prepare(sem_expr);
+ made_changes = true;
+ }
+ }
+
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return Program(std::move(b));
}
// Decomposes side-effecting expressions to ensure order of evaluation. This
@@ -89,7 +91,7 @@
struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> {
class CollectHoistsState;
class DecomposeState;
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
+ ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
};
// CollectHoistsState traverses the AST top-down, identifying which expressions
@@ -667,12 +669,15 @@
}
return nullptr;
});
-
- ctx.Clone();
}
};
-void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
// First collect side-effecting expressions to hoist
CollectHoistsState collect_hoists_state{ctx};
auto to_hoist = collect_hoists_state.Run();
@@ -680,6 +685,9 @@
// Now decompose these expressions
DecomposeState decompose_state{ctx, std::move(to_hoist)};
decompose_state.Run();
+
+ ctx.Clone();
+ return Program(std::move(b));
}
} // namespace
@@ -687,13 +695,13 @@
PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default;
PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
-Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const {
+Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
transform::Manager manager;
manager.Add<SimplifySideEffectStatements>();
manager.Add<DecomposeSideEffects>();
-
- auto output = manager.Run(program, data);
- return output;
+ return manager.Apply(src, inputs, outputs);
}
} // namespace tint::transform
diff --git a/src/tint/transform/promote_side_effects_to_decl.h b/src/tint/transform/promote_side_effects_to_decl.h
index d5d1126..99e80c6 100644
--- a/src/tint/transform/promote_side_effects_to_decl.h
+++ b/src/tint/transform/promote_side_effects_to_decl.h
@@ -31,12 +31,10 @@
/// Destructor
~PromoteSideEffectsToDecl() override;
- protected:
- /// Runs the transform on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific data
- /// @returns the transformation result
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_continue_in_switch.cc b/src/tint/transform/remove_continue_in_switch.cc
index e5df23f..cf0158f 100644
--- a/src/tint/transform/remove_continue_in_switch.cc
+++ b/src/tint/transform/remove_continue_in_switch.cc
@@ -32,53 +32,19 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch);
namespace tint::transform {
-namespace {
-class State {
- private:
- CloneContext& ctx;
- ProgramBuilder& b;
- const sem::Info& sem;
-
- // Map of switch statement to 'tint_continue' variable.
- std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
-
- // If `cont` is within a switch statement within a loop, returns a pointer to
- // that switch statement.
- static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
- const ast::ContinueStatement* cont) {
- // Find whether first parent is a switch or a loop
- auto* sem_stmt = sem.Get(cont);
- auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
- sem::ForLoopStatement, sem::WhileStatement>();
- if (!sem_parent) {
- return nullptr;
- }
- return sem_parent->Declaration()->As<ast::SwitchStatement>();
- }
-
- public:
+/// PIMPL state for the transform
+struct RemoveContinueInSwitch::State {
/// Constructor
- /// @param ctx_in the context
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
-
- /// Returns true if this transform should be run for the given program
- static bool ShouldRun(const Program* program) {
- for (auto* node : program->ASTNodes().Objects()) {
- auto* stmt = node->As<ast::ContinueStatement>();
- if (!stmt) {
- continue;
- }
- if (GetParentSwitchInLoop(program->Sem(), stmt)) {
- return true;
- }
- }
- return false;
- }
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Runs the transform
- void Run() {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ bool made_changes = false;
+
+ for (auto* node : src->ASTNodes().Objects()) {
auto* cont = node->As<ast::ContinueStatement>();
if (!cont) {
continue;
@@ -90,6 +56,8 @@
continue;
}
+ made_changes = true;
+
auto cont_var_name =
tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
// Create and insert 'var tint_continue : bool = false;' before the
@@ -116,22 +84,50 @@
ctx.Replace(cont, new_stmt);
}
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Alias to src->sem
+ const sem::Info& sem = src->Sem();
+
+ // Map of switch statement to 'tint_continue' variable.
+ std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
+
+ // If `cont` is within a switch statement within a loop, returns a pointer to
+ // that switch statement.
+ static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
+ const ast::ContinueStatement* cont) {
+ // Find whether first parent is a switch or a loop
+ auto* sem_stmt = sem.Get(cont);
+ auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
+ sem::ForLoopStatement, sem::WhileStatement>();
+ if (!sem_parent) {
+ return nullptr;
+ }
+ return sem_parent->Declaration()->As<ast::SwitchStatement>();
}
};
-} // namespace
-
RemoveContinueInSwitch::RemoveContinueInSwitch() = default;
RemoveContinueInSwitch::~RemoveContinueInSwitch() = default;
-bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const {
- return State::ShouldRun(program);
-}
-
-void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State state(ctx);
- state.Run();
+Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ State state(src);
+ return state.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_continue_in_switch.h b/src/tint/transform/remove_continue_in_switch.h
index 9e5a4d5..1070906 100644
--- a/src/tint/transform/remove_continue_in_switch.h
+++ b/src/tint/transform/remove_continue_in_switch.h
@@ -31,19 +31,13 @@
/// Destructor
~RemoveContinueInSwitch() override;
- protected:
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_phonies.cc b/src/tint/transform/remove_phonies.cc
index 1f84538..080152c 100644
--- a/src/tint/transform/remove_phonies.cc
+++ b/src/tint/transform/remove_phonies.cc
@@ -41,34 +41,25 @@
RemovePhonies::~RemovePhonies() = default;
-bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (node->Is<ast::PhonyExpression>()) {
- return true;
- }
- if (auto* stmt = node->As<ast::CallStatement>()) {
- if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) {
- return true;
- }
- }
- }
- return false;
-}
+Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
-void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- auto& sem = ctx.src->Sem();
+ auto& sem = src->Sem();
- std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks;
+ utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ bool made_changes = false;
+ for (auto* node : src->ASTNodes().Objects()) {
Switch(
node,
[&](const ast::AssignmentStatement* stmt) {
if (stmt->lhs->Is<ast::PhonyExpression>()) {
+ made_changes = true;
+
std::vector<const ast::Expression*> side_effects;
if (!ast::TraverseExpressions(
- stmt->rhs, ctx.dst->Diagnostics(),
- [&](const ast::CallExpression* expr) {
+ stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) {
// ast::CallExpression may map to a function or builtin call
// (both may have side-effects), or a type initializer or
// type conversion (both do not have side effects).
@@ -100,8 +91,7 @@
if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
// Phony assignment with single call side effect.
// Replace phony assignment with call.
- ctx.Replace(stmt,
- [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
+ ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); });
return;
}
}
@@ -114,22 +104,21 @@
for (auto* arg : side_effects) {
sig.push_back(sem.Get(arg)->Type()->UnwrapRef());
}
- auto sink = utils::GetOrCreate(sinks, sig, [&] {
- auto name = ctx.dst->Symbols().New("phony_sink");
+ auto sink = sinks.GetOrCreate(sig, [&] {
+ auto name = b.Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params;
for (auto* ty : sig) {
auto* ast_ty = CreateASTTypeFor(ctx, ty);
- params.Push(
- ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty));
+ params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
}
- ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
+ b.Func(name, params, b.ty.void_(), {});
return name;
});
utils::Vector<const ast::Expression*, 8> args;
for (auto* arg : side_effects) {
args.Push(ctx.Clone(arg));
}
- return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
+ return b.CallStmt(b.Call(sink, args));
});
}
},
@@ -138,12 +127,18 @@
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
auto* sem_expr = sem.Get(stmt->expr);
if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) {
+ made_changes = true;
ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt);
}
});
}
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_phonies.h b/src/tint/transform/remove_phonies.h
index daa1812..99a049e 100644
--- a/src/tint/transform/remove_phonies.h
+++ b/src/tint/transform/remove_phonies.h
@@ -33,19 +33,10 @@
/// Destructor
~RemovePhonies() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/remove_unreachable_statements.cc b/src/tint/transform/remove_unreachable_statements.cc
index 964d767..f9bf202 100644
--- a/src/tint/transform/remove_unreachable_statements.cc
+++ b/src/tint/transform/remove_unreachable_statements.cc
@@ -36,27 +36,28 @@
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
-bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* node : program->ASTNodes().Objects()) {
- if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
+Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ bool made_changes = false;
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* stmt = src->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
- return true;
+ RemoveStatement(ctx, stmt->Declaration());
+ made_changes = true;
}
}
}
- return false;
-}
-void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
- if (!stmt->IsReachable()) {
- RemoveStatement(ctx, stmt->Declaration());
- }
- }
+ if (!made_changes) {
+ return SkipTransform;
}
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/remove_unreachable_statements.h b/src/tint/transform/remove_unreachable_statements.h
index 7f8b947..f5848f5 100644
--- a/src/tint/transform/remove_unreachable_statements.h
+++ b/src/tint/transform/remove_unreachable_statements.h
@@ -32,19 +32,10 @@
/// Destructor
~RemoveUnreachableStatements() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc
index 562a52f..0fd1113 100644
--- a/src/tint/transform/renamer.cc
+++ b/src/tint/transform/renamer.cc
@@ -1252,39 +1252,31 @@
Renamer::Renamer() = default;
Renamer::~Renamer() = default;
-Output Renamer::Run(const Program* in, const DataMap& inputs) const {
- ProgramBuilder out;
- // Disable auto-cloning of symbols, since we want to rename them.
- CloneContext ctx(&out, in, false);
+Transform::ApplyResult Renamer::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap& outputs) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
// Swizzles, builtin calls and builtin structure members need to keep their
// symbols preserved.
- std::unordered_set<const ast::IdentifierExpression*> preserve;
- for (auto* node : in->ASTNodes().Objects()) {
+ utils::Hashset<const ast::IdentifierExpression*, 8> preserve;
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* member = node->As<ast::MemberAccessorExpression>()) {
- auto* sem = in->Sem().Get(member);
- if (!sem) {
- TINT_ICE(Transform, out.Diagnostics())
- << "MemberAccessorExpression has no semantic info";
- continue;
- }
+ auto* sem = src->Sem().Get(member);
if (sem->Is<sem::Swizzle>()) {
- preserve.emplace(member->member);
- } else if (auto* str_expr = in->Sem().Get(member->structure)) {
+ preserve.Add(member->member);
+ } else if (auto* str_expr = src->Sem().Get(member->structure)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure
- preserve.emplace(member->member);
+ preserve.Add(member->member);
}
}
}
} else if (auto* call = node->As<ast::CallExpression>()) {
- auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
- if (!sem) {
- TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info";
- continue;
- }
+ auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
if (sem->Target()->Is<sem::Builtin>()) {
- preserve.emplace(call->target.name);
+ preserve.Add(call->target.name);
}
}
}
@@ -1300,7 +1292,7 @@
}
ctx.ReplaceAll([&](Symbol sym_in) {
- auto name_in = ctx.src->Symbols().NameFor(sym_in);
+ auto name_in = src->Symbols().NameFor(sym_in);
if (preserve_unicode || text::utf8::IsASCII(name_in)) {
switch (target) {
case Target::kAll:
@@ -1343,17 +1335,20 @@
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
- if (preserve.count(ident)) {
+ if (preserve.Contains(ident)) {
auto sym_in = ident->symbol;
- auto str = in->Symbols().NameFor(sym_in);
- auto sym_out = out.Symbols().Register(str);
+ auto str = src->Symbols().NameFor(sym_in);
+ auto sym_out = b.Symbols().Register(str);
return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out);
}
return nullptr; // Clone ident. Uses the symbol remapping above.
});
- ctx.Clone();
- return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings)));
+ ctx.Clone(); // Must come before the std::move()
+
+ outputs.Add<Data>(std::move(remappings));
+
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/renamer.h b/src/tint/transform/renamer.h
index 000aee9..8a9f97e 100644
--- a/src/tint/transform/renamer.h
+++ b/src/tint/transform/renamer.h
@@ -85,11 +85,10 @@
/// Destructor
~Renamer() override;
- /// Runs the transform on `program`, returning the transformation result.
- /// @param program the source program to transform
- /// @param data optional extra transform-specific input data
- /// @returns the transformation result
- Output Run(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc
index a22f84f..75d63f2 100644
--- a/src/tint/transform/robustness.cc
+++ b/src/tint/transform/robustness.cc
@@ -33,36 +33,48 @@
namespace tint::transform {
-/// State holds the current transform state
+/// PIMPL state for the transform
struct Robustness::State {
- /// The clone context
- CloneContext& ctx;
+ /// Constructor
+ /// @param program the source program
+ /// @param omitted the omitted address spaces
+ State(const Program* program, std::unordered_set<ast::AddressSpace>&& omitted)
+ : src(program), omitted_address_spaces(std::move(omitted)) {}
- /// Set of address spacees to not apply the transform to
- std::unordered_set<ast::AddressSpace> omitted_classes;
-
- /// Applies the transformation state to `ctx`.
- void Transform() {
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); });
ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); });
+
+ ctx.Clone();
+ return Program(std::move(b));
}
+ private:
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// Set of address spaces to not apply the transform to
+ std::unordered_set<ast::AddressSpace> omitted_address_spaces;
+
/// Apply bounds clamping to array, vector and matrix indexing
/// @param expr the array, vector or matrix index expression
/// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
/// changes.
const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) {
- auto* sem =
- ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
+ auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
auto* ret_type = sem->Type();
auto* ref = ret_type->As<sem::Reference>();
- if (ref && omitted_classes.count(ref->AddressSpace()) != 0) {
+ if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) {
return nullptr;
}
- ProgramBuilder& b = *ctx.dst;
-
// idx return the cloned index expression, as a u32.
auto idx = [&]() -> const ast::Expression* {
auto* i = ctx.Clone(expr->index);
@@ -109,8 +121,8 @@
} else {
// Note: Don't be tempted to use the array override variable as an expression
// here, the name might be shadowed!
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- sem::Array::kErrExpectedConstantCount);
+ b.Diagnostics().add_error(diag::System::Transform,
+ sem::Array::kErrExpectedConstantCount);
return nullptr;
}
@@ -119,7 +131,7 @@
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled object type in robustness of array index: "
- << ctx.src->FriendlyName(ret_type->UnwrapRef());
+ << src->FriendlyName(ret_type->UnwrapRef());
return nullptr;
});
@@ -127,9 +139,9 @@
return nullptr; // Clamping not needed
}
- auto src = ctx.Clone(expr->source);
- auto* obj = ctx.Clone(expr->object);
- return b.IndexAccessor(src, obj, clamped_idx);
+ auto idx_src = ctx.Clone(expr->source);
+ auto* idx_obj = ctx.Clone(expr->object);
+ return b.IndexAccessor(idx_src, idx_obj, clamped_idx);
}
/// @param type builtin type
@@ -145,15 +157,13 @@
/// @return the clamped replacement call expression, or nullptr if `expr`
/// should be cloned without changes.
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
- auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* call_target = call->Target();
auto* builtin = call_target->As<sem::Builtin>();
if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
return nullptr; // No transform, just clone.
}
- ProgramBuilder& b = *ctx.dst;
-
// Indices of the mandatory texture and coords parameters, and the optional
// array and level parameters.
auto& signature = builtin->Signature();
@@ -261,7 +271,7 @@
// Clamp the level argument, if provided
if (level_idx >= 0) {
auto* arg = expr->args[static_cast<size_t>(level_idx)];
- ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a));
+ ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a));
}
return nullptr; // Clone, which will use the argument replacements above.
@@ -276,28 +286,27 @@
Robustness::Robustness() = default;
Robustness::~Robustness() = default;
-void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult Robustness::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
- std::unordered_set<ast::AddressSpace> omitted_classes;
- for (auto sc : cfg.omitted_classes) {
+ std::unordered_set<ast::AddressSpace> omitted_address_spaces;
+ for (auto sc : cfg.omitted_address_spaces) {
switch (sc) {
case AddressSpace::kUniform:
- omitted_classes.insert(ast::AddressSpace::kUniform);
+ omitted_address_spaces.insert(ast::AddressSpace::kUniform);
break;
case AddressSpace::kStorage:
- omitted_classes.insert(ast::AddressSpace::kStorage);
+ omitted_address_spaces.insert(ast::AddressSpace::kStorage);
break;
}
}
- State state{ctx, std::move(omitted_classes)};
-
- state.Transform();
- ctx.Clone();
+ return State{src, std::move(omitted_address_spaces)}.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/robustness.h b/src/tint/transform/robustness.h
index 21a7ff9..14c5fe1 100644
--- a/src/tint/transform/robustness.h
+++ b/src/tint/transform/robustness.h
@@ -54,9 +54,9 @@
/// @returns this Config
Config& operator=(const Config&);
- /// Address spacees to omit from apply the transform to.
+ /// Address spaces to omit from apply the transform to.
/// This allows for optimizing on hardware that provide safe accesses.
- std::unordered_set<AddressSpace> omitted_classes;
+ std::unordered_set<AddressSpace> omitted_address_spaces;
};
/// Constructor
@@ -64,14 +64,10 @@
/// Destructor
~Robustness() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
struct State;
diff --git a/src/tint/transform/robustness_test.cc b/src/tint/transform/robustness_test.cc
index 16d958f..990bbde 100644
--- a/src/tint/transform/robustness_test.cc
+++ b/src/tint/transform/robustness_test.cc
@@ -1274,7 +1274,7 @@
)";
Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
+ cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
DataMap data;
data.Add<Robustness::Config>(cfg);
@@ -1325,7 +1325,7 @@
)";
Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
+ cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);
@@ -1376,8 +1376,8 @@
)";
Robustness::Config cfg;
- cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
- cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
+ cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
+ cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);
diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc
index ea35699..b2b99ed 100644
--- a/src/tint/transform/simplify_pointers.cc
+++ b/src/tint/transform/simplify_pointers.cc
@@ -45,14 +45,18 @@
} // namespace
-/// The PIMPL state for the SimplifyPointers transform
+/// PIMPL state for the transform
struct SimplifyPointers::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context) : ctx(context) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Traverses the expression `expr` looking for non-literal array indexing
/// expressions that would affect the computed address of a pointer
@@ -120,10 +124,11 @@
}
}
- /// Performs the transformation
- void Run() {
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
// A map of saved expressions to their saved variable name
- std::unordered_map<const ast::Expression*, Symbol> saved_vars;
+ utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
// Register the ast::Expression transform handler.
// This performs two different transformations:
@@ -135,9 +140,8 @@
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
- auto it = saved_vars.find(expr);
- if (it != saved_vars.end()) {
- return ctx.dst->Expr(it->second);
+ if (auto* saved_var = saved_vars.Find(expr)) {
+ return ctx.dst->Expr(*saved_var);
}
// Reduce the expression, folding away chains of address-of / indirections
@@ -174,7 +178,7 @@
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
- std::vector<const ast::VariableDeclStatement*> saved;
+ utils::Vector<const ast::VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
@@ -182,18 +186,18 @@
auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
- saved.emplace_back(decl);
+ saved.Push(decl);
// Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the
// ReplaceAll() handler above.
- saved_vars.emplace(idx_expr, saved_name);
+ saved_vars.Add(idx_expr, saved_name);
});
// Find the place to insert the saved declarations.
// Special care needs to be made for lets declared as the initializer
// part of for-loops. In this case the block will hold the for-loop
// statement, not the let.
- if (!saved.empty()) {
+ if (!saved.IsEmpty()) {
auto* stmt = ctx.src->Sem().Get(let);
auto* block = stmt->Block();
// Find the statement owned by the block (either the let decl or a
@@ -219,7 +223,9 @@
RemoveStatement(ctx, let);
}
}
+
ctx.Clone();
+ return Program(std::move(b));
}
};
@@ -227,8 +233,8 @@
SimplifyPointers::~SimplifyPointers() = default;
-void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/simplify_pointers.h b/src/tint/transform/simplify_pointers.h
index 787c7d8..6e040bb 100644
--- a/src/tint/transform/simplify_pointers.h
+++ b/src/tint/transform/simplify_pointers.h
@@ -39,16 +39,13 @@
/// Destructor
~SimplifyPointers() override;
- protected:
- struct State;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc
index 8d26a7f..87787ae 100644
--- a/src/tint/transform/single_entry_point.cc
+++ b/src/tint/transform/single_entry_point.cc
@@ -30,33 +30,37 @@
SingleEntryPoint::~SingleEntryPoint() = default;
-void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
-
- return;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "missing transform data for " + std::string(TypeInfo().name));
+ return Program(std::move(b));
}
// Find the target entry point.
const ast::Function* entry_point = nullptr;
- for (auto* f : ctx.src->AST().Functions()) {
+ for (auto* f : src->AST().Functions()) {
if (!f->IsEntryPoint()) {
continue;
}
- if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
+ if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
entry_point = f;
break;
}
}
if (entry_point == nullptr) {
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- "entry point '" + cfg->entry_point_name + "' not found");
- return;
+ b.Diagnostics().add_error(diag::System::Transform,
+ "entry point '" + cfg->entry_point_name + "' not found");
+ return Program(std::move(b));
}
- auto& sem = ctx.src->Sem();
+ auto& sem = src->Sem();
// Build set of referenced module-scope variables for faster lookups later.
std::unordered_set<const ast::Variable*> referenced_vars;
@@ -66,12 +70,12 @@
// Clone any module-scope variables, types, and functions that are statically referenced by the
// target entry point.
- for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
+ for (auto* decl : src->AST().GlobalDeclarations()) {
Switch(
decl, //
[&](const ast::TypeDecl* ty) {
// TODO(jrprice): Strip unused types.
- ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
+ b.AST().AddTypeDecl(ctx.Clone(ty));
},
[&](const ast::Override* override) {
if (referenced_vars.count(override)) {
@@ -80,37 +84,39 @@
// so that its allocated ID so that it won't be affected by other
// stripped away overrides
auto* global = sem.Get(override);
- const auto* id = ctx.dst->Id(global->OverrideId());
+ const auto* id = b.Id(global->OverrideId());
ctx.InsertFront(override->attributes, id);
}
- ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
+ b.AST().AddGlobalVariable(ctx.Clone(override));
}
},
[&](const ast::Var* var) {
if (referenced_vars.count(var)) {
- ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
+ b.AST().AddGlobalVariable(ctx.Clone(var));
}
},
[&](const ast::Const* c) {
// Always keep 'const' declarations, as these can be used by attributes and array
// sizes, which are not tracked as transitively used by functions. They also don't
// typically get emitted by the backend unless they're actually used.
- ctx.dst->AST().AddGlobalVariable(ctx.Clone(c));
+ b.AST().AddGlobalVariable(ctx.Clone(c));
},
[&](const ast::Function* func) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
- ctx.dst->AST().AddFunction(ctx.Clone(func));
+ b.AST().AddFunction(ctx.Clone(func));
}
},
- [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); },
+ [&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
[&](Default) {
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
});
}
// Clone the entry point.
- ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
+ b.AST().AddFunction(ctx.Clone(entry_point));
+
+ return Program(std::move(b));
}
SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {}
diff --git a/src/tint/transform/single_entry_point.h b/src/tint/transform/single_entry_point.h
index 59aa021..7aba5e8 100644
--- a/src/tint/transform/single_entry_point.h
+++ b/src/tint/transform/single_entry_point.h
@@ -53,14 +53,10 @@
/// Destructor
~SingleEntryPoint() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index 127c702..eba6681 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -37,7 +37,7 @@
using namespace tint::number_suffixes; // NOLINT
-/// Private implementation of transform
+/// PIMPL state for the transform
struct SpirvAtomic::State {
private:
/// A struct that has been forked because a subset of members were made atomic.
@@ -46,19 +46,24 @@
std::unordered_set<size_t> atomic_members;
};
- CloneContext& ctx;
- ProgramBuilder& b = *ctx.dst;
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs;
std::unordered_set<const sem::Variable*> atomic_variables;
utils::UniqueVector<const sem::Expression*, 8> atomic_expressions;
public:
/// Constructor
- /// @param c the clone context
- explicit State(CloneContext& c) : ctx(c) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Runs the transform
- void Run() {
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
// for atomic builtin calls.
for (auto* fn : ctx.src->AST().Functions()) {
@@ -102,6 +107,10 @@
}
}
+ if (atomic_expressions.IsEmpty()) {
+ return SkipTransform;
+ }
+
// Transform all variables and structure members that were used in atomic operations as
// atomic types. This propagates up originating expression chains.
ProcessAtomicExpressions();
@@ -143,6 +152,7 @@
ReplaceLoadsAndStores();
ctx.Clone();
+ return Program(std::move(b));
}
private:
@@ -297,17 +307,8 @@
ctx->dst->AllocateNodeID(), builtin);
}
-bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* fn : program->AST().Functions()) {
- if (ast::HasAttribute<Stub>(fn->attributes)) {
- return true;
- }
- }
- return false;
-}
-
-void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State{ctx}.Run();
+Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State{src}.Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/spirv_atomic.h b/src/tint/transform/spirv_atomic.h
index e1311c5..0f99dba 100644
--- a/src/tint/transform/spirv_atomic.h
+++ b/src/tint/transform/spirv_atomic.h
@@ -63,21 +63,13 @@
const sem::BuiltinType builtin;
};
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- protected:
+ private:
struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 920fc69..0746bda 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -77,14 +77,20 @@
namespace tint::transform {
-/// The PIMPL state for the Std140 transform
+/// PIMPL state for the transform
struct Std140::State {
/// Constructor
- /// @param c the CloneContext
- explicit State(CloneContext& c) : ctx(c) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
/// Runs the transform
- void Run() {
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ if (!ShouldRun()) {
+ // Transform is not required
+ return SkipTransform;
+ }
+
// Begin by creating forked types for any type that is used as a uniform buffer, that
// either directly or transitively contains a matrix that needs splitting for std140 layout.
ForkTypes();
@@ -116,11 +122,11 @@
});
ctx.Clone();
+ return Program(std::move(b));
}
/// @returns true if this transform should be run for the given program
- /// @param program the program to inspect
- static bool ShouldRun(const Program* program) {
+ bool ShouldRun() const {
// Returns true if the type needs to be forked for std140 usage.
auto needs_fork = [&](const sem::Type* ty) {
while (auto* arr = ty->As<sem::Array>()) {
@@ -135,7 +141,7 @@
};
// Scan structures for members that need forking
- for (auto* ty : program->Types()) {
+ for (auto* ty : src->Types()) {
if (auto* str = ty->As<sem::Struct>()) {
if (str->UsedAs(ast::AddressSpace::kUniform)) {
for (auto* member : str->Members()) {
@@ -148,8 +154,8 @@
}
// Scan uniform variables that have types that need forking
- for (auto* decl : program->AST().GlobalVariables()) {
- auto* global = program->Sem().Get(decl);
+ for (auto* decl : src->AST().GlobalVariables()) {
+ auto* global = src->Sem().Get(decl);
if (global->AddressSpace() == ast::AddressSpace::kUniform) {
if (needs_fork(global->Type()->UnwrapRef())) {
return true;
@@ -197,14 +203,16 @@
}
};
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
- /// Alias to the semantic info in ctx.src
- const sem::Info& sem = ctx.src->Sem();
- /// Alias to the symbols in ctx.src
- const SymbolTable& sym = ctx.src->Symbols();
- /// Alias to the ctx.dst program builder
- ProgramBuilder& b = *ctx.dst;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// Alias to the semantic info in src
+ const sem::Info& sem = src->Sem();
+ /// Alias to the symbols in src
+ const SymbolTable& sym = src->Symbols();
/// Map of load function signature, to the generated function
utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns;
@@ -218,7 +226,7 @@
// Map of original structure to 'std140' forked structure
utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs;
- // Map of structure member in ctx.src of a matrix type, to list of decomposed column
+ // Map of structure member in src of a matrix type, to list of decomposed column
// members in ctx.dst.
utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
std140_mat_members;
@@ -232,7 +240,7 @@
utils::Vector<Symbol, 4> columns;
};
- // Map of matrix type in ctx.src, to decomposed column structure in ctx.dst.
+ // Map of matrix type in src, to decomposed column structure in ctx.dst.
utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats;
/// AccessChain describes a chain of access expressions to uniform buffer variable.
@@ -266,7 +274,7 @@
/// map (via Std140Type()).
void ForkTypes() {
// For each module scope declaration...
- for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
+ for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) {
// Check to see if this is a structure used by a uniform buffer...
auto* str = sem.Get<sem::Struct>(global);
if (str && str->UsedAs(ast::AddressSpace::kUniform)) {
@@ -317,7 +325,7 @@
if (fork_std140) {
// Clone any members that have not already been cloned.
for (auto& member : members) {
- if (member->program_id == ctx.src->ID()) {
+ if (member->program_id == src->ID()) {
member = ctx.Clone(member);
}
}
@@ -326,7 +334,7 @@
auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140");
auto* std140 = b.create<ast::Struct>(name, std::move(members),
ctx.Clone(str->Declaration()->attributes));
- ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140);
+ ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
std140_structs.Add(str, name);
}
}
@@ -337,14 +345,13 @@
/// type that has been forked for std140-layout.
/// Populates the #std140_uniforms set.
void ReplaceUniformVarTypes() {
- for (auto* global : ctx.src->AST().GlobalVariables()) {
+ for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (var->declared_address_space == ast::AddressSpace::kUniform) {
auto* v = sem.Get(var);
if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) {
ctx.Replace(global->type, std140_ty);
std140_uniforms.Add(v);
- continue;
}
}
}
@@ -404,7 +411,7 @@
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
std::to_string(mat->rows()) + "_" +
- ctx.src->FriendlyName(mat->type()));
+ src->FriendlyName(mat->type()));
auto members =
DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size());
b.Structure(name, members);
@@ -421,7 +428,7 @@
if (auto* std140 = Std140Type(arr->ElemType())) {
utils::Vector<const ast::Attribute*, 1> attrs;
if (!arr->IsStrideImplicit()) {
- attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
+ attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
}
auto count = arr->ConstantCount();
if (!count) {
@@ -429,7 +436,7 @@
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
@@ -440,7 +447,7 @@
});
}
- /// @param mat the matrix to decompose (in ctx.src)
+ /// @param mat the matrix to decompose (in src)
/// @param name_prefix the name prefix to apply to each of the returned column vector members.
/// @param align the alignment in bytes of the matrix.
/// @param size the size in bytes of the matrix.
@@ -473,7 +480,7 @@
// Build the member
const auto col_name = name_prefix + std::to_string(i);
const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
- const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes));
+ const auto* col_member = b.Member(col_name, col_ty, std::move(attributes));
// Record the member for std140_mat_members
out.Push(col_member);
}
@@ -618,7 +625,7 @@
/// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
/// being converted.
- const std::string ConvertSuffix(const sem::Type* ty) const {
+ const std::string ConvertSuffix(const sem::Type* ty) {
return Switch(
ty, //
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
@@ -629,8 +636,7 @@
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "unexpected non-constant array count";
+ TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
count = 1;
}
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
@@ -642,7 +648,7 @@
[&](const sem::F32*) { return "f32"; },
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
- << "unhandled type for conversion name: " << ctx.src->FriendlyName(ty);
+ << "unhandled type for conversion name: " << src->FriendlyName(ty);
return "";
});
}
@@ -718,8 +724,7 @@
stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args))));
} else {
TINT_ICE(Transform, b.Diagnostics())
- << "failed to find std140 matrix info for: "
- << ctx.src->FriendlyName(ty);
+ << "failed to find std140 matrix info for: " << src->FriendlyName(ty);
}
}, //
[&](const sem::Array* arr) {
@@ -736,7 +741,7 @@
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
@@ -749,7 +754,7 @@
},
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
- << "unhandled type for conversion: " << ctx.src->FriendlyName(ty);
+ << "unhandled type for conversion: " << src->FriendlyName(ty);
});
// Generate the function
@@ -1063,7 +1068,7 @@
if (std::get_if<UniformVariable>(&access)) {
const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol));
- const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol);
+ const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol);
ty = chain.var->Type()->UnwrapRef();
return {expr, ty, name};
}
@@ -1090,7 +1095,7 @@
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
- << "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
+ << "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@@ -1104,14 +1109,14 @@
for (auto el : *swizzle) {
rhs += xyzw[el];
}
- auto swizzle_ty = ctx.src->Types().Find<sem::Vector>(
+ auto swizzle_ty = src->Types().Find<sem::Vector>(
vec->type(), static_cast<uint32_t>(swizzle->Length()));
auto* expr = b.MemberAccessor(lhs, rhs);
return {expr, swizzle_ty, rhs};
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
- << "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
+ << "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@@ -1140,7 +1145,7 @@
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
- << "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
+ << "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@@ -1150,12 +1155,8 @@
Std140::~Std140() = default;
-bool Std140::ShouldRun(const Program* program, const DataMap&) const {
- return State::ShouldRun(program);
-}
-
-void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/std140.h b/src/tint/transform/std140.h
index ec5cad5..49e663d 100644
--- a/src/tint/transform/std140.h
+++ b/src/tint/transform/std140.h
@@ -34,21 +34,13 @@
/// Destructor
~Std140() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
struct State;
-
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/substitute_override.cc b/src/tint/transform/substitute_override.cc
index 2de04e0..7c2d0a2 100644
--- a/src/tint/transform/substitute_override.cc
+++ b/src/tint/transform/substitute_override.cc
@@ -15,6 +15,7 @@
#include "src/tint/transform/substitute_override.h"
#include <functional>
+#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/builtin.h"
@@ -25,12 +26,9 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config);
namespace tint::transform {
+namespace {
-SubstituteOverride::SubstituteOverride() = default;
-
-SubstituteOverride::~SubstituteOverride() = default;
-
-bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->AST().GlobalVariables()) {
if (node->Is<ast::Override>()) {
return true;
@@ -39,18 +37,32 @@
return false;
}
-void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const {
+} // namespace
+
+SubstituteOverride::SubstituteOverride() = default;
+
+SubstituteOverride::~SubstituteOverride() = default;
+
+Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
+ const DataMap& config,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
const auto* data = config.Get<Config>();
if (!data) {
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- "Missing override substitution data");
- return;
+ b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data");
+ return Program(std::move(b));
+ }
+
+ if (!ShouldRun(ctx.src)) {
+ return SkipTransform;
}
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* {
auto* sem = ctx.src->Sem().Get(w);
- auto src = ctx.Clone(w->source);
+ auto source = ctx.Clone(w->source);
auto sym = ctx.Clone(w->symbol);
auto* ty = ctx.Clone(w->type);
@@ -58,30 +70,30 @@
auto iter = data->map.find(sem->OverrideId());
if (iter == data->map.end()) {
if (!w->initializer) {
- ctx.dst->Diagnostics().add_error(
+ b.Diagnostics().add_error(
diag::System::Transform,
"Initializer not provided for override, and override not overridden.");
return nullptr;
}
- return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer));
+ return b.Const(source, sym, ty, ctx.Clone(w->initializer));
}
auto value = iter->second;
auto* ctor = Switch(
sem->Type(),
- [&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to<double>()(value, 0.0)); },
- [&](const sem::I32*) { return ctx.dst->Expr(i32(value)); },
- [&](const sem::U32*) { return ctx.dst->Expr(u32(value)); },
- [&](const sem::F32*) { return ctx.dst->Expr(f32(value)); },
- [&](const sem::F16*) { return ctx.dst->Expr(f16(value)); });
+ [&](const sem::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); },
+ [&](const sem::I32*) { return b.Expr(i32(value)); },
+ [&](const sem::U32*) { return b.Expr(u32(value)); },
+ [&](const sem::F32*) { return b.Expr(f32(value)); },
+ [&](const sem::F16*) { return b.Expr(f16(value)); });
if (!ctor) {
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- "Failed to create override-expression");
+ b.Diagnostics().add_error(diag::System::Transform,
+ "Failed to create override-expression");
return nullptr;
}
- return ctx.dst->Const(src, sym, ty, ctor);
+ return b.Const(source, sym, ty, ctor);
});
// Ensure that objects that are indexed with an override-expression are materialized.
@@ -89,11 +101,10 @@
// resulting type of the index may change. See: crbug.com/tint/1697.
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
- if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* sem = src->Sem().Get(expr)) {
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
access->Index()->Stage() == sem::EvaluationStage::kOverride) {
- auto& b = *ctx.dst;
auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize),
ctx.Clone(expr->object));
return b.IndexAccessor(obj, ctx.Clone(expr->index));
@@ -104,6 +115,7 @@
});
ctx.Clone();
+ return Program(std::move(b));
}
SubstituteOverride::Config::Config() = default;
diff --git a/src/tint/transform/substitute_override.h b/src/tint/transform/substitute_override.h
index 940e11d..853acc7 100644
--- a/src/tint/transform/substitute_override.h
+++ b/src/tint/transform/substitute_override.h
@@ -75,19 +75,10 @@
/// Destructor
~SubstituteOverride() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/test_helper.h b/src/tint/transform/test_helper.h
index bc82fe5..ac48c5f 100644
--- a/src/tint/transform/test_helper.h
+++ b/src/tint/transform/test_helper.h
@@ -122,7 +122,18 @@
}
const Transform& t = TRANSFORM();
- return t.ShouldRun(&program, data);
+
+ DataMap outputs;
+ auto result = t.Apply(&program, data, outputs);
+ if (!result) {
+ return false;
+ }
+ if (!result->IsValid()) {
+ ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: "
+ << result->Diagnostics().str();
+ return true;
+ }
+ return result.has_value();
}
/// @param in the input WGSL source
diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc
index 3e03411..c37f3b4 100644
--- a/src/tint/transform/transform.cc
+++ b/src/tint/transform/transform.cc
@@ -46,24 +46,19 @@
Transform::Transform() = default;
Transform::~Transform() = default;
-Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const {
- ProgramBuilder builder;
- CloneContext ctx(&builder, program);
+Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const {
Output output;
- Run(ctx, data, output.data);
- output.program = Program(std::move(builder));
+ if (auto program = Apply(src, data, output.data)) {
+ output.program = std::move(program.value());
+ } else {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+ ctx.Clone();
+ output.program = Program(std::move(b));
+ }
return output;
}
-void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
- << "Transform::Run() unimplemented for " << TypeInfo().name;
-}
-
-bool Transform::ShouldRun(const Program*, const DataMap&) const {
- return true;
-}
-
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
auto* sem = ctx.src->Sem().Get(stmt);
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
diff --git a/src/tint/transform/transform.h b/src/tint/transform/transform.h
index c3e3d1d..6580e25 100644
--- a/src/tint/transform/transform.h
+++ b/src/tint/transform/transform.h
@@ -158,26 +158,30 @@
/// Destructor
~Transform() override;
- /// Runs the transform on `program`, returning the transformation result.
+ /// Runs the transform on @p program, returning the transformation result or a clone of
+ /// @p program.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
- virtual Output Run(const Program* program, const DataMap& data = {}) const;
+ Output Run(const Program* program, const DataMap& data = {}) const;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const;
+ /// The return value of Apply().
+ /// If SkipTransform (std::nullopt), then the transform is not needed to be run.
+ using ApplyResult = std::optional<Program>;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
+ /// Value returned from Apply() to indicate that the transform does not need to be run
+ static inline constexpr std::nullopt_t SkipTransform = std::nullopt;
+
+ /// Runs the transform on `program`, return.
+ /// @param program the input program
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
- virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const;
+ /// @returns a transformed program, or std::nullopt if the transform didn't need to run.
+ virtual ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const = 0;
+ protected:
/// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and
/// continuing of for-loops.
diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc
index d063ba2..82fdf6a 100644
--- a/src/tint/transform/transform_test.cc
+++ b/src/tint/transform/transform_test.cc
@@ -23,7 +23,9 @@
// Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform {
- Output Run(const Program*, const DataMap&) const override { return {}; }
+ ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
+ return SkipTransform;
+ }
const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder;
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index 975e2ed..93ce595 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -28,27 +28,32 @@
namespace tint::transform {
-/// The PIMPL state for the Unshadow transform
+/// PIMPL state for the transform
struct Unshadow::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
/// The clone context
- CloneContext& ctx;
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Constructor
- /// @param context the clone context
- explicit State(CloneContext& context) : ctx(context) {}
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
- /// Performs the transformation
- void Run() {
- auto& sem = ctx.src->Sem();
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ Transform::ApplyResult Run() {
+ auto& sem = src->Sem();
// Maps a variable to its new name.
- std::unordered_map<const sem::Variable*, Symbol> renamed_to;
+ utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
auto* decl = v->Declaration();
- auto name = ctx.src->Symbols().NameFor(decl->symbol);
- auto symbol = ctx.dst->Symbols().New(name);
- renamed_to.emplace(v, symbol);
+ auto name = src->Symbols().NameFor(decl->symbol);
+ auto symbol = b.Symbols().New(name);
+ renamed_to.Add(v, symbol);
auto source = ctx.Clone(decl->source);
auto* type = ctx.Clone(decl->type);
@@ -57,20 +62,20 @@
return Switch(
decl, //
[&](const ast::Var* var) {
- return ctx.dst->Var(source, symbol, type, var->declared_address_space,
- var->declared_access, initializer, attributes);
+ return b.Var(source, symbol, type, var->declared_address_space,
+ var->declared_access, initializer, attributes);
},
[&](const ast::Let*) {
- return ctx.dst->Let(source, symbol, type, initializer, attributes);
+ return b.Let(source, symbol, type, initializer, attributes);
},
[&](const ast::Const*) {
- return ctx.dst->Const(source, symbol, type, initializer, attributes);
+ return b.Const(source, symbol, type, initializer, attributes);
},
- [&](const ast::Parameter*) {
- return ctx.dst->Param(source, symbol, type, attributes);
+ [&](const ast::Parameter*) { //
+ return b.Param(source, symbol, type, attributes);
},
[&](Default) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "unexpected variable type: " << decl->TypeInfo().name;
return nullptr;
});
@@ -92,14 +97,15 @@
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
- auto it = renamed_to.find(user->Variable());
- if (it != renamed_to.end()) {
- return ctx.dst->Expr(it->second);
+ if (auto* renamed = renamed_to.Find(user->Variable())) {
+ return b.Expr(*renamed);
}
}
return nullptr;
});
+
ctx.Clone();
+ return Program(std::move(b));
}
};
@@ -107,8 +113,8 @@
Unshadow::~Unshadow() = default;
-void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- State(ctx).Run();
+Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
}
} // namespace tint::transform
diff --git a/src/tint/transform/unshadow.h b/src/tint/transform/unshadow.h
index 5ffe839..8ebf105 100644
--- a/src/tint/transform/unshadow.h
+++ b/src/tint/transform/unshadow.h
@@ -29,16 +29,13 @@
/// Destructor
~Unshadow() override;
- protected:
- struct State;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc
index 4e20d55..068fe35 100644
--- a/src/tint/transform/unwind_discard_functions.cc
+++ b/src/tint/transform/unwind_discard_functions.cc
@@ -35,7 +35,51 @@
namespace tint::transform {
namespace {
-class State {
+bool ShouldRun(const Program* program) {
+ auto& sem = program->Sem();
+ for (auto* f : program->AST().Functions()) {
+ if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+/// PIMPL state for the transform
+struct UnwindDiscardFunctions::State {
+ /// Constructor
+ /// @param ctx_in the context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
+
+ /// Runs the transform
+ void Run() {
+ ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
+ // Iterate block statements and replace them as needed.
+ for (auto* stmt : block->statements) {
+ if (auto* new_stmt = Statement(stmt)) {
+ ctx.Replace(stmt, new_stmt);
+ }
+
+ // Handle for loops, as they are the only other AST node that
+ // contains statements outside of BlockStatements.
+ if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
+ if (auto* new_stmt = Statement(fl->initializer)) {
+ ctx.Replace(fl->initializer, new_stmt);
+ }
+ if (auto* new_stmt = Statement(fl->continuing)) {
+ // NOTE: Should never reach here as we cannot discard in a
+ // continuing block.
+ ctx.Replace(fl->continuing, new_stmt);
+ }
+ }
+ }
+
+ return nullptr;
+ });
+ }
+
private:
CloneContext& ctx;
ProgramBuilder& b;
@@ -163,7 +207,7 @@
// Returns true if `stmt` is a for-loop initializer statement.
bool IsForLoopInitStatement(const ast::Statement* stmt) {
if (auto* sem_stmt = sem.Get(stmt)) {
- if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
+ if (auto* sem_fl = tint::As<sem::ForLoopStatement>(sem_stmt->Parent())) {
return sem_fl->Declaration()->initializer == stmt;
}
}
@@ -305,60 +349,26 @@
return TryInsertAfter(s, sem_expr);
});
}
-
- public:
- /// Constructor
- /// @param ctx_in the context
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
-
- /// Runs the transform
- void Run() {
- ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
- // Iterate block statements and replace them as needed.
- for (auto* stmt : block->statements) {
- if (auto* new_stmt = Statement(stmt)) {
- ctx.Replace(stmt, new_stmt);
- }
-
- // Handle for loops, as they are the only other AST node that
- // contains statements outside of BlockStatements.
- if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
- if (auto* new_stmt = Statement(fl->initializer)) {
- ctx.Replace(fl->initializer, new_stmt);
- }
- if (auto* new_stmt = Statement(fl->continuing)) {
- // NOTE: Should never reach here as we cannot discard in a
- // continuing block.
- ctx.Replace(fl->continuing, new_stmt);
- }
- }
- }
-
- return nullptr;
- });
-
- ctx.Clone();
- }
};
-} // namespace
-
UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;
-void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
State state(ctx);
state.Run();
-}
-bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const {
- auto& sem = program->Sem();
- for (auto* f : program->AST().Functions()) {
- if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
- return true;
- }
- }
- return false;
+ ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/unwind_discard_functions.h b/src/tint/transform/unwind_discard_functions.h
index 105a9d8..7614c27 100644
--- a/src/tint/transform/unwind_discard_functions.h
+++ b/src/tint/transform/unwind_discard_functions.h
@@ -44,19 +44,13 @@
/// Destructor
~UnwindDiscardFunctions() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
+ private:
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index 0aed996..e0b35a1 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -30,7 +30,59 @@
namespace tint::transform {
/// Private implementation of HoistToDeclBefore transform
-class HoistToDeclBefore::State {
+struct HoistToDeclBefore::State {
+ /// Constructor
+ /// @param ctx_in the clone context
+ explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
+
+ /// @copydoc HoistToDeclBefore::Add()
+ bool Add(const sem::Expression* before_expr,
+ const ast::Expression* expr,
+ bool as_let,
+ const char* decl_name) {
+ auto name = b.Symbols().New(decl_name);
+
+ if (as_let) {
+ auto builder = [this, expr, name] {
+ return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
+ };
+ if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
+ return false;
+ }
+ } else {
+ auto builder = [this, expr, name] {
+ return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
+ };
+ if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
+ return false;
+ }
+ }
+
+ // Replace the initializer expression with a reference to the let
+ ctx.Replace(expr, b.Expr(name));
+ return true;
+ }
+
+ /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
+ bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
+ if (stmt) {
+ auto builder = [stmt] { return stmt; };
+ return InsertBeforeImpl(before_stmt, std::move(builder));
+ }
+ return InsertBeforeImpl(before_stmt, Decompose{});
+ }
+
+ /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
+ bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
+ return InsertBeforeImpl(before_stmt, std::move(builder));
+ }
+
+ /// @copydoc HoistToDeclBefore::Prepare()
+ bool Prepare(const sem::Expression* before_expr) {
+ return InsertBefore(before_expr->Stmt(), nullptr);
+ }
+
+ private:
CloneContext& ctx;
ProgramBuilder& b;
@@ -215,6 +267,8 @@
template <typename BUILDER>
bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) {
+ (void)builder; // Avoid 'unused parameter' warning due to 'if constexpr'
+
auto* ip = before_stmt->Declaration();
auto* else_if = before_stmt->As<sem::IfStatement>();
@@ -299,58 +353,6 @@
<< "unhandled expression parent statement type: " << parent->TypeInfo().name;
return false;
}
-
- public:
- /// Constructor
- /// @param ctx_in the clone context
- explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
-
- /// @copydoc HoistToDeclBefore::Add()
- bool Add(const sem::Expression* before_expr,
- const ast::Expression* expr,
- bool as_let,
- const char* decl_name) {
- auto name = b.Symbols().New(decl_name);
-
- if (as_let) {
- auto builder = [this, expr, name] {
- return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
- };
- if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
- return false;
- }
- } else {
- auto builder = [this, expr, name] {
- return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
- };
- if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
- return false;
- }
- }
-
- // Replace the initializer expression with a reference to the let
- ctx.Replace(expr, b.Expr(name));
- return true;
- }
-
- /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
- bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
- if (stmt) {
- auto builder = [stmt] { return stmt; };
- return InsertBeforeImpl(before_stmt, std::move(builder));
- }
- return InsertBeforeImpl(before_stmt, Decompose{});
- }
-
- /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
- bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
- return InsertBeforeImpl(before_stmt, std::move(builder));
- }
-
- /// @copydoc HoistToDeclBefore::Prepare()
- bool Prepare(const sem::Expression* before_expr) {
- return InsertBefore(before_expr->Stmt(), nullptr);
- }
};
HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {}
diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h
index b2e993e..d9a8a8a 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.h
+++ b/src/tint/transform/utils/hoist_to_decl_before.h
@@ -77,7 +77,7 @@
bool Prepare(const sem::Expression* before_expr);
private:
- class State;
+ struct State;
std::unique_ptr<State> state_;
};
diff --git a/src/tint/transform/var_for_dynamic_index.cc b/src/tint/transform/var_for_dynamic_index.cc
index d30831e..81af013 100644
--- a/src/tint/transform/var_for_dynamic_index.cc
+++ b/src/tint/transform/var_for_dynamic_index.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/transform/var_for_dynamic_index.h"
+
+#include <utility>
+
#include "src/tint/program_builder.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@@ -22,7 +25,12 @@
VarForDynamicIndex::~VarForDynamicIndex() = default;
-void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
HoistToDeclBefore hoist_to_decl_before(ctx);
// Extracts array and matrix values that are dynamically indexed to a
@@ -30,7 +38,7 @@
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object;
- auto& sem = ctx.src->Sem();
+ auto& sem = src->Sem();
if (sem.Get(index_expr)->ConstantValue()) {
// Index expression resolves to a compile time value.
@@ -49,15 +57,21 @@
return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index");
};
- for (auto* node : ctx.src->ASTNodes().Objects()) {
+ bool index_accessor_found = false;
+ for (auto* node : src->ASTNodes().Objects()) {
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
if (!dynamic_index_to_var(access_expr)) {
- return;
+ return Program(std::move(b));
}
+ index_accessor_found = true;
}
}
+ if (!index_accessor_found) {
+ return SkipTransform;
+ }
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/var_for_dynamic_index.h b/src/tint/transform/var_for_dynamic_index.h
index 39ef2f2..070a2cd 100644
--- a/src/tint/transform/var_for_dynamic_index.h
+++ b/src/tint/transform/var_for_dynamic_index.h
@@ -31,14 +31,10 @@
/// Destructor
~VarForDynamicIndex() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_matrix_conversions.cc b/src/tint/transform/vectorize_matrix_conversions.cc
index 576b885..94fbdf3 100644
--- a/src/tint/transform/vectorize_matrix_conversions.cc
+++ b/src/tint/transform/vectorize_matrix_conversions.cc
@@ -30,11 +30,9 @@
namespace tint::transform {
-VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
+namespace {
-VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
-
-bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
@@ -50,14 +48,29 @@
return false;
}
-void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+} // namespace
+
+VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
+
+VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
+
+Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
using HelperFunctionKey =
utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
- auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_conv = call->Target()->As<sem::TypeConversion>();
if (!ty_conv) {
return nullptr;
@@ -72,16 +85,16 @@
return nullptr;
}
- auto& src = args[0];
+ auto& matrix = args[0];
- auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>();
+ auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>();
if (!src_type) {
return nullptr;
}
// The source and destination type of a matrix conversion must have a same shape.
if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "source and destination matrix has different shape in matrix conversion";
return nullptr;
}
@@ -90,47 +103,45 @@
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t c = 0; c < dst_type->columns(); c++) {
auto* src_matrix_expr = src_expression_builder();
- auto* src_column_expr =
- ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c)));
- columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()),
- src_column_expr));
+ auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
+ columns.Push(
+ b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr));
}
- return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns);
+ return b.Construct(CreateASTTypeFor(ctx, dst_type), columns);
};
// Replace the matrix conversion to column vector conversions and a matrix construction.
- if (!src->HasSideEffects()) {
+ if (!matrix->HasSideEffects()) {
// Simply use the argument's declaration if it has no side effects.
return build_vectorized_conversion_expression([&]() { //
- return ctx.Clone(src->Declaration());
+ return ctx.Clone(matrix->Declaration());
});
} else {
// If has side effects, use a helper function.
auto fn =
utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
- auto name =
- ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) +
- "x" + std::to_string(src_type->rows()) + "_" +
- ctx.dst->FriendlyName(src_type->type()) + "_" +
- ctx.dst->FriendlyName(dst_type->type()));
- ctx.dst->Func(
- name,
- utils::Vector{
- ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)),
- },
- CreateASTTypeFor(ctx, dst_type),
- utils::Vector{
- ctx.dst->Return(build_vectorized_conversion_expression([&]() { //
- return ctx.dst->Expr("value");
- })),
- });
+ auto name = b.Symbols().New(
+ "convert_mat" + std::to_string(src_type->columns()) + "x" +
+ std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) +
+ "_" + b.FriendlyName(dst_type->type()));
+ b.Func(name,
+ utils::Vector{
+ b.Param("value", CreateASTTypeFor(ctx, src_type)),
+ },
+ CreateASTTypeFor(ctx, dst_type),
+ utils::Vector{
+ b.Return(build_vectorized_conversion_expression([&]() { //
+ return b.Expr("value");
+ })),
+ });
return name;
});
- return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
+ return b.Call(fn, ctx.Clone(args[0]->Declaration()));
}
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_matrix_conversions.h b/src/tint/transform/vectorize_matrix_conversions.h
index f16467c..c86240c 100644
--- a/src/tint/transform/vectorize_matrix_conversions.h
+++ b/src/tint/transform/vectorize_matrix_conversions.h
@@ -28,19 +28,10 @@
/// Destructor
~VectorizeMatrixConversions() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_scalar_matrix_initializers.cc b/src/tint/transform/vectorize_scalar_matrix_initializers.cc
index 97b0e4f..e6e1c46 100644
--- a/src/tint/transform/vectorize_scalar_matrix_initializers.cc
+++ b/src/tint/transform/vectorize_scalar_matrix_initializers.cc
@@ -27,12 +27,9 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers);
namespace tint::transform {
+namespace {
-VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
-
-VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
-
-bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) {
@@ -46,11 +43,26 @@
return false;
}
-void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+} // namespace
+
+VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
+
+VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
+
+Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
std::unordered_map<const sem::Matrix*, Symbol> scalar_inits;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
- auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_init = call->Target()->As<sem::TypeInitializer>();
if (!ty_init) {
return nullptr;
@@ -87,10 +99,10 @@
}
// Construct the column vector.
- columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
- std::move(row_values)));
+ columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
+ std::move(row_values)));
}
- return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
+ return b.Construct(CreateASTTypeFor(ctx, mat_type), columns);
};
if (args.Length() == 1) {
@@ -98,23 +110,22 @@
// This is done to ensure that the single argument value is only evaluated once, and
// with the correct expression evaluation order.
auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] {
- auto name =
- ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
- std::to_string(mat_type->rows()));
- ctx.dst->Func(name,
- utils::Vector{
- // Single scalar parameter
- ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
- },
- CreateASTTypeFor(ctx, mat_type),
- utils::Vector{
- ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
- return ctx.dst->Expr("value");
- })),
- });
+ auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) +
+ "x" + std::to_string(mat_type->rows()));
+ b.Func(name,
+ utils::Vector{
+ // Single scalar parameter
+ b.Param("value", CreateASTTypeFor(ctx, mat_type->type())),
+ },
+ CreateASTTypeFor(ctx, mat_type),
+ utils::Vector{
+ b.Return(build_mat([&](uint32_t, uint32_t) { //
+ return b.Expr("value");
+ })),
+ });
return name;
});
- return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
+ return b.Call(fn, ctx.Clone(args[0]->Declaration()));
}
if (args.Length() == mat_type->columns() * mat_type->rows()) {
@@ -123,12 +134,13 @@
});
}
- TINT_ICE(Transform, ctx.dst->Diagnostics())
+ TINT_ICE(Transform, b.Diagnostics())
<< "matrix initializer has unexpected number of arguments";
return nullptr;
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/vectorize_scalar_matrix_initializers.h b/src/tint/transform/vectorize_scalar_matrix_initializers.h
index 342754a..f9c0164 100644
--- a/src/tint/transform/vectorize_scalar_matrix_initializers.h
+++ b/src/tint/transform/vectorize_scalar_matrix_initializers.h
@@ -29,19 +29,10 @@
/// Destructor
~VectorizeScalarMatrixInitializers() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index 00b5a06..d5ee424 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -201,13 +201,46 @@
return {BaseType::kInvalid, 0};
}
-struct State {
- State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {}
- State(const State&) = default;
- ~State() = default;
+} // namespace
- /// LocationReplacement describes an ast::Variable replacement for a
- /// location input.
+/// PIMPL state for the transform
+struct VertexPulling::State {
+ /// Constructor
+ /// @param program the source program
+ /// @param c the VertexPulling config
+ State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ // Find entry point
+ const ast::Function* func = nullptr;
+ for (auto* fn : src->AST().Functions()) {
+ if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
+ if (func != nullptr) {
+ b.Diagnostics().add_error(
+ diag::System::Transform,
+ "VertexPulling found more than one vertex entry point");
+ return Program(std::move(b));
+ }
+ func = fn;
+ }
+ }
+ if (func == nullptr) {
+ b.Diagnostics().add_error(diag::System::Transform,
+ "Vertex stage entry point not found");
+ return Program(std::move(b));
+ }
+
+ AddVertexStorageBuffers();
+ Process(func);
+
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+
+ private:
+ /// LocationReplacement describes an ast::Variable replacement for a location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
@@ -215,13 +248,22 @@
ast::Variable* to;
};
+ /// LocationInfo describes an input location
struct LocationInfo {
+ /// A builder that builds the expression that resolves to the (transformed) input location
std::function<const ast::Expression*()> expr;
+ /// The store type of the location variable
const sem::Type* type;
};
- CloneContext& ctx;
+ /// The source program
+ const Program* const src;
+ /// The transform config
VertexPulling::Config const cfg;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<uint32_t, LocationInfo> location_info;
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
std::function<const ast::Expression*()> instance_index_expr = nullptr;
@@ -235,7 +277,7 @@
Symbol GetVertexBufferName(uint32_t index) {
return utils::GetOrCreate(vertex_buffer_names, index, [&] {
static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
- return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
+ return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
});
}
@@ -243,7 +285,7 @@
Symbol GetStructBufferName() {
if (!struct_buffer_name.IsValid()) {
static const char kStructBufferName[] = "tint_vertex_data";
- struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
+ struct_buffer_name = b.Symbols().New(kStructBufferName);
}
return struct_buffer_name;
}
@@ -252,21 +294,19 @@
void AddVertexStorageBuffers() {
// Creating the struct type
static const char kStructName[] = "TintVertexData";
- auto* struct_type =
- ctx.dst->Structure(ctx.dst->Symbols().New(kStructName),
- utils::Vector{
- ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<u32>()),
- });
+ auto* struct_type = b.Structure(b.Symbols().New(kStructName),
+ utils::Vector{
+ b.Member(GetStructBufferName(), b.ty.array<u32>()),
+ });
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
- ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
- ast::AddressSpace::kStorage, ast::Access::kRead,
- ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group)));
+ b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage,
+ ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group)));
}
}
/// Creates and returns the assignment to the variables from the buffers
- ast::BlockStatement* CreateVertexPullingPreamble() {
+ const ast::BlockStatement* CreateVertexPullingPreamble() {
// Assign by looking at the vertex descriptor to find attributes with
// matching location.
@@ -276,7 +316,7 @@
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
if ((buffer_layout.array_stride & 3) != 0) {
- ctx.dst->Diagnostics().add_error(
+ b.Diagnostics().add_error(
diag::System::Transform,
"WebGPU requires that vertex stride must be a multiple of 4 bytes, "
"but VertexPulling array stride for buffer " +
@@ -292,15 +332,15 @@
// buffer_array_base is the base array offset for all the vertex
// attributes. These are units of uint (4 bytes).
auto buffer_array_base =
- ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
+ b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
auto* attribute_offset = index_expr;
if (buffer_layout.array_stride != 4) {
- attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u));
+ attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u));
}
// let pulling_offset_n = <attribute_offset>
- stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset)));
+ stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset)));
for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
auto it = location_info.find(attribute_desc.shader_location);
@@ -320,8 +360,8 @@
err << "VertexAttributeDescriptor for location "
<< std::to_string(attribute_desc.shader_location) << " has format "
<< attribute_desc.format << " but shader expects "
- << var.type->FriendlyName(ctx.src->Symbols());
- ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
+ << var.type->FriendlyName(src->Symbols());
+ b.Diagnostics().add_error(diag::System::Transform, err.str());
return nullptr;
}
@@ -337,16 +377,16 @@
// WGSL variable vector width is smaller than the loaded vector width
switch (var_dt.width) {
case 1:
- value = ctx.dst->MemberAccessor(fetch, "x");
+ value = b.MemberAccessor(fetch, "x");
break;
case 2:
- value = ctx.dst->MemberAccessor(fetch, "xy");
+ value = b.MemberAccessor(fetch, "xy");
break;
case 3:
- value = ctx.dst->MemberAccessor(fetch, "xyz");
+ value = b.MemberAccessor(fetch, "xyz");
break;
default:
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width;
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width;
return nullptr;
}
} else if (var_dt.width > fmt_dt.width) {
@@ -355,32 +395,32 @@
utils::Vector<const ast::Expression*, 8> values{fetch};
switch (var_dt.base_type) {
case BaseType::kI32:
- ty = ctx.dst->ty.i32();
+ ty = b.ty.i32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i));
+ values.Push(b.Expr((i == 3) ? 1_i : 0_i));
}
break;
case BaseType::kU32:
- ty = ctx.dst->ty.u32();
+ ty = b.ty.u32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u));
+ values.Push(b.Expr((i == 3) ? 1_u : 0_u));
}
break;
case BaseType::kF32:
- ty = ctx.dst->ty.f32();
+ ty = b.ty.f32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
- values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f));
+ values.Push(b.Expr((i == 3) ? 1_f : 0_f));
}
break;
default:
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type;
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type;
return nullptr;
}
- value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
+ value = b.Construct(b.ty.vec(ty, var_dt.width), values);
}
// Assign the value to the WGSL variable
- stmts.Push(ctx.dst->Assign(var.expr(), value));
+ stmts.Push(b.Assign(var.expr(), value));
}
}
@@ -388,7 +428,7 @@
return nullptr;
}
- return ctx.dst->create<ast::BlockStatement>(std::move(stmts));
+ return b.Block(std::move(stmts));
}
/// Generates an expression reading from a buffer a specific format.
@@ -407,7 +447,7 @@
};
// Returns a i32 loaded from buffer_base + offset.
- auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
+ auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); };
// Returns a u32 loaded from buffer_base + offset + 4.
auto load_next_u32 = [&] {
@@ -415,7 +455,7 @@
};
// Returns a i32 loaded from buffer_base + offset + 4.
- auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
+ auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); };
// Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
@@ -427,17 +467,17 @@
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
- return ctx.dst->Shl(low_u32, 16_u);
+ return b.Shl(low_u32, 16_u);
case 1:
- return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u);
+ return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u);
case 2:
- return ctx.dst->And(low_u32, 0xffff0000_u);
+ return b.And(low_u32, 0xffff0000_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
- auto* shr = ctx.dst->Shr(low_u32, 8_u);
- auto* shl = ctx.dst->Shl(high_u32, 24_u);
- return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u);
+ auto* shr = b.Shr(low_u32, 8_u);
+ auto* shl = b.Shl(high_u32, 24_u);
+ return b.And(b.Or(shl, shr), 0xffff0000_u);
}
}
};
@@ -450,24 +490,24 @@
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
- return ctx.dst->And(low_u32, 0xffff_u);
+ return b.And(low_u32, 0xffff_u);
case 1:
- return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u);
+ return b.And(b.Shr(low_u32, 8_u), 0xffff_u);
case 2:
- return ctx.dst->Shr(low_u32, 16_u);
+ return b.Shr(low_u32, 16_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
- auto* shr = ctx.dst->Shr(low_u32, 24_u);
- auto* shl = ctx.dst->Shl(high_u32, 8_u);
- return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u);
+ auto* shr = b.Shr(low_u32, 24_u);
+ auto* shl = b.Shl(high_u32, 8_u);
+ return b.And(b.Or(shl, shr), 0xffff_u);
}
}
};
// Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
- auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
+ auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); };
// Assumptions are made that alignment must be at least as large as the size
// of a single component.
@@ -480,128 +520,121 @@
// Vectors of basic primitives
case VertexFormat::kUint32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 2);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2);
case VertexFormat::kUint32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 3);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3);
case VertexFormat::kUint32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
- VertexFormat::kUint32, 4);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4);
case VertexFormat::kSint32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 2);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2);
case VertexFormat::kSint32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 3);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3);
case VertexFormat::kSint32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
- VertexFormat::kSint32, 4);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4);
case VertexFormat::kFloat32x2:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 2);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 2);
case VertexFormat::kFloat32x3:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 3);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 3);
case VertexFormat::kFloat32x4:
- return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
- VertexFormat::kFloat32, 4);
+ return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
+ 4);
case VertexFormat::kUint8x2: {
// yyxx0000, yyxx0000
- auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
+ auto* u16s = b.vec2<u32>(load_u16_h());
// xx000000, yyxx0000
- auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8_u, 0_u));
+ auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u));
// 000000xx, 000000yy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
+ return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kUint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
- auto* u32s = ctx.dst->vec4<u32>(load_u32());
+ auto* u32s = b.vec4<u32>(load_u32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
- auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
+ auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// 000000xx, 000000yy, 000000zz, 000000ww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
+ return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kUint16x2: {
// yyyyxxxx, yyyyxxxx
- auto* u32s = ctx.dst->vec2<u32>(load_u32());
+ auto* u32s = b.vec2<u32>(load_u32());
// xxxx0000, yyyyxxxx
- auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16_u, 0_u));
+ auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u));
// 0000xxxx, 0000yyyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
+ return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kUint16x4: {
// yyyyxxxx, wwwwzzzz
- auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
+ auto* u32s = b.vec2<u32>(load_u32(), load_next_u32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
- auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
+ auto* xxyy = b.MemberAccessor(u32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
- auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
+ auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
+ return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kSint8x2: {
// yyxx0000, yyxx0000
- auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
+ auto* i16s = b.vec2<i32>(load_i16_h());
// xx000000, yyxx0000
- auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8_u, 0_u));
+ auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u));
// ssssssxx, ssssssyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
+ return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kSint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
- auto* i32s = ctx.dst->vec4<i32>(load_i32());
+ auto* i32s = b.vec4<i32>(load_i32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
- auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
+ auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// ssssssxx, ssssssyy, sssssszz, ssssssww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
+ return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kSint16x2: {
// yyyyxxxx, yyyyxxxx
- auto* i32s = ctx.dst->vec2<i32>(load_i32());
+ auto* i32s = b.vec2<i32>(load_i32());
// xxxx0000, yyyyxxxx
- auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16_u, 0_u));
+ auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u));
// ssssxxxx, ssssyyyy
- return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
+ return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kSint16x4: {
// yyyyxxxx, wwwwzzzz
- auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
+ auto* i32s = b.vec2<i32>(load_i32(), load_next_i32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
- auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
+ auto* xxyy = b.MemberAccessor(i32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
- auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
+ auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// ssssxxxx, ssssyyyy, sssszzzz, sssswwww
- return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
+ return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kUnorm8x2:
- return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
+ return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy");
case VertexFormat::kSnorm8x2:
- return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
+ return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy");
case VertexFormat::kUnorm8x4:
- return ctx.dst->Call("unpack4x8unorm", load_u32());
+ return b.Call("unpack4x8unorm", load_u32());
case VertexFormat::kSnorm8x4:
- return ctx.dst->Call("unpack4x8snorm", load_u32());
+ return b.Call("unpack4x8snorm", load_u32());
case VertexFormat::kUnorm16x2:
- return ctx.dst->Call("unpack2x16unorm", load_u32());
+ return b.Call("unpack2x16unorm", load_u32());
case VertexFormat::kSnorm16x2:
- return ctx.dst->Call("unpack2x16snorm", load_u32());
+ return b.Call("unpack2x16snorm", load_u32());
case VertexFormat::kFloat16x2:
- return ctx.dst->Call("unpack2x16float", load_u32());
+ return b.Call("unpack2x16float", load_u32());
case VertexFormat::kUnorm16x4:
- return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()),
- ctx.dst->Call("unpack2x16unorm", load_next_u32()));
+ return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()),
+ b.Call("unpack2x16unorm", load_next_u32()));
case VertexFormat::kSnorm16x4:
- return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()),
- ctx.dst->Call("unpack2x16snorm", load_next_u32()));
+ return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()),
+ b.Call("unpack2x16snorm", load_next_u32()));
case VertexFormat::kFloat16x4:
- return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()),
- ctx.dst->Call("unpack2x16float", load_next_u32()));
+ return b.vec4<f32>(b.Call("unpack2x16float", load_u32()),
+ b.Call("unpack2x16float", load_next_u32()));
}
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
- << "format " << static_cast<int>(format);
+ TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format);
return nullptr;
}
@@ -623,12 +656,12 @@
const ast ::Expression* index = nullptr;
if (offset > 0) {
- index = ctx.dst->Add(array_base, u32(offset / 4));
+ index = b.Add(array_base, u32(offset / 4));
} else {
- index = ctx.dst->Expr(array_base);
+ index = b.Expr(array_base);
}
- u = ctx.dst->IndexAccessor(
- ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
+ u = b.IndexAccessor(
+ b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
} else {
// Unaligned load
@@ -639,22 +672,22 @@
uint32_t shift = 8u * (offset & 3u);
- auto* low_shr = ctx.dst->Shr(low, u32(shift));
- auto* high_shl = ctx.dst->Shl(high, u32(32u - shift));
- u = ctx.dst->Or(low_shr, high_shl);
+ auto* low_shr = b.Shr(low, u32(shift));
+ auto* high_shl = b.Shl(high, u32(32u - shift));
+ u = b.Or(low_shr, high_shl);
}
switch (format) {
case VertexFormat::kUint32:
return u;
case VertexFormat::kSint32:
- return ctx.dst->Bitcast(ctx.dst->ty.i32(), u);
+ return b.Bitcast(b.ty.i32(), u);
case VertexFormat::kFloat32:
- return ctx.dst->Bitcast(ctx.dst->ty.f32(), u);
+ return b.Bitcast(b.ty.f32(), u);
default:
break;
}
- TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
+ TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "invalid format for LoadPrimitive" << static_cast<int>(format);
return nullptr;
}
@@ -682,8 +715,7 @@
expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
}
- return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
- std::move(expr_list));
+ return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
@@ -696,34 +728,30 @@
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol);
auto* func_var_type = ctx.Clone(param->type);
- auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
- ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
+ auto* func_var = b.Var(func_var_sym, func_var_type);
+ ctx.InsertFront(func->body->statements, b.Decl(func_var));
// Capture mapping from location to the new variable.
LocationInfo info;
- info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
+ info.expr = [this, func_var]() { return b.Expr(func_var); };
- auto* sem = ctx.src->Sem().Get<sem::Parameter>(param);
+ auto* sem = src->Sem().Get<sem::Parameter>(param);
info.type = sem->Type();
if (!sem->Location().has_value()) {
- TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value";
+ TINT_ICE(Transform, b.Diagnostics()) << "Location missing value";
return;
}
location_info[sem->Location().value()] = info;
} else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
// Check for existing vertex_index and instance_index builtins.
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
- vertex_index_expr = [this, param]() {
- return ctx.dst->Expr(ctx.Clone(param->symbol));
- };
+ vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
} else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) {
- instance_index_expr = [this, param]() {
- return ctx.dst->Expr(ctx.Clone(param->symbol));
- };
+ instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
}
new_function_parameters.Push(ctx.Clone(param));
} else {
- TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
@@ -746,7 +774,7 @@
for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->symbol);
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
- return ctx.dst->MemberAccessor(param_sym, member_sym);
+ return b.MemberAccessor(param_sym, member_sym);
};
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
@@ -754,7 +782,7 @@
LocationInfo info;
info.expr = member_expr;
- auto* sem = ctx.src->Sem().Get(member);
+ auto* sem = src->Sem().Get(member);
info.type = sem->Type();
TINT_ASSERT(Transform, sem->Location().has_value());
@@ -770,7 +798,7 @@
}
members_to_clone.Push(member);
} else {
- TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+ TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
@@ -781,8 +809,8 @@
}
// Create a function-scope variable to replace the parameter.
- auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
- ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
+ auto* func_var = b.Var(param_sym, ctx.Clone(param->type));
+ ctx.InsertFront(func->body->statements, b.Decl(func_var));
if (!members_to_clone.IsEmpty()) {
// Create a new struct without the location attributes.
@@ -791,20 +819,20 @@
auto member_sym = ctx.Clone(member->symbol);
auto* member_type = ctx.Clone(member->type);
auto member_attrs = ctx.Clone(member->attributes);
- new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs)));
+ new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs)));
}
- auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
+ auto* new_struct = b.Structure(b.Sym(), new_members);
// Create a new function parameter with this struct.
- auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
+ auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct));
new_function_parameters.Push(new_param);
// Copy values from the new parameter to the function-scope variable.
for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->symbol);
ctx.InsertFront(func->body->statements,
- ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
- ctx.dst->MemberAccessor(new_param, member_name)));
+ b.Assign(b.MemberAccessor(func_var, member_name),
+ b.MemberAccessor(new_param, member_name)));
}
}
}
@@ -818,7 +846,7 @@
// Process entry point parameters.
for (auto* param : func->params) {
- auto* sem = ctx.src->Sem().Get(param);
+ auto* sem = src->Sem().Get(param);
if (auto* str = sem->Type()->As<sem::Struct>()) {
ProcessStructParameter(func, param, str->Declaration());
} else {
@@ -830,11 +858,11 @@
if (!vertex_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kVertex) {
- auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
- new_function_parameters.Push(ctx.dst->Param(
- name, ctx.dst->ty.u32(),
- utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)}));
- vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ auto name = b.Symbols().New("tint_pulling_vertex_index");
+ new_function_parameters.Push(
+ b.Param(name, b.ty.u32(),
+ utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)}));
+ vertex_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
@@ -842,11 +870,11 @@
if (!instance_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kInstance) {
- auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
- new_function_parameters.Push(ctx.dst->Param(
- name, ctx.dst->ty.u32(),
- utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)}));
- instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+ auto name = b.Symbols().New("tint_pulling_instance_index");
+ new_function_parameters.Push(
+ b.Param(name, b.ty.u32(),
+ utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)}));
+ instance_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
@@ -864,53 +892,24 @@
auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes);
auto* new_func =
- ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters,
- ret_type, body, std::move(attrs), std::move(ret_attrs));
+ b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body,
+ std::move(attrs), std::move(ret_attrs));
ctx.Replace(func, new_func);
}
};
-} // namespace
-
VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default;
-void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
+Transform::ApplyResult VertexPulling::Apply(const Program* src,
+ const DataMap& inputs,
+ DataMap&) const {
auto cfg = cfg_;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
- // Find entry point
- const ast::Function* func = nullptr;
- for (auto* fn : ctx.src->AST().Functions()) {
- if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
- if (func != nullptr) {
- ctx.dst->Diagnostics().add_error(
- diag::System::Transform,
- "VertexPulling found more than one vertex entry point");
- return;
- }
- func = fn;
- }
- }
- if (func == nullptr) {
- ctx.dst->Diagnostics().add_error(diag::System::Transform,
- "Vertex stage entry point not found");
- return;
- }
-
- // TODO(idanr): Need to check shader locations in descriptor cover all
- // attributes
-
- // TODO(idanr): Make sure we covered all error cases, to guarantee the
- // following stages will pass
-
- State state{ctx, cfg};
- state.AddVertexStorageBuffers();
- state.Process(func);
-
- ctx.Clone();
+ return State{src, cfg}.Run();
}
VertexPulling::Config::Config() = default;
diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h
index 6dd35bc..c0f88a5 100644
--- a/src/tint/transform/vertex_pulling.h
+++ b/src/tint/transform/vertex_pulling.h
@@ -171,16 +171,14 @@
/// Destructor
~VertexPulling() override;
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
+ struct State;
+
Config cfg_;
};
diff --git a/src/tint/transform/while_to_loop.cc b/src/tint/transform/while_to_loop.cc
index 45944e6..d359d2e 100644
--- a/src/tint/transform/while_to_loop.cc
+++ b/src/tint/transform/while_to_loop.cc
@@ -14,18 +14,17 @@
#include "src/tint/transform/while_to_loop.h"
+#include <utility>
+
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
namespace tint::transform {
+namespace {
-WhileToLoop::WhileToLoop() = default;
-
-WhileToLoop::~WhileToLoop() = default;
-
-bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
+bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) {
return true;
@@ -34,20 +33,32 @@
return false;
}
-void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
+} // namespace
+
+WhileToLoop::WhileToLoop() = default;
+
+WhileToLoop::~WhileToLoop() = default;
+
+Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
+ }
+
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 16> stmts;
auto* cond = w->condition;
// !condition
- auto* not_cond =
- ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
+ auto* not_cond = b.Not(ctx.Clone(cond));
// { break; }
- auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
+ auto* break_body = b.Block(b.Break());
// if (!condition) { break; }
- stmts.Push(ctx.dst->If(not_cond, break_body));
+ stmts.Push(b.If(not_cond, break_body));
for (auto* stmt : w->body->statements) {
stmts.Push(ctx.Clone(stmt));
@@ -55,13 +66,14 @@
const ast::BlockStatement* continuing = nullptr;
- auto* body = ctx.dst->Block(stmts);
- auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
+ auto* body = b.Block(stmts);
+ auto* loop = b.Loop(body, continuing);
return loop;
});
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/while_to_loop.h b/src/tint/transform/while_to_loop.h
index 4915d68..187799a 100644
--- a/src/tint/transform/while_to_loop.h
+++ b/src/tint/transform/while_to_loop.h
@@ -29,19 +29,10 @@
/// Destructor
~WhileToLoop() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
};
} // namespace tint::transform
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index ea65436..ed3584e9 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -31,10 +31,24 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
namespace tint::transform {
+namespace {
+
+bool ShouldRun(const Program* program) {
+ for (auto* global : program->AST().GlobalVariables()) {
+ if (auto* var = global->As<ast::Var>()) {
+ if (var->declared_address_space == ast::AddressSpace::kWorkgroup) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace
using StatementList = utils::Vector<const ast::Statement*, 8>;
-/// PIMPL state for the ZeroInitWorkgroupMemory transform
+/// PIMPL state for the transform
struct ZeroInitWorkgroupMemory::State {
/// The clone context
CloneContext& ctx;
@@ -424,24 +438,24 @@
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
-bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const {
- for (auto* global : program->AST().GlobalVariables()) {
- if (auto* var = global->As<ast::Var>()) {
- if (var->declared_address_space == ast::AddressSpace::kWorkgroup) {
- return true;
- }
- }
+Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ if (!ShouldRun(src)) {
+ return SkipTransform;
}
- return false;
-}
-void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
- for (auto* fn : ctx.src->AST().Functions()) {
+ ProgramBuilder b;
+ CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
+
+ for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
State{ctx}.Run(fn);
}
}
+
ctx.Clone();
+ return Program(std::move(b));
}
} // namespace tint::transform
diff --git a/src/tint/transform/zero_init_workgroup_memory.h b/src/tint/transform/zero_init_workgroup_memory.h
index 07feaa8..64f4da8 100644
--- a/src/tint/transform/zero_init_workgroup_memory.h
+++ b/src/tint/transform/zero_init_workgroup_memory.h
@@ -30,19 +30,10 @@
/// Destructor
~ZeroInitWorkgroupMemory() override;
- /// @param program the program to inspect
- /// @param data optional extra transform-specific input data
- /// @returns true if this transform should be run for the given program
- bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
-
- protected:
- /// Runs the transform using the CloneContext built for transforming a
- /// program. Run() is responsible for calling Clone() on the CloneContext.
- /// @param ctx the CloneContext primed with the input program and
- /// ProgramBuilder
- /// @param inputs optional extra transform-specific input data
- /// @param outputs optional extra transform-specific output data
- void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
private:
struct State;
diff --git a/test/tint/bug/tint/1739.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1739.wgsl.expected.dxc.hlsl
index 6a6f107..a80ef5b 100644
--- a/test/tint/bug/tint/1739.wgsl.expected.dxc.hlsl
+++ b/test/tint/bug/tint/1739.wgsl.expected.dxc.hlsl
@@ -1,3 +1,7 @@
+int2 tint_clamp(int2 e, int2 low, int2 high) {
+ return min(max(e, low), high);
+}
+
struct GammaTransferParams {
float G;
float A;
@@ -46,10 +50,6 @@
return float4(color, 1.0f);
}
-int2 tint_clamp(int2 e, int2 low, int2 high) {
- return min(max(e, low), high);
-}
-
float3x4 tint_symbol_6(uint4 buffer[11], uint offset) {
const uint scalar_offset = ((offset + 0u)) / 4;
const uint scalar_offset_1 = ((offset + 16u)) / 4;
diff --git a/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl
index 6a6f107..a80ef5b 100644
--- a/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl
+++ b/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl
@@ -1,3 +1,7 @@
+int2 tint_clamp(int2 e, int2 low, int2 high) {
+ return min(max(e, low), high);
+}
+
struct GammaTransferParams {
float G;
float A;
@@ -46,10 +50,6 @@
return float4(color, 1.0f);
}
-int2 tint_clamp(int2 e, int2 low, int2 high) {
- return min(max(e, low), high);
-}
-
float3x4 tint_symbol_6(uint4 buffer[11], uint offset) {
const uint scalar_offset = ((offset + 0u)) / 4;
const uint scalar_offset_1 = ((offset + 16u)) / 4;
diff --git a/test/tint/bug/tint/1739.wgsl.expected.msl b/test/tint/bug/tint/1739.wgsl.expected.msl
index e1802c5..ed7a995 100644
--- a/test/tint/bug/tint/1739.wgsl.expected.msl
+++ b/test/tint/bug/tint/1739.wgsl.expected.msl
@@ -14,6 +14,10 @@
T elements[N];
};
+int2 tint_clamp(int2 e, int2 low, int2 high) {
+ return min(max(e, low), high);
+}
+
struct GammaTransferParams {
/* 0x0000 */ float G;
/* 0x0004 */ float A;
@@ -57,10 +61,6 @@
return float4(color, 1.0f);
}
-int2 tint_clamp(int2 e, int2 low, int2 high) {
- return min(max(e, low), high);
-}
-
kernel void tint_symbol(texture2d<float, access::sample> tint_symbol_5 [[texture(0)]], texture2d<float, access::sample> tint_symbol_6 [[texture(1)]], const constant ExternalTextureParams* tint_symbol_7 [[buffer(0)]], texture2d<float, access::write> tint_symbol_8 [[texture(2)]]) {
int2 const tint_symbol_1 = tint_clamp(int2(10), int2(0), int2((uint2(uint2(tint_symbol_5.get_width(), tint_symbol_5.get_height())) - uint2(1u))));
float4 red = textureLoadExternal(tint_symbol_5, tint_symbol_6, tint_symbol_1, *(tint_symbol_7));
diff --git a/test/tint/bug/tint/1739.wgsl.expected.spvasm b/test/tint/bug/tint/1739.wgsl.expected.spvasm
index 2cdbccb..2f01aa5 100644
--- a/test/tint/bug/tint/1739.wgsl.expected.spvasm
+++ b/test/tint/bug/tint/1739.wgsl.expected.spvasm
@@ -5,7 +5,7 @@
; Schema: 0
OpCapability Shader
OpCapability ImageQuery
- %25 = OpExtInstImport "GLSL.std.450"
+ %28 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
@@ -31,6 +31,10 @@
OpName %ext_tex_params "ext_tex_params"
OpName %t "t"
OpName %outImage "outImage"
+ OpName %tint_clamp "tint_clamp"
+ OpName %e "e"
+ OpName %low "low"
+ OpName %high "high"
OpName %gammaCorrection "gammaCorrection"
OpName %v "v"
OpName %params "params"
@@ -40,10 +44,6 @@
OpName %coord "coord"
OpName %params_0 "params"
OpName %color "color"
- OpName %tint_clamp "tint_clamp"
- OpName %e "e"
- OpName %low "low"
- OpName %high "high"
OpName %main "main"
OpName %red "red"
OpName %green "green"
@@ -95,20 +95,20 @@
%18 = OpTypeImage %float 2D 0 0 0 2 Rgba8
%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18
%outImage = OpVariable %_ptr_UniformConstant_18 UniformConstant
- %19 = OpTypeFunction %v3float %v3float %GammaTransferParams
+ %int = OpTypeInt 32 1
+ %v2int = OpTypeVector %int 2
+ %19 = OpTypeFunction %v2int %v2int %v2int %v2int
+ %30 = OpTypeFunction %v3float %v3float %GammaTransferParams
%bool = OpTypeBool
%v3bool = OpTypeVector %bool 3
%_ptr_Function_v3float = OpTypePointer Function %v3float
- %39 = OpConstantNull %v3float
- %int = OpTypeInt 32 1
- %v2int = OpTypeVector %int 2
- %59 = OpTypeFunction %v4float %3 %3 %v2int %ExternalTextureParams
+ %49 = OpConstantNull %v3float
+ %69 = OpTypeFunction %v4float %3 %3 %v2int %ExternalTextureParams
%uint_1 = OpConstant %uint 1
- %76 = OpConstantNull %int
+ %84 = OpConstantNull %int
%v2float = OpTypeVector %float 2
%float_1 = OpConstant %float 1
- %90 = OpConstantNull %uint
- %108 = OpTypeFunction %v2int %v2int %v2int %v2int
+ %98 = OpConstantNull %uint
%void = OpTypeVoid
%116 = OpTypeFunction %void
%int_10 = OpConstant %int 10
@@ -125,106 +125,106 @@
%int_118 = OpConstant %int 118
%154 = OpConstantComposite %v2int %int_70 %int_118
%int_1 = OpConstant %int 1
- %168 = OpConstantComposite %v2int %int_1 %76
-%gammaCorrection = OpFunction %v3float None %19
+ %168 = OpConstantComposite %v2int %int_1 %84
+ %tint_clamp = OpFunction %v2int None %19
+ %e = OpFunctionParameter %v2int
+ %low = OpFunctionParameter %v2int
+ %high = OpFunctionParameter %v2int
+ %26 = OpLabel
+ %29 = OpExtInst %v2int %28 SMax %e %low
+ %27 = OpExtInst %v2int %28 SMin %29 %high
+ OpReturnValue %27
+ OpFunctionEnd
+%gammaCorrection = OpFunction %v3float None %30
%v = OpFunctionParameter %v3float
%params = OpFunctionParameter %GammaTransferParams
- %23 = OpLabel
- %37 = OpVariable %_ptr_Function_v3float Function %39
- %49 = OpVariable %_ptr_Function_v3float Function %39
- %55 = OpVariable %_ptr_Function_v3float Function %39
- %24 = OpExtInst %v3float %25 FAbs %v
- %26 = OpCompositeExtract %float %params 4
- %27 = OpCompositeConstruct %v3float %26 %26 %26
- %28 = OpFOrdLessThan %v3bool %24 %27
- %31 = OpExtInst %v3float %25 FSign %v
- %32 = OpCompositeExtract %float %params 3
- %33 = OpExtInst %v3float %25 FAbs %v
- %34 = OpVectorTimesScalar %v3float %33 %32
- %35 = OpCompositeExtract %float %params 6
- %40 = OpCompositeConstruct %v3float %35 %35 %35
- %36 = OpFAdd %v3float %34 %40
- %41 = OpFMul %v3float %31 %36
- %42 = OpExtInst %v3float %25 FSign %v
- %44 = OpCompositeExtract %float %params 1
- %45 = OpExtInst %v3float %25 FAbs %v
- %46 = OpVectorTimesScalar %v3float %45 %44
- %47 = OpCompositeExtract %float %params 2
- %50 = OpCompositeConstruct %v3float %47 %47 %47
- %48 = OpFAdd %v3float %46 %50
- %51 = OpCompositeExtract %float %params 0
- %52 = OpCompositeConstruct %v3float %51 %51 %51
- %43 = OpExtInst %v3float %25 Pow %48 %52
- %53 = OpCompositeExtract %float %params 5
- %56 = OpCompositeConstruct %v3float %53 %53 %53
- %54 = OpFAdd %v3float %43 %56
- %57 = OpFMul %v3float %42 %54
- %58 = OpSelect %v3float %28 %41 %57
- OpReturnValue %58
+ %34 = OpLabel
+ %47 = OpVariable %_ptr_Function_v3float Function %49
+ %59 = OpVariable %_ptr_Function_v3float Function %49
+ %65 = OpVariable %_ptr_Function_v3float Function %49
+ %35 = OpExtInst %v3float %28 FAbs %v
+ %36 = OpCompositeExtract %float %params 4
+ %37 = OpCompositeConstruct %v3float %36 %36 %36
+ %38 = OpFOrdLessThan %v3bool %35 %37
+ %41 = OpExtInst %v3float %28 FSign %v
+ %42 = OpCompositeExtract %float %params 3
+ %43 = OpExtInst %v3float %28 FAbs %v
+ %44 = OpVectorTimesScalar %v3float %43 %42
+ %45 = OpCompositeExtract %float %params 6
+ %50 = OpCompositeConstruct %v3float %45 %45 %45
+ %46 = OpFAdd %v3float %44 %50
+ %51 = OpFMul %v3float %41 %46
+ %52 = OpExtInst %v3float %28 FSign %v
+ %54 = OpCompositeExtract %float %params 1
+ %55 = OpExtInst %v3float %28 FAbs %v
+ %56 = OpVectorTimesScalar %v3float %55 %54
+ %57 = OpCompositeExtract %float %params 2
+ %60 = OpCompositeConstruct %v3float %57 %57 %57
+ %58 = OpFAdd %v3float %56 %60
+ %61 = OpCompositeExtract %float %params 0
+ %62 = OpCompositeConstruct %v3float %61 %61 %61
+ %53 = OpExtInst %v3float %28 Pow %58 %62
+ %63 = OpCompositeExtract %float %params 5
+ %66 = OpCompositeConstruct %v3float %63 %63 %63
+ %64 = OpFAdd %v3float %53 %66
+ %67 = OpFMul %v3float %52 %64
+ %68 = OpSelect %v3float %38 %51 %67
+ OpReturnValue %68
OpFunctionEnd
-%textureLoadExternal = OpFunction %v4float None %59
+%textureLoadExternal = OpFunction %v4float None %69
%plane0 = OpFunctionParameter %3
%plane1 = OpFunctionParameter %3
%coord = OpFunctionParameter %v2int
%params_0 = OpFunctionParameter %ExternalTextureParams
- %67 = OpLabel
- %color = OpVariable %_ptr_Function_v3float Function %39
- %69 = OpCompositeExtract %uint %params_0 0
- %71 = OpIEqual %bool %69 %uint_1
- OpSelectionMerge %72 None
- OpBranchConditional %71 %73 %74
- %73 = OpLabel
- %75 = OpImageFetch %v4float %plane0 %coord Lod %76
- %77 = OpVectorShuffle %v3float %75 %75 0 1 2
- OpStore %color %77
- OpBranch %72
- %74 = OpLabel
- %78 = OpImageFetch %v4float %plane0 %coord Lod %76
- %79 = OpCompositeExtract %float %78 0
- %80 = OpImageFetch %v4float %plane1 %coord Lod %76
- %82 = OpVectorShuffle %v2float %80 %80 0 1
- %83 = OpCompositeExtract %float %82 0
- %84 = OpCompositeExtract %float %82 1
- %86 = OpCompositeConstruct %v4float %79 %83 %84 %float_1
- %87 = OpCompositeExtract %mat3v4float %params_0 2
- %88 = OpVectorTimesMatrix %v3float %86 %87
- OpStore %color %88
- OpBranch %72
- %72 = OpLabel
- %89 = OpCompositeExtract %uint %params_0 1
- %91 = OpIEqual %bool %89 %90
- OpSelectionMerge %92 None
- OpBranchConditional %91 %93 %92
- %93 = OpLabel
- %95 = OpLoad %v3float %color
- %96 = OpCompositeExtract %GammaTransferParams %params_0 3
- %94 = OpFunctionCall %v3float %gammaCorrection %95 %96
- OpStore %color %94
- %97 = OpCompositeExtract %mat3v3float %params_0 5
- %98 = OpLoad %v3float %color
- %99 = OpMatrixTimesVector %v3float %97 %98
- OpStore %color %99
- %101 = OpLoad %v3float %color
- %102 = OpCompositeExtract %GammaTransferParams %params_0 4
- %100 = OpFunctionCall %v3float %gammaCorrection %101 %102
- OpStore %color %100
- OpBranch %92
- %92 = OpLabel
+ %75 = OpLabel
+ %color = OpVariable %_ptr_Function_v3float Function %49
+ %77 = OpCompositeExtract %uint %params_0 0
+ %79 = OpIEqual %bool %77 %uint_1
+ OpSelectionMerge %80 None
+ OpBranchConditional %79 %81 %82
+ %81 = OpLabel
+ %83 = OpImageFetch %v4float %plane0 %coord Lod %84
+ %85 = OpVectorShuffle %v3float %83 %83 0 1 2
+ OpStore %color %85
+ OpBranch %80
+ %82 = OpLabel
+ %86 = OpImageFetch %v4float %plane0 %coord Lod %84
+ %87 = OpCompositeExtract %float %86 0
+ %88 = OpImageFetch %v4float %plane1 %coord Lod %84
+ %90 = OpVectorShuffle %v2float %88 %88 0 1
+ %91 = OpCompositeExtract %float %90 0
+ %92 = OpCompositeExtract %float %90 1
+ %94 = OpCompositeConstruct %v4float %87 %91 %92 %float_1
+ %95 = OpCompositeExtract %mat3v4float %params_0 2
+ %96 = OpVectorTimesMatrix %v3float %94 %95
+ OpStore %color %96
+ OpBranch %80
+ %80 = OpLabel
+ %97 = OpCompositeExtract %uint %params_0 1
+ %99 = OpIEqual %bool %97 %98
+ OpSelectionMerge %100 None
+ OpBranchConditional %99 %101 %100
+ %101 = OpLabel
%103 = OpLoad %v3float %color
- %104 = OpCompositeExtract %float %103 0
- %105 = OpCompositeExtract %float %103 1
- %106 = OpCompositeExtract %float %103 2
- %107 = OpCompositeConstruct %v4float %104 %105 %106 %float_1
- OpReturnValue %107
- OpFunctionEnd
- %tint_clamp = OpFunction %v2int None %108
- %e = OpFunctionParameter %v2int
- %low = OpFunctionParameter %v2int
- %high = OpFunctionParameter %v2int
- %113 = OpLabel
- %115 = OpExtInst %v2int %25 SMax %e %low
- %114 = OpExtInst %v2int %25 SMin %115 %high
- OpReturnValue %114
+ %104 = OpCompositeExtract %GammaTransferParams %params_0 3
+ %102 = OpFunctionCall %v3float %gammaCorrection %103 %104
+ OpStore %color %102
+ %105 = OpCompositeExtract %mat3v3float %params_0 5
+ %106 = OpLoad %v3float %color
+ %107 = OpMatrixTimesVector %v3float %105 %106
+ OpStore %color %107
+ %109 = OpLoad %v3float %color
+ %110 = OpCompositeExtract %GammaTransferParams %params_0 4
+ %108 = OpFunctionCall %v3float %gammaCorrection %109 %110
+ OpStore %color %108
+ OpBranch %100
+ %100 = OpLabel
+ %111 = OpLoad %v3float %color
+ %112 = OpCompositeExtract %float %111 0
+ %113 = OpCompositeExtract %float %111 1
+ %114 = OpCompositeExtract %float %111 2
+ %115 = OpCompositeConstruct %v4float %112 %113 %114 %float_1
+ OpReturnValue %115
OpFunctionEnd
%main = OpFunction %void None %116
%119 = OpLabel
diff --git a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
index c1df2b1..2799157 100644
--- a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
@@ -1,6 +1,3 @@
-Texture2D<float4> arg_0 : register(t0, space1);
-SamplerState arg_1 : register(s1, space1);
-
float4 tint_textureSampleBaseClampToEdge(Texture2D<float4> t, SamplerState s, float2 coord) {
int3 tint_tmp;
t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z);
@@ -10,6 +7,9 @@
return t.SampleLevel(s, clamped, 0.0f);
}
+Texture2D<float4> arg_0 : register(t0, space1);
+SamplerState arg_1 : register(s1, space1);
+
void textureSampleBaseClampToEdge_9ca02c() {
float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, (0.0f).xx);
}
diff --git a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
index c1df2b1..2799157 100644
--- a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
@@ -1,6 +1,3 @@
-Texture2D<float4> arg_0 : register(t0, space1);
-SamplerState arg_1 : register(s1, space1);
-
float4 tint_textureSampleBaseClampToEdge(Texture2D<float4> t, SamplerState s, float2 coord) {
int3 tint_tmp;
t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z);
@@ -10,6 +7,9 @@
return t.SampleLevel(s, clamped, 0.0f);
}
+Texture2D<float4> arg_0 : register(t0, space1);
+SamplerState arg_1 : register(s1, space1);
+
void textureSampleBaseClampToEdge_9ca02c() {
float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, (0.0f).xx);
}
diff --git a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
index 870d5d4..1e3e8bd 100644
--- a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
+++ b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl
@@ -1,6 +1,3 @@
-Texture2D<float4> arg_0 : register(t0, space1);
-SamplerState arg_1 : register(s1, space1);
-
float4 tint_textureSampleBaseClampToEdge(Texture2D<float4> t, SamplerState s, float2 coord) {
int3 tint_tmp;
t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z);
@@ -10,6 +7,9 @@
return t.SampleLevel(s, clamped, 0.0f);
}
+Texture2D<float4> arg_0 : register(t0, space1);
+SamplerState arg_1 : register(s1, space1);
+
void textureSampleBaseClampToEdge_9ca02c() {
float2 arg_2 = (0.0f).xx;
float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, arg_2);
diff --git a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
index 870d5d4..1e3e8bd 100644
--- a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
+++ b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl
@@ -1,6 +1,3 @@
-Texture2D<float4> arg_0 : register(t0, space1);
-SamplerState arg_1 : register(s1, space1);
-
float4 tint_textureSampleBaseClampToEdge(Texture2D<float4> t, SamplerState s, float2 coord) {
int3 tint_tmp;
t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z);
@@ -10,6 +7,9 @@
return t.SampleLevel(s, clamped, 0.0f);
}
+Texture2D<float4> arg_0 : register(t0, space1);
+SamplerState arg_1 : register(s1, space1);
+
void textureSampleBaseClampToEdge_9ca02c() {
float2 arg_2 = (0.0f).xx;
float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, arg_2);