Rework the FirstIndexOffset transform
So that it transforms more on clone than in-place.
Bug: dawn:548
Bug: tint:390
Change-Id: I0127bc02c4e0e88c924042c491d274363422cc52
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35420
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/module.h b/src/ast/module.h
index 6b08b4f..59a67b0 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -56,6 +56,10 @@
Module Clone();
/// Clone this module into `ctx->mod` using the provided CloneContext
+ /// The module will be cloned in this order:
+ /// * Constructed types
+ /// * Global variables
+ /// * Functions
/// @param ctx the clone context
void Clone(CloneContext* ctx);
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index fcca0c6..37014d0 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -25,6 +25,7 @@
#include "src/ast/builtin_decoration.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
+#include "src/ast/clone_context.h"
#include "src/ast/constructor_expression.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
@@ -61,6 +62,20 @@
constexpr char kFirstInstanceName[] = "tint_first_instance_index";
constexpr char kIndexOffsetPrefix[] = "tint_first_index_offset_";
+ast::DecoratedVariable* clone_variable_with_new_name(ast::CloneContext* ctx,
+ ast::DecoratedVariable* in,
+ std::string new_name) {
+ auto* var = ctx->mod->create<ast::Variable>(ctx->Clone(in->source()),
+ new_name, in->storage_class(),
+ ctx->Clone(in->type()));
+ var->set_is_const(in->is_const());
+ var->set_constructor(ctx->Clone(in->constructor()));
+
+ auto* out = ctx->mod->create<ast::DecoratedVariable>(var);
+ out->set_decorations(ctx->Clone(in->decorations()));
+ return out;
+}
+
} // namespace
FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t set)
@@ -69,17 +84,29 @@
FirstIndexOffset::~FirstIndexOffset() = default;
Transform::Output FirstIndexOffset::Run(ast::Module* in) {
- Output out;
- out.module = in->Clone();
- auto* mod = &out.module;
+ // First do a quick check to see if the transform has already been applied.
+ for (ast::Variable* var : in->global_variables()) {
+ if (auto* dec_var = var->As<ast::DecoratedVariable>()) {
+ if (dec_var->name() == kBufferName) {
+ diag::Diagnostic err;
+ err.message = "First index offset transform has already been applied.";
+ err.severity = diag::Severity::Error;
+ Output out;
+ out.diagnostics.add(std::move(err));
+ return out;
+ }
+ }
+ }
// Running TypeDeterminer as we require local_referenced_builtin_variables()
- // to be populated
- TypeDeterminer td(mod);
+ // to be populated. TODO(bclayton) - it should not be necessary to re-run the
+ // type determiner if semantic information is already generated. Remove.
+ TypeDeterminer td(in);
if (!td.Determine()) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = td.error();
+ Output out;
out.diagnostics.add(std::move(err));
return out;
}
@@ -87,51 +114,69 @@
std::string vertex_index_name;
std::string instance_index_name;
- for (ast::Variable* var : mod->global_variables()) {
- if (auto* dec_var = var->As<ast::DecoratedVariable>()) {
- if (dec_var->name() == kBufferName) {
- diag::Diagnostic err;
- err.message = "First index offset transform has already been applied.";
- err.severity = diag::Severity::Error;
- out.diagnostics.add(std::move(err));
- return out;
- }
+ Output out;
- for (ast::VariableDecoration* dec : dec_var->decorations()) {
- if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
- ast::Builtin blt_type = blt_dec->value();
- if (blt_type == ast::Builtin::kVertexIdx) {
- vertex_index_name = var->name();
- var->set_name(kIndexOffsetPrefix + var->name());
- has_vertex_index_ = true;
- } else if (blt_type == ast::Builtin::kInstanceIdx) {
- instance_index_name = var->name();
- var->set_name(kIndexOffsetPrefix + var->name());
- has_instance_index_ = true;
- }
+ // Lazilly construct the UniformBuffer on first call to
+ // maybe_create_buffer_var()
+ ast::Variable* buffer_var = nullptr;
+ auto maybe_create_buffer_var = [&] {
+ if (buffer_var == nullptr) {
+ buffer_var = AddUniformBuffer(&out.module);
+ }
+ };
+
+ // Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add
+ // a CreateFirstIndexOffset() statement to each function that uses one of
+ // these builtins.
+ ast::CloneContext ctx(&out.module);
+ ctx.ReplaceAll([&](ast::DecoratedVariable* var) -> ast::DecoratedVariable* {
+ for (ast::VariableDecoration* dec : var->decorations()) {
+ if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
+ ast::Builtin blt_type = blt_dec->value();
+ if (blt_type == ast::Builtin::kVertexIdx) {
+ vertex_index_name = var->name();
+ has_vertex_index_ = true;
+ return clone_variable_with_new_name(&ctx, var,
+ kIndexOffsetPrefix + var->name());
+ } else if (blt_type == ast::Builtin::kInstanceIdx) {
+ instance_index_name = var->name();
+ has_instance_index_ = true;
+ return clone_variable_with_new_name(&ctx, var,
+ kIndexOffsetPrefix + var->name());
}
}
}
- }
+ return nullptr; // Just clone var
+ });
+ ctx.ReplaceAll( // Note: This happens in the same pass as the rename above
+ // which determines the original builtin variable names,
+ // but this should be fine, as variables are cloned first.
+ [&](ast::Function* func) -> ast::Function* {
+ maybe_create_buffer_var();
+ if (buffer_var == nullptr) {
+ return nullptr; // no transform need, just clone func
+ }
+ auto* body = ctx.mod->create<ast::BlockStatement>(
+ ctx.Clone(func->body()->source()));
+ for (const auto& data : func->local_referenced_builtin_variables()) {
+ if (data.second->value() == ast::Builtin::kVertexIdx) {
+ body->append(CreateFirstIndexOffset(
+ vertex_index_name, kFirstVertexName, buffer_var, ctx.mod));
+ } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
+ body->append(CreateFirstIndexOffset(
+ instance_index_name, kFirstInstanceName, buffer_var, ctx.mod));
+ }
+ }
+ for (auto* s : *func->body()) {
+ body->append(ctx.Clone(s));
+ }
+ return ctx.mod->create<ast::Function>(
+ ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()),
+ ctx.Clone(func->return_type()), ctx.Clone(body),
+ ctx.Clone(func->decorations()));
+ });
- if (!has_vertex_index_ && !has_instance_index_) {
- return out;
- }
-
- ast::Variable* buffer_var = AddUniformBuffer(mod);
-
- for (ast::Function* func : mod->functions()) {
- for (const auto& data : func->local_referenced_builtin_variables()) {
- if (data.second->value() == ast::Builtin::kVertexIdx) {
- AddFirstIndexOffset(vertex_index_name, kFirstVertexName, buffer_var,
- func, mod);
- } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
- AddFirstIndexOffset(instance_index_name, kFirstInstanceName, buffer_var,
- func, mod);
- }
- }
- }
-
+ in->Clone(&ctx);
return out;
}
@@ -187,12 +232,10 @@
auto* idx_var =
mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
Source{}, kBufferName, ast::StorageClass::kUniform, struct_type));
-
- ast::VariableDecorationList decorations;
- decorations.push_back(
- mod->create<ast::BindingDecoration>(binding_, Source{}));
- decorations.push_back(mod->create<ast::SetDecoration>(set_, Source{}));
- idx_var->set_decorations(std::move(decorations));
+ idx_var->set_decorations({
+ mod->create<ast::BindingDecoration>(binding_, Source{}),
+ mod->create<ast::SetDecoration>(set_, Source{}),
+ });
mod->AddGlobalVariable(idx_var);
@@ -201,11 +244,11 @@
return idx_var;
}
-void FirstIndexOffset::AddFirstIndexOffset(const std::string& original_name,
- const std::string& field_name,
- ast::Variable* buffer_var,
- ast::Function* func,
- ast::Module* mod) {
+ast::VariableDeclStatement* FirstIndexOffset::CreateFirstIndexOffset(
+ const std::string& original_name,
+ const std::string& field_name,
+ ast::Variable* buffer_var,
+ ast::Module* mod) {
auto* buffer = mod->create<ast::IdentifierExpression>(buffer_var->name());
auto* var = mod->create<ast::Variable>(Source{}, original_name,
ast::StorageClass::kNone,
@@ -217,8 +260,7 @@
mod->create<ast::IdentifierExpression>(kIndexOffsetPrefix + var->name()),
mod->create<ast::MemberAccessorExpression>(
buffer, mod->create<ast::IdentifierExpression>(field_name))));
- func->body()->insert(0,
- mod->create<ast::VariableDeclStatement>(std::move(var)));
+ return mod->create<ast::VariableDeclStatement>(var);
}
} // namespace transform
diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h
index 163bfbb..873ffc8 100644
--- a/src/transform/first_index_offset.h
+++ b/src/transform/first_index_offset.h
@@ -18,6 +18,7 @@
#include <string>
#include "src/ast/module.h"
+#include "src/ast/variable_decl_statement.h"
#include "src/transform/transform.h"
namespace tint {
@@ -94,12 +95,12 @@
/// @param original_name the name of the original builtin used in function
/// @param field_name name of field in firstVertex/Instance buffer
/// @param buffer_var variable of firstVertex/Instance buffer
- /// @param func function to modify
- void AddFirstIndexOffset(const std::string& original_name,
- const std::string& field_name,
- ast::Variable* buffer_var,
- ast::Function* func,
- ast::Module* module);
+ /// @param module the target module to contain the new ast nodes
+ ast::VariableDeclStatement* CreateFirstIndexOffset(
+ const std::string& original_name,
+ const std::string& field_name,
+ ast::Variable* buffer_var,
+ ast::Module* module);
uint32_t binding_;
uint32_t set_;
diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc
index 3e20748..b1dbdd0 100644
--- a/src/transform/first_index_offset_test.cc
+++ b/src/transform/first_index_offset_test.cc
@@ -76,6 +76,8 @@
struct Builder : public ModuleBuilder {
void Build() override {
AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
+ AddFunction("test")->body()->append(create<ast::ReturnStatement>(
+ Source{}, create<ast::IdentifierExpression>("vert_idx")));
}
};
@@ -106,15 +108,16 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- EXPECT_EQ("Module{\n}\n", result.module.to_str());
+ auto got = result.module.to_str();
+ auto* expected = "Module{\n}\n";
+ EXPECT_EQ(got, expected);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
struct Builder : public ModuleBuilder {
void Build() override {
AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
- ast::Function* func = AddFunction("test");
- func->body()->append(create<ast::ReturnStatement>(
+ AddFunction("test")->body()->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("vert_idx")));
}
};
@@ -131,7 +134,9 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- EXPECT_EQ(R"(Module{
+ auto got = result.module.to_str();
+ auto* expected =
+ R"(Module{
TintFirstIndexOffsetData Struct{
[[block]]
StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -180,14 +185,16 @@
}
}
}
-)",
- result.module.to_str());
+)";
+ EXPECT_EQ(got, expected);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
struct Builder : public ModuleBuilder {
void Build() override {
AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx);
+ AddFunction("test")->body()->append(create<ast::ReturnStatement>(
+ Source{}, create<ast::IdentifierExpression>("inst_idx")));
}
};
@@ -202,7 +209,9 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- EXPECT_EQ(R"(Module{
+
+ auto got = result.module.to_str();
+ auto* expected = R"(Module{
TintFirstIndexOffsetData Struct{
[[block]]
StructMember{[[ offset 0 ]] tint_first_instance_index: __u32}
@@ -224,9 +233,35 @@
uniform
__struct_TintFirstIndexOffsetData
}
+ Function test -> __u32
+ ()
+ {
+ VariableDeclStatement{
+ VariableConst{
+ inst_idx
+ none
+ __u32
+ {
+ Binary[__u32]{
+ Identifier[__ptr_in__u32]{tint_first_index_offset_inst_idx}
+ add
+ MemberAccessor[__ptr_uniform__u32]{
+ Identifier[__ptr_uniform__struct_TintFirstIndexOffsetData]{tint_first_index_data}
+ Identifier[not set]{tint_first_instance_index}
+ }
+ }
+ }
+ }
+ }
+ Return{
+ {
+ Identifier[__u32]{inst_idx}
+ }
+ }
+ }
}
-)",
- result.module.to_str());
+)";
+ EXPECT_EQ(got, expected);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
@@ -234,6 +269,8 @@
void Build() override {
AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx);
AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
+ AddFunction("test")->body()->append(
+ create<ast::ReturnStatement>(Source{}, Expr(1u)));
}
};
@@ -251,7 +288,9 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- EXPECT_EQ(R"(Module{
+
+ auto got = result.module.to_str();
+ auto* expected = R"(Module{
TintFirstIndexOffsetData Struct{
[[block]]
StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -282,9 +321,18 @@
uniform
__struct_TintFirstIndexOffsetData
}
+ Function test -> __u32
+ ()
+ {
+ Return{
+ {
+ ScalarConstructor[__u32]{1}
+ }
+ }
+ }
}
-)",
- result.module.to_str());
+)";
+ EXPECT_EQ(got, expected);
EXPECT_TRUE(transform_ptr->HasVertexIndex());
EXPECT_EQ(transform_ptr->GetFirstVertexOffset(), 0u);
@@ -321,7 +369,9 @@
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
- EXPECT_EQ(R"(Module{
+
+ auto got = result.module.to_str();
+ auto* expected = R"(Module{
TintFirstIndexOffsetData Struct{
[[block]]
StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -383,8 +433,8 @@
}
}
}
-)",
- result.module.to_str());
+)";
+ EXPECT_EQ(got, expected);
}
} // namespace