Remove BlockStatement::insert()
Bug: tint:396
Bug: tint:390
Change-Id: I719b84804164fa801ded505ed56717948f06c7a7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35502
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index 9b5417a..d485bdf 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -35,14 +35,6 @@
BlockStatement(BlockStatement&&);
~BlockStatement() override;
- /// Insert a statement to the block
- /// @param index the index to insert at
- /// @param stmt the statement to insert
- void insert(size_t index, Statement* stmt) {
- auto offset = static_cast<decltype(statements_)::difference_type>(index);
- statements_.insert(statements_.begin() + offset, stmt);
- }
-
/// @returns true if the block is empty
bool empty() const { return statements_.empty(); }
/// @returns the number of statements directly in the block
@@ -60,16 +52,12 @@
/// Retrieves the statement at `idx`
/// @param idx the index. The index is not bounds checked.
/// @returns the statement at `idx`
- const Statement* get(size_t idx) const { return statements_[idx]; }
+ Statement* get(size_t idx) const { return statements_[idx]; }
/// Retrieves the statement at `idx`
/// @param idx the index. The index is not bounds checked.
/// @returns the statement at `idx`
- Statement* operator[](size_t idx) { return statements_[idx]; }
- /// Retrieves the statement at `idx`
- /// @param idx the index. The index is not bounds checked.
- /// @returns the statement at `idx`
- const Statement* operator[](size_t idx) const { return statements_[idx]; }
+ Statement* operator[](size_t idx) const { return statements_[idx]; }
/// @returns the beginning iterator
StatementList::const_iterator begin() const { return statements_.begin(); }
diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc
index 71b71d2..78134c6 100644
--- a/src/ast/block_statement_test.cc
+++ b/src/ast/block_statement_test.cc
@@ -37,24 +37,6 @@
EXPECT_EQ(b[0], ptr);
}
-TEST_F(BlockStatementTest, Creation_WithInsert) {
- auto* s1 = create<DiscardStatement>(Source{});
- auto* s2 = create<DiscardStatement>(Source{});
- auto* s3 = create<DiscardStatement>(Source{});
-
- BlockStatement b(Source{}, StatementList{});
- b.insert(0, s1);
- b.insert(0, s2);
- b.insert(1, s3);
-
- // |b| should contain s2, s3, s1
-
- ASSERT_EQ(b.size(), 3u);
- EXPECT_EQ(b[0], s2);
- EXPECT_EQ(b[1], s3);
- EXPECT_EQ(b[2], s1);
-}
-
TEST_F(BlockStatementTest, Creation_WithSource) {
BlockStatement b(Source{Source::Location{20, 2}}, ast::StatementList{});
auto src = b.source();
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 6d5cf58..a357554 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -33,15 +33,30 @@
Module Module::Clone() {
Module out;
CloneContext ctx(&out);
- Clone(&ctx);
+
+ // Symbol table must be cloned first so that the resulting module has the
+ // symbols before we start the tree mutations.
+ ctx.mod->symbol_table_ = symbol_table_;
+
+ CloneUsing(&ctx);
return out;
}
-void Module::Clone(CloneContext* ctx) {
+Module Module::Clone(const std::function<void(CloneContext* ctx)>& init) {
+ Module out;
+ CloneContext ctx(&out);
+
// Symbol table must be cloned first so that the resulting module has the
// symbols before we start the tree mutations.
- ctx->mod->symbol_table_ = symbol_table_;
+ ctx.mod->symbol_table_ = symbol_table_;
+ init(&ctx);
+
+ CloneUsing(&ctx);
+ return out;
+}
+
+void Module::CloneUsing(CloneContext* ctx) {
for (auto* ty : constructed_types_) {
ctx->mod->constructed_types_.emplace_back(ctx->Clone(ty));
}
diff --git a/src/ast/module.h b/src/ast/module.h
index 1facd79..e5d153d 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -15,6 +15,7 @@
#ifndef SRC_AST_MODULE_H_
#define SRC_AST_MODULE_H_
+#include <functional>
#include <memory>
#include <string>
#include <type_traits>
@@ -55,13 +56,11 @@
/// @return a deep copy of this module
Module Clone();
- /// Clone this module into `ctx->mod` using the provided CloneContext
- /// The module will be cloned in this order:
- /// * Constructed types
- /// * Global variables
- /// * Functions
- /// @param ctx the clone context
- void Clone(CloneContext* ctx);
+ /// @param init a callback function to configure the CloneContex before
+ /// cloning any of the module's state
+ /// @return a deep copy of this module, calling `init` to first initialize the
+ /// context.
+ Module Clone(const std::function<void(CloneContext* ctx)>& init);
/// Add a global variable to the module
/// @param var the variable to add
@@ -181,6 +180,14 @@
private:
Module(const Module&) = delete;
+ /// Clone this module into `ctx->mod` using the provided CloneContext
+ /// The module will be cloned in this order:
+ /// * Constructed types
+ /// * Global variables
+ /// * Functions
+ /// @param ctx the clone context
+ void CloneUsing(CloneContext* ctx);
+
SymbolTable symbol_table_;
VariableList global_variables_;
// The constructed types are owned by the type manager
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 6219b32..663bdc6 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1418,7 +1418,12 @@
// body_stmt
// : BRACKET_LEFT statements BRACKET_RIGHT
Expect<ast::BlockStatement*> ParserImpl::expect_body_stmt() {
- return expect_brace_block("", [&] { return expect_statements(); });
+ return expect_brace_block("", [&]() -> Expect<ast::BlockStatement*> {
+ auto stmts = expect_statements();
+ if (stmts.errored)
+ return Failure::kErrored;
+ return create<ast::BlockStatement>(Source{}, stmts.value);
+ });
}
// paren_rhs_stmt
@@ -1437,7 +1442,7 @@
// statements
// : statement*
-Expect<ast::BlockStatement*> ParserImpl::expect_statements() {
+Expect<ast::StatementList> ParserImpl::expect_statements() {
bool errored = false;
ast::StatementList stmts;
@@ -1455,7 +1460,7 @@
if (errored)
return Failure::kErrored;
- return create<ast::BlockStatement>(Source{}, stmts);
+ return stmts;
}
// statement
@@ -1859,15 +1864,16 @@
return Failure::kNoMatch;
return expect_brace_block("loop", [&]() -> Maybe<ast::LoopStatement*> {
- auto body = expect_statements();
- if (body.errored)
+ auto stmts = expect_statements();
+ if (stmts.errored)
return Failure::kErrored;
auto continuing = continuing_stmt();
if (continuing.errored)
return Failure::kErrored;
- return create<ast::LoopStatement>(source, body.value, continuing.value);
+ auto* body = create<ast::BlockStatement>(source, stmts.value);
+ return create<ast::LoopStatement>(source, body, continuing.value);
});
}
@@ -1958,9 +1964,9 @@
if (header.errored)
return Failure::kErrored;
- auto body =
+ auto stmts =
expect_brace_block("for loop", [&] { return expect_statements(); });
- if (body.errored)
+ if (stmts.errored)
return Failure::kErrored;
// The for statement is a syntactic sugar on top of the loop statement.
@@ -1980,7 +1986,7 @@
auto* break_if_not_condition =
create<ast::IfStatement>(not_condition->source(), not_condition,
break_body, ast::ElseStatementList{});
- body->insert(0, break_if_not_condition);
+ stmts.value.insert(stmts.value.begin(), break_if_not_condition);
}
ast::BlockStatement* continuing_body = nullptr;
@@ -1991,7 +1997,8 @@
});
}
- auto* loop = create<ast::LoopStatement>(source, body.value, continuing_body);
+ auto* body = create<ast::BlockStatement>(source, stmts.value);
+ auto* loop = create<ast::LoopStatement>(source, body, continuing_body);
if (header->initializer != nullptr) {
return create<ast::BlockStatement>(source, ast::StatementList{
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 1bdd73b..c523d06 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -468,7 +468,7 @@
Expect<ast::Expression*> expect_paren_rhs_stmt();
/// Parses a `statements` grammar element
/// @returns the statements parsed
- Expect<ast::BlockStatement*> expect_statements();
+ Expect<ast::StatementList> expect_statements();
/// Parses a `statement` grammar element
/// @returns the parsed statement or nullptr
Maybe<ast::Statement*> statement();
diff --git a/src/reader/wgsl/parser_impl_for_stmt_test.cc b/src/reader/wgsl/parser_impl_for_stmt_test.cc
index 2172e21..b112779 100644
--- a/src/reader/wgsl/parser_impl_for_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_for_stmt_test.cc
@@ -15,6 +15,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "src/ast/block_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -30,15 +31,15 @@
auto e_loop = p_loop->expect_statements();
EXPECT_FALSE(e_loop.errored);
EXPECT_FALSE(p_loop->has_error()) << p_loop->error();
- ASSERT_NE(e_loop.value, nullptr);
auto p_for = parser(for_str);
auto e_for = p_for->expect_statements();
EXPECT_FALSE(e_for.errored);
EXPECT_FALSE(p_for->has_error()) << p_for->error();
- ASSERT_NE(e_for.value, nullptr);
- EXPECT_EQ(e_loop->str(), e_for->str());
+ std::string loop = ast::BlockStatement({}, e_loop.value).str();
+ std::string for_ = ast::BlockStatement({}, e_for.value).str();
+ EXPECT_EQ(loop, for_);
}
};
diff --git a/src/reader/wgsl/parser_impl_statements_test.cc b/src/reader/wgsl/parser_impl_statements_test.cc
index 88b6012..d33176c 100644
--- a/src/reader/wgsl/parser_impl_statements_test.cc
+++ b/src/reader/wgsl/parser_impl_statements_test.cc
@@ -29,8 +29,8 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e->size(), 2u);
- EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
- EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
+ EXPECT_TRUE(e.value[0]->Is<ast::DiscardStatement>());
+ EXPECT_TRUE(e.value[1]->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, Statements_Empty) {
diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc
index d68da7c..3adf6b7 100644
--- a/src/transform/bound_array_accessors.cc
+++ b/src/transform/bound_array_accessors.cc
@@ -55,11 +55,11 @@
Transform::Output BoundArrayAccessors::Run(ast::Module* mod) {
Output out;
- ast::CloneContext ctx(&out.module);
- ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
- return Transform(expr, &ctx, &out.diagnostics);
+ out.module = mod->Clone([&](ast::CloneContext* ctx) {
+ ctx->ReplaceAll([&, ctx](ast::ArrayAccessorExpression* expr) {
+ return Transform(expr, ctx, &out.diagnostics);
+ });
});
- mod->Clone(&ctx);
return out;
}
diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc
index 1d22a26..d8640d8 100644
--- a/src/transform/emit_vertex_point_size.cc
+++ b/src/transform/emit_vertex_point_size.cc
@@ -19,6 +19,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/block_statement.h"
+#include "src/ast/clone_context.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/scalar_constructor_expression.h"
@@ -39,45 +40,56 @@
Transform::Output EmitVertexPointSize::Run(ast::Module* in) {
Output out;
- out.module = in->Clone();
- auto* mod = &out.module;
- if (!mod->HasStage(ast::PipelineStage::kVertex)) {
+ if (!in->HasStage(ast::PipelineStage::kVertex)) {
// If the module doesn't have any vertex stages, then there's nothing to do.
+ out.module = in->Clone();
return out;
}
- auto* f32 = mod->create<ast::type::F32>();
+ tint::ast::AssignmentStatement* pointsize_assign = nullptr;
+ auto get_pointsize_assign = [&pointsize_assign](ast::Module* mod) {
+ if (pointsize_assign != nullptr) {
+ return pointsize_assign;
+ }
- // Declare the pointsize builtin output variable.
- auto* pointsize_var =
- mod->create<ast::Variable>(Source{}, // source
- kPointSizeVar, // name
- ast::StorageClass::kOutput, // storage_class
- f32, // type
- false, // is_const
- nullptr, // constructor
- ast::VariableDecorationList{
- // decorations
- mod->create<ast::BuiltinDecoration>(
- ast::Builtin::kPointSize, Source{}),
- });
- mod->AddGlobalVariable(pointsize_var);
+ auto* f32 = mod->create<ast::type::F32>();
- // Build the AST expression & statement for assigning pointsize one.
- auto* one = mod->create<ast::ScalarConstructorExpression>(
- Source{}, mod->create<ast::FloatLiteral>(Source{}, f32, 1.0f));
- auto* pointsize_ident = mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar);
- auto* pointsize_assign =
- mod->create<ast::AssignmentStatement>(Source{}, pointsize_ident, one);
+ // Declare the pointsize builtin output variable.
+ auto* pointsize_var =
+ mod->create<ast::Variable>(Source{}, // source
+ kPointSizeVar, // name
+ ast::StorageClass::kOutput, // storage_class
+ f32, // type
+ false, // is_const
+ nullptr, // constructor
+ ast::VariableDecorationList{
+ // decorations
+ mod->create<ast::BuiltinDecoration>(
+ ast::Builtin::kPointSize, Source{}),
+ });
+ mod->AddGlobalVariable(pointsize_var);
+
+ // Build the AST expression & statement for assigning pointsize one.
+ auto* one = mod->create<ast::ScalarConstructorExpression>(
+ Source{}, mod->create<ast::FloatLiteral>(Source{}, f32, 1.0f));
+ auto* pointsize_ident = mod->create<ast::IdentifierExpression>(
+ Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar);
+ pointsize_assign =
+ mod->create<ast::AssignmentStatement>(Source{}, pointsize_ident, one);
+ return pointsize_assign;
+ };
// Add the pointsize assignment statement to the front of all vertex stages.
- for (auto* func : mod->functions()) {
- if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
- func->body()->insert(0, pointsize_assign);
- }
- }
+ out.module = in->Clone([&](ast::CloneContext* ctx) {
+ ctx->ReplaceAll([&, ctx](ast::Function* func) -> ast::Function* {
+ if (func->pipeline_stage() != ast::PipelineStage::kVertex) {
+ return nullptr; // Just clone func
+ }
+ return CloneWithStatementsAtStart(ctx, func,
+ {get_pointsize_assign(ctx->mod)});
+ });
+ });
return out;
}
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 97c7670..807a2c6 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -112,70 +112,63 @@
std::string vertex_index_name;
std::string instance_index_name;
- Output out;
-
// Lazilly construct the UniformBuffer on first call to
// maybe_create_buffer_var()
ast::Variable* buffer_var = nullptr;
- auto maybe_create_buffer_var = [&] {
+ auto maybe_create_buffer_var = [&](ast::Module* mod) {
if (buffer_var == nullptr) {
- buffer_var = AddUniformBuffer(&out.module);
+ buffer_var = AddUniformBuffer(mod);
}
};
// Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add
// a CreateFirstIndexOffset() statement to each function that uses one of
// these builtins.
- ast::CloneContext ctx(&out.module);
- ctx.ReplaceAll([&](ast::Variable* var) -> ast::Variable* {
- for (ast::VariableDecoration* dec : var->decorations()) {
- if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
- ast::Builtin blt_type = blt_dec->value();
- if (blt_type == ast::Builtin::kVertexIdx) {
- vertex_index_name = var->name();
- has_vertex_index_ = true;
- return clone_variable_with_new_name(&ctx, var,
- kIndexOffsetPrefix + var->name());
- } else if (blt_type == ast::Builtin::kInstanceIdx) {
- instance_index_name = var->name();
- has_instance_index_ = true;
- return clone_variable_with_new_name(&ctx, var,
- kIndexOffsetPrefix + var->name());
- }
- }
- }
- return nullptr; // Just clone var
- });
- ctx.ReplaceAll( // Note: This happens in the same pass as the rename above
- // which determines the original builtin variable names,
- // but this should be fine, as variables are cloned first.
- [&](ast::Function* func) -> ast::Function* {
- maybe_create_buffer_var();
- if (buffer_var == nullptr) {
- return nullptr; // no transform need, just clone func
- }
- ast::StatementList statements;
- for (const auto& data : func->local_referenced_builtin_variables()) {
- if (data.second->value() == ast::Builtin::kVertexIdx) {
- statements.emplace_back(CreateFirstIndexOffset(
- vertex_index_name, kFirstVertexName, buffer_var, ctx.mod));
- } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
- statements.emplace_back(CreateFirstIndexOffset(
- instance_index_name, kFirstInstanceName, buffer_var, ctx.mod));
+
+ Output out;
+ out.module = in->Clone([&](ast::CloneContext* ctx) {
+ ctx->ReplaceAll([&, ctx](ast::Variable* var) -> ast::Variable* {
+ for (ast::VariableDecoration* dec : var->decorations()) {
+ if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
+ ast::Builtin blt_type = blt_dec->value();
+ if (blt_type == ast::Builtin::kVertexIdx) {
+ vertex_index_name = var->name();
+ has_vertex_index_ = true;
+ return clone_variable_with_new_name(
+ ctx, var, kIndexOffsetPrefix + var->name());
+ } else if (blt_type == ast::Builtin::kInstanceIdx) {
+ instance_index_name = var->name();
+ has_instance_index_ = true;
+ return clone_variable_with_new_name(
+ ctx, var, kIndexOffsetPrefix + var->name());
}
}
- for (auto* s : *func->body()) {
- statements.emplace_back(ctx.Clone(s));
- }
- return ctx.mod->create<ast::Function>(
- ctx.Clone(func->source()), func->symbol(), func->name(),
- ctx.Clone(func->params()), ctx.Clone(func->return_type()),
- ctx.mod->create<ast::BlockStatement>(
- ctx.Clone(func->body()->source()), statements),
- ctx.Clone(func->decorations()));
- });
+ }
+ return nullptr; // Just clone var
+ });
+ ctx->ReplaceAll( // Note: This happens in the same pass as the rename above
+ // which determines the original builtin variable names,
+ // but this should be fine, as variables are cloned first.
+ [&, ctx](ast::Function* func) -> ast::Function* {
+ maybe_create_buffer_var(ctx->mod);
+ if (buffer_var == nullptr) {
+ return nullptr; // no transform need, just clone func
+ }
+ ast::StatementList statements;
+ for (const auto& data : func->local_referenced_builtin_variables()) {
+ if (data.second->value() == ast::Builtin::kVertexIdx) {
+ statements.emplace_back(CreateFirstIndexOffset(
+ vertex_index_name, kFirstVertexName, buffer_var, ctx->mod));
+ } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
+ statements.emplace_back(CreateFirstIndexOffset(
+ instance_index_name, kFirstInstanceName, buffer_var,
+ ctx->mod));
+ }
+ }
+ return CloneWithStatementsAtStart(ctx, func, statements);
+ });
+ });
- in->Clone(&ctx);
return out;
}
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index f6bdfd2..a03b943 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -14,11 +14,30 @@
#include "src/transform/transform.h"
+#include "src/ast/block_statement.h"
+#include "src/ast/clone_context.h"
+#include "src/ast/function.h"
+
namespace tint {
namespace transform {
Transform::Transform() = default;
Transform::~Transform() = default;
+ast::Function* Transform::CloneWithStatementsAtStart(
+ ast::CloneContext* ctx,
+ ast::Function* in,
+ ast::StatementList statements) {
+ for (auto* s : *in->body()) {
+ statements.emplace_back(ctx->Clone(s));
+ }
+ return ctx->mod->create<ast::Function>(
+ ctx->Clone(in->source()), in->symbol(), in->name(),
+ ctx->Clone(in->params()), ctx->Clone(in->return_type()),
+ ctx->mod->create<ast::BlockStatement>(ctx->Clone(in->body()->source()),
+ statements),
+ ctx->Clone(in->decorations()));
+}
+
} // namespace transform
} // namespace tint
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 2a68467..211be8a 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -48,6 +48,18 @@
/// @param module the source module to transform
/// @returns the transformation result
virtual Output Run(ast::Module* module) = 0;
+
+ protected:
+ /// Clones the function `in` adding `statements` to the beginning of the
+ /// cloned function body.
+ /// @param ctx the clone context
+ /// @param in the function to clone
+ /// @param statements the statements to prepend to `in`'s body
+ /// @return the cloned function
+ static ast::Function* CloneWithStatementsAtStart(
+ ast::CloneContext* ctx,
+ ast::Function* in,
+ ast::StatementList statements);
};
} // namespace transform
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 163d875..74feab3 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -20,6 +20,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
+#include "src/ast/clone_context.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/stride_decoration.h"
@@ -69,27 +70,24 @@
}
Transform::Output VertexPulling::Run(ast::Module* in) {
- Output out;
- out.module = in->Clone();
-
- ast::Module* mod = &out.module;
-
// Check SetVertexState was called
if (!cfg.vertex_state_set) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "SetVertexState not called";
+ Output out;
out.diagnostics.add(std::move(err));
return out;
}
// Find entry point
- auto* func = mod->FindFunctionBySymbolAndStage(
- mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
+ auto* func = in->FindFunctionBySymbolAndStage(
+ in->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
if (func == nullptr) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
err.message = "Vertex stage entry point not found";
+ Output out;
out.diagnostics.add(std::move(err));
return out;
}
@@ -99,13 +97,22 @@
// TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass
+ Output out;
+ out.module = in->Clone([&](ast::CloneContext* ctx) {
+ State state{in, ctx->mod, cfg};
+ state.FindOrInsertVertexIndexIfUsed();
+ state.FindOrInsertInstanceIndexIfUsed();
+ state.ConvertVertexInputVariablesToPrivate();
+ state.AddVertexStorageBuffers();
- State state{mod, cfg};
- state.FindOrInsertVertexIndexIfUsed();
- state.FindOrInsertInstanceIndexIfUsed();
- state.ConvertVertexInputVariablesToPrivate();
- state.AddVertexStorageBuffers();
- func->body()->insert(0, state.CreateVertexPullingPreamble());
+ ctx->ReplaceAll([func, ctx, state](ast::Function* f) -> ast::Function* {
+ if (f == func) {
+ return CloneWithStatementsAtStart(
+ ctx, f, {state.CreateVertexPullingPreamble()});
+ }
+ return nullptr; // Just clone func
+ });
+ });
return out;
}
@@ -114,11 +121,14 @@
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
-VertexPulling::State::State(ast::Module* m, const Config& c) : mod(m), cfg(c) {}
+VertexPulling::State::State(ast::Module* i, ast::Module* o, const Config& c)
+ : in(i), out(o), cfg(c) {}
+
+VertexPulling::State::State(const State&) = default;
VertexPulling::State::~State() = default;
-std::string VertexPulling::State::GetVertexBufferName(uint32_t index) {
+std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const {
return kVertexBufferNamePrefix + std::to_string(index);
}
@@ -135,7 +145,7 @@
}
// Look for an existing vertex index builtin
- for (auto* v : mod->global_variables()) {
+ for (auto* v : in->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@@ -154,7 +164,7 @@
vertex_index_name = kDefaultVertexIndexName;
auto* var =
- mod->create<ast::Variable>(Source{}, // source
+ out->create<ast::Variable>(Source{}, // source
vertex_index_name, // name
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
@@ -162,11 +172,11 @@
nullptr, // constructor
ast::VariableDecorationList{
// decorations
- mod->create<ast::BuiltinDecoration>(
+ out->create<ast::BuiltinDecoration>(
ast::Builtin::kVertexIdx, Source{}),
});
- mod->AddGlobalVariable(var);
+ out->AddGlobalVariable(var);
}
void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
@@ -182,7 +192,7 @@
}
// Look for an existing instance index builtin
- for (auto* v : mod->global_variables()) {
+ for (auto* v : in->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@@ -201,7 +211,7 @@
instance_index_name = kDefaultInstanceIndexName;
auto* var =
- mod->create<ast::Variable>(Source{}, // source
+ out->create<ast::Variable>(Source{}, // source
instance_index_name, // name
ast::StorageClass::kInput, // storage_class
GetI32Type(), // type
@@ -209,14 +219,14 @@
nullptr, // constructor
ast::VariableDecorationList{
// decorations
- mod->create<ast::BuiltinDecoration>(
+ out->create<ast::BuiltinDecoration>(
ast::Builtin::kInstanceIdx, Source{}),
});
- mod->AddGlobalVariable(var);
+ out->AddGlobalVariable(var);
}
void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
- for (auto*& v : mod->global_variables()) {
+ for (auto*& v : in->global_variables()) {
if (v->storage_class() != ast::StorageClass::kInput) {
continue;
}
@@ -227,7 +237,7 @@
// This is where the replacement happens. Expressions use identifier
// strings instead of pointers, so we don't need to update any other
// place in the AST.
- v = mod->create<ast::Variable>(
+ v = out->create<ast::Variable>(
Source{}, // source
v->name(), // name
ast::StorageClass::kPrivate, // storage_class
@@ -245,31 +255,31 @@
void VertexPulling::State::AddVertexStorageBuffers() {
// TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
// The array inside the struct definition
- auto* internal_array_type = mod->create<ast::type::Array>(
+ auto* internal_array_type = out->create<ast::type::Array>(
GetU32Type(), 0,
ast::ArrayDecorationList{
- mod->create<ast::StrideDecoration>(4u, Source{}),
+ out->create<ast::StrideDecoration>(4u, Source{}),
});
// Creating the struct type
ast::StructMemberList members;
ast::StructMemberDecorationList member_dec;
member_dec.push_back(
- mod->create<ast::StructMemberOffsetDecoration>(0u, Source{}));
+ out->create<ast::StructMemberOffsetDecoration>(0u, Source{}));
- members.push_back(mod->create<ast::StructMember>(
+ members.push_back(out->create<ast::StructMember>(
Source{}, kStructBufferName, internal_array_type, std::move(member_dec)));
ast::StructDecorationList decos;
- decos.push_back(mod->create<ast::StructBlockDecoration>(Source{}));
+ decos.push_back(out->create<ast::StructBlockDecoration>(Source{}));
- auto* struct_type = mod->create<ast::type::Struct>(
- mod->RegisterSymbol(kStructName), kStructName,
- mod->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
+ auto* struct_type = out->create<ast::type::Struct>(
+ out->RegisterSymbol(kStructName), kStructName,
+ out->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
- auto* var = mod->create<ast::Variable>(
+ auto* var = out->create<ast::Variable>(
Source{}, // source
GetVertexBufferName(i), // name
ast::StorageClass::kStorageBuffer, // storage_class
@@ -278,23 +288,23 @@
nullptr, // constructor
ast::VariableDecorationList{
// decorations
- mod->create<ast::BindingDecoration>(i, Source{}),
- mod->create<ast::SetDecoration>(cfg.pulling_set, Source{}),
+ out->create<ast::BindingDecoration>(i, Source{}),
+ out->create<ast::SetDecoration>(cfg.pulling_set, Source{}),
});
- mod->AddGlobalVariable(var);
+ out->AddGlobalVariable(var);
}
- mod->AddConstructedType(struct_type);
+ out->AddConstructedType(struct_type);
}
-ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() {
+ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
// Assign by looking at the vertex descriptor to find attributes with matching
// location.
ast::StatementList stmts;
// Declare the |kPullingPosVarName| variable in the shader
- auto* pos_declaration = mod->create<ast::VariableDeclStatement>(
- Source{}, mod->create<ast::Variable>(
+ auto* pos_declaration = out->create<ast::VariableDeclStatement>(
+ Source{}, out->create<ast::Variable>(
Source{}, // source
kPullingPosVarName, // name
ast::StorageClass::kFunction, // storage_class
@@ -323,45 +333,46 @@
? vertex_index_name
: instance_index_name;
// Identifier to index by
- auto* index_identifier = mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(name), name);
+ auto* index_identifier = out->create<ast::IdentifierExpression>(
+ Source{}, out->RegisterSymbol(name), name);
// An expression for the start of the read in the buffer in bytes
- auto* pos_value = mod->create<ast::BinaryExpression>(
+ auto* pos_value = out->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd,
- mod->create<ast::BinaryExpression>(
+ out->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kMultiply, index_identifier,
GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
GenUint(static_cast<uint32_t>(attribute_desc.offset)));
// Update position of the read
- auto* set_pos_expr = mod->create<ast::AssignmentStatement>(
+ auto* set_pos_expr = out->create<ast::AssignmentStatement>(
Source{}, CreatePullingPositionIdent(), pos_value);
stmts.emplace_back(set_pos_expr);
- stmts.emplace_back(mod->create<ast::AssignmentStatement>(
+ stmts.emplace_back(out->create<ast::AssignmentStatement>(
Source{},
- mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(v->name()), v->name()),
+ out->create<ast::IdentifierExpression>(
+ Source{}, out->RegisterSymbol(v->name()), v->name()),
AccessByFormat(i, attribute_desc.format)));
}
}
- return mod->create<ast::BlockStatement>(Source{}, stmts);
+ return out->create<ast::BlockStatement>(Source{}, stmts);
}
-ast::Expression* VertexPulling::State::GenUint(uint32_t value) {
- return mod->create<ast::ScalarConstructorExpression>(
- Source{}, mod->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
+ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
+ return out->create<ast::ScalarConstructorExpression>(
+ Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
}
-ast::Expression* VertexPulling::State::CreatePullingPositionIdent() {
- return mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(kPullingPosVarName), kPullingPosVarName);
+ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
+ return out->create<ast::IdentifierExpression>(
+ Source{}, out->RegisterSymbol(kPullingPosVarName), kPullingPosVarName);
}
-ast::Expression* VertexPulling::State::AccessByFormat(uint32_t buffer,
- VertexFormat format) {
+ast::Expression* VertexPulling::State::AccessByFormat(
+ uint32_t buffer,
+ VertexFormat format) const {
// TODO(idanr): this doesn't account for the format of the attribute in the
// shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
// right now, we would try to assign a vec4<f32> to this attribute, but we
@@ -388,43 +399,44 @@
}
ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
- ast::Expression* pos) {
+ ast::Expression* pos) const {
// Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
// has byte offsets for each attribute, and we will convert it to u32 indexes
// by dividing. Then, that element is going to be read, and if needed,
// unpacked into an appropriate variable. All reads should end up here as a
// base case.
auto vbuf_name = GetVertexBufferName(buffer);
- return mod->create<ast::ArrayAccessorExpression>(
+ return out->create<ast::ArrayAccessorExpression>(
Source{},
- mod->create<ast::MemberAccessorExpression>(
+ out->create<ast::MemberAccessorExpression>(
Source{},
- mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(vbuf_name), vbuf_name),
- mod->create<ast::IdentifierExpression>(
- Source{}, mod->RegisterSymbol(kStructBufferName),
+ out->create<ast::IdentifierExpression>(
+ Source{}, out->RegisterSymbol(vbuf_name), vbuf_name),
+ out->create<ast::IdentifierExpression>(
+ Source{}, out->RegisterSymbol(kStructBufferName),
kStructBufferName)),
- mod->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
+ out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
GenUint(4)));
}
ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
- ast::Expression* pos) {
+ ast::Expression* pos) const {
// as<T> reinterprets bits
- return mod->create<ast::BitcastExpression>(Source{}, GetI32Type(),
+ return out->create<ast::BitcastExpression>(Source{}, GetI32Type(),
AccessU32(buffer, pos));
}
ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
- ast::Expression* pos) {
+ ast::Expression* pos) const {
// as<T> reinterprets bits
- return mod->create<ast::BitcastExpression>(Source{}, GetF32Type(),
+ return out->create<ast::BitcastExpression>(Source{}, GetF32Type(),
AccessU32(buffer, pos));
}
-ast::Expression* VertexPulling::State::AccessPrimitive(uint32_t buffer,
- ast::Expression* pos,
- VertexFormat format) {
+ast::Expression* VertexPulling::State::AccessPrimitive(
+ uint32_t buffer,
+ ast::Expression* pos,
+ VertexFormat format) const {
// This function uses a position expression to read, rather than using the
// position variable. This allows us to read from offset positions relative to
// |kPullingPosVarName|. We can't call AccessByFormat because it reads only
@@ -445,31 +457,31 @@
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
- uint32_t count) {
+ uint32_t count) const {
ast::ExpressionList expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
- auto* cur_pos = mod->create<ast::BinaryExpression>(
+ auto* cur_pos = out->create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
GenUint(element_stride * i));
expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
}
- return mod->create<ast::TypeConstructorExpression>(
- Source{}, mod->create<ast::type::Vector>(base_type, count),
+ return out->create<ast::TypeConstructorExpression>(
+ Source{}, out->create<ast::type::Vector>(base_type, count),
std::move(expr_list));
}
-ast::type::Type* VertexPulling::State::GetU32Type() {
- return mod->create<ast::type::U32>();
+ast::type::Type* VertexPulling::State::GetU32Type() const {
+ return out->create<ast::type::U32>();
}
-ast::type::Type* VertexPulling::State::GetI32Type() {
- return mod->create<ast::type::I32>();
+ast::type::Type* VertexPulling::State::GetI32Type() const {
+ return out->create<ast::type::I32>();
}
-ast::type::Type* VertexPulling::State::GetF32Type() {
- return mod->create<ast::type::F32>();
+ast::type::Type* VertexPulling::State::GetF32Type() const {
+ return out->create<ast::type::F32>();
}
VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h
index 4f55b1f..6d4f23c 100644
--- a/src/transform/vertex_pulling.h
+++ b/src/transform/vertex_pulling.h
@@ -178,12 +178,13 @@
Config cfg;
struct State {
- State(ast::Module* m, const Config& c);
+ State(ast::Module* in, ast::Module* out, const Config& c);
+ explicit State(const State&);
~State();
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
- std::string GetVertexBufferName(uint32_t index);
+ std::string GetVertexBufferName(uint32_t index) const;
/// Inserts vertex_idx binding, or finds the existing one
void FindOrInsertVertexIndexIfUsed();
@@ -198,36 +199,36 @@
void AddVertexStorageBuffers();
/// Creates and returns the assignment to the variables from the buffers
- ast::BlockStatement* CreateVertexPullingPreamble();
+ ast::BlockStatement* CreateVertexPullingPreamble() const;
/// Generates an expression holding a constant uint
/// @param value uint value
- ast::Expression* GenUint(uint32_t value);
+ ast::Expression* GenUint(uint32_t value) const;
/// Generates an expression to read the shader value `kPullingPosVarName`
- ast::Expression* CreatePullingPositionIdent();
+ ast::Expression* CreatePullingPositionIdent() const;
/// Generates an expression reading from a buffer a specific format.
/// This reads the value wherever `kPullingPosVarName` points to at the time
/// of the read.
/// @param buffer the index of the vertex buffer
/// @param format the format to read
- ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format);
+ ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const;
/// Generates an expression reading a uint32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos);
+ ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const;
/// Generates an expression reading an int32 from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos);
+ ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const;
/// Generates an expression reading a float from a vertex buffer
/// @param buffer the index of the vertex buffer
/// @param pos an expression for the position of the access, in bytes
- ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos);
+ ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const;
/// Generates an expression reading a basic type (u32, i32, f32) from a
/// vertex buffer
@@ -236,7 +237,7 @@
/// @param format the underlying vertex format
ast::Expression* AccessPrimitive(uint32_t buffer,
ast::Expression* pos,
- VertexFormat format);
+ VertexFormat format) const;
/// Generates an expression reading a vec2/3/4 from a vertex buffer.
/// This reads the value wherever `kPullingPosVarName` points to at the time
@@ -250,14 +251,15 @@
uint32_t element_stride,
ast::type::Type* base_type,
VertexFormat base_format,
- uint32_t count);
+ uint32_t count) const;
// Used to grab corresponding types from the type manager
- ast::type::Type* GetU32Type();
- ast::type::Type* GetI32Type();
- ast::type::Type* GetF32Type();
+ ast::type::Type* GetU32Type() const;
+ ast::type::Type* GetI32Type() const;
+ ast::type::Type* GetF32Type() const;
- ast::Module* const mod;
+ ast::Module* const in;
+ ast::Module* const out;
Config const cfg;
std::unordered_map<uint32_t, ast::Variable*> location_to_var;
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index 30145ce..818853f 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -176,11 +176,6 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
- var_a
- private
- __f32
- }
- Variable{
Decorations{
BuiltinDecoration{vertex_idx}
}
@@ -197,6 +192,11 @@
storage_buffer
__struct_TintVertexData
}
+ Variable{
+ var_a
+ private
+ __f32
+ }
Function main -> __void
StageDecoration{vertex}
()
@@ -263,11 +263,6 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
- var_a
- private
- __f32
- }
- Variable{
Decorations{
BuiltinDecoration{instance_idx}
}
@@ -284,6 +279,11 @@
storage_buffer
__struct_TintVertexData
}
+ Variable{
+ var_a
+ private
+ __f32
+ }
Function main -> __void
StageDecoration{vertex}
()
@@ -350,11 +350,6 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
- var_a
- private
- __f32
- }
- Variable{
Decorations{
BuiltinDecoration{vertex_idx}
}
@@ -371,6 +366,11 @@
storage_buffer
__struct_TintVertexData
}
+ Variable{
+ var_a
+ private
+ __f32
+ }
Function main -> __void
StageDecoration{vertex}
()
@@ -466,6 +466,24 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
+ Decorations{
+ BindingDecoration{0}
+ SetDecoration{4}
+ }
+ _tint_pulling_vertex_buffer_0
+ storage_buffer
+ __struct_TintVertexData
+ }
+ Variable{
+ Decorations{
+ BindingDecoration{1}
+ SetDecoration{4}
+ }
+ _tint_pulling_vertex_buffer_1
+ storage_buffer
+ __struct_TintVertexData
+ }
+ Variable{
var_a
private
__f32
@@ -491,24 +509,6 @@
in
__i32
}
- Variable{
- Decorations{
- BindingDecoration{0}
- SetDecoration{4}
- }
- _tint_pulling_vertex_buffer_0
- storage_buffer
- __struct_TintVertexData
- }
- Variable{
- Decorations{
- BindingDecoration{1}
- SetDecoration{4}
- }
- _tint_pulling_vertex_buffer_1
- storage_buffer
- __struct_TintVertexData
- }
Function main -> __void
StageDecoration{vertex}
()
@@ -608,16 +608,6 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
- var_a
- private
- __f32
- }
- Variable{
- var_b
- private
- __array__f32_4
- }
- Variable{
Decorations{
BuiltinDecoration{vertex_idx}
}
@@ -634,6 +624,16 @@
storage_buffer
__struct_TintVertexData
}
+ Variable{
+ var_a
+ private
+ __f32
+ }
+ Variable{
+ var_b
+ private
+ __array__f32_4
+ }
Function main -> __void
StageDecoration{vertex}
()
@@ -795,21 +795,6 @@
StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
}
Variable{
- var_a
- private
- __array__f32_2
- }
- Variable{
- var_b
- private
- __array__f32_3
- }
- Variable{
- var_c
- private
- __array__f32_4
- }
- Variable{
Decorations{
BuiltinDecoration{vertex_idx}
}
@@ -844,6 +829,21 @@
storage_buffer
__struct_TintVertexData
}
+ Variable{
+ var_a
+ private
+ __array__f32_2
+ }
+ Variable{
+ var_b
+ private
+ __array__f32_3
+ }
+ Variable{
+ var_c
+ private
+ __array__f32_4
+ }
Function main -> __void
StageDecoration{vertex}
()