Import Tint changes from Dawn
Changes:
- 96a96697192cf8bf59979ef8060e0a975c618596 [tint] Fold trivial lets in SPIR-V reader by James Price <jrprice@google.com>
- b50aa3802e94fcb3111a2d6b0cb91f42ac015ea5 [tint][fuzzer] Fix bad mutation by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 96a96697192cf8bf59979ef8060e0a975c618596
Change-Id: Ice4d2198f6775b2b17ff5f2782e835546d6c9fa5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/140261
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 38c374e..2dae7b4 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -436,6 +436,8 @@
"ast/transform/expand_compound_assignment.h",
"ast/transform/first_index_offset.cc",
"ast/transform/first_index_offset.h",
+ "ast/transform/fold_trivial_lets.cc",
+ "ast/transform/fold_trivial_lets.h",
"ast/transform/for_loop_to_loop.cc",
"ast/transform/for_loop_to_loop.h",
"ast/transform/localize_struct_array_assignment.cc",
@@ -1813,6 +1815,7 @@
"ast/transform/disable_uniformity_analysis_test.cc",
"ast/transform/expand_compound_assignment_test.cc",
"ast/transform/first_index_offset_test.cc",
+ "ast/transform/fold_trivial_lets_test.cc",
"ast/transform/for_loop_to_loop_test.cc",
"ast/transform/localize_struct_array_assignment_test.cc",
"ast/transform/merge_return_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 68fe434..76a37bb 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -389,6 +389,8 @@
ast/transform/expand_compound_assignment.h
ast/transform/first_index_offset.cc
ast/transform/first_index_offset.h
+ ast/transform/fold_trivial_lets.cc
+ ast/transform/fold_trivial_lets.h
ast/transform/for_loop_to_loop.cc
ast/transform/for_loop_to_loop.h
ast/transform/localize_struct_array_assignment.cc
@@ -1388,6 +1390,7 @@
ast/transform/disable_uniformity_analysis_test.cc
ast/transform/expand_compound_assignment_test.cc
ast/transform/first_index_offset_test.cc
+ ast/transform/fold_trivial_lets_test.cc
ast/transform/for_loop_to_loop_test.cc
ast/transform/expand_compound_assignment_test.cc
ast/transform/localize_struct_array_assignment_test.cc
diff --git a/src/tint/ast/transform/fold_trivial_lets.cc b/src/tint/ast/transform/fold_trivial_lets.cc
new file mode 100644
index 0000000..0cf0b7d
--- /dev/null
+++ b/src/tint/ast/transform/fold_trivial_lets.cc
@@ -0,0 +1,157 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/transform/fold_trivial_lets.h"
+
+#include <utility>
+
+#include "src/tint/ast/traverse_expressions.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/value_expression.h"
+#include "src/tint/utils/hashmap.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::FoldTrivialLets);
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform.
+struct FoldTrivialLets::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+ /// The semantic info.
+ const sem::Info& sem = {ctx.src->Sem()};
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Process a block.
+ /// @param block the block
+ void ProcessBlock(const BlockStatement* block) {
+ // PendingLet describes a let declaration that might be inlined.
+ struct PendingLet {
+ // The let declaration.
+ const VariableDeclStatement* decl = nullptr;
+ // The number of uses that have not yet been inlined.
+ size_t remaining_uses = 0;
+ };
+
+ // A map from semantic variables to their PendingLet descriptors.
+ utils::Hashmap<const sem::Variable*, PendingLet, 16> pending_lets;
+
+ // Helper that folds pending let declarations into `expr` if possible.
+ auto fold_lets = [&](const Expression* expr) {
+ TraverseExpressions(expr, b.Diagnostics(), [&](const IdentifierExpression* ident) {
+ if (auto* user = sem.Get<sem::VariableUser>(ident)) {
+ auto itr = pending_lets.Find(user->Variable());
+ if (itr) {
+ TINT_ASSERT(Transform, itr->remaining_uses > 0);
+
+ // We found a reference to a pending let, so replace it with the inlined
+ // initializer expression.
+ ctx.Replace(ident, ctx.Clone(itr->decl->variable->initializer));
+
+ // Decrement the remaining uses count and remove the let declaration if this
+ // was the last remaining use.
+ if (--itr->remaining_uses == 0) {
+ ctx.Remove(block->statements, itr->decl);
+ }
+ }
+ }
+ return TraverseAction::Descend;
+ });
+ };
+
+ // Loop over all statements in the block.
+ for (auto* stmt : block->statements) {
+ // Check for a let declarations.
+ if (auto* decl = stmt->As<VariableDeclStatement>()) {
+ if (auto* let = decl->variable->As<Let>()) {
+ // If the initializer doesn't have side effects, we might be able to inline it.
+ if (!sem.GetVal(let->initializer)->HasSideEffects()) { //
+ auto num_users = sem.Get(let)->Users().Length();
+ if (let->initializer->Is<IdentifierExpression>()) {
+ // The initializer is a single identifier expression.
+ // We can fold it into multiple uses in the next non-let statement.
+ // We also fold previous pending lets into this one, but only if
+ // it's only used once (to avoid duplicating potentially complex
+ // expressions).
+ if (num_users == 1) {
+ fold_lets(let->initializer);
+ }
+ pending_lets.Add(sem.Get(let), PendingLet{decl, num_users});
+ } else {
+ // The initializer is something more complex, so we only want to inline
+ // it if it's only used once.
+ // We also fold previous pending lets into this one.
+ fold_lets(let->initializer);
+ if (num_users == 1) {
+ pending_lets.Add(sem.Get(let), PendingLet{decl, 1});
+ }
+ }
+ continue;
+ }
+ }
+ }
+
+ // Fold pending let declarations into a select few places that are frequently generated
+ // by the SPIR_V reader.
+ if (auto* assign = stmt->As<AssignmentStatement>()) {
+ // We can fold into the RHS of an assignment statement if the RHS and LHS
+ // expressions have no side effects.
+ if (!sem.GetVal(assign->lhs)->HasSideEffects() &&
+ !sem.GetVal(assign->rhs)->HasSideEffects()) {
+ fold_lets(assign->rhs);
+ }
+ } else if (auto* ifelse = stmt->As<IfStatement>()) {
+ // We can fold into the condition of an if statement if the condition expression has
+ // no side effects.
+ if (!sem.GetVal(ifelse->condition)->HasSideEffects()) {
+ fold_lets(ifelse->condition);
+ }
+ }
+
+ // Clear any remaining pending lets.
+ // We do not try to fold lets beyond the first non-let statement.
+ pending_lets.Clear();
+ }
+ }
+
+ /// Runs the transform.
+ /// @returns the new program
+ ApplyResult Run() {
+ // Process all blocks in the module.
+ for (auto* node : src->ASTNodes().Objects()) {
+ if (auto* block = node->As<BlockStatement>()) {
+ ProcessBlock(block);
+ }
+ }
+ ctx.Clone();
+ return Program(std::move(b));
+ }
+};
+
+FoldTrivialLets::FoldTrivialLets() = default;
+
+FoldTrivialLets::~FoldTrivialLets() = default;
+
+Transform::ApplyResult FoldTrivialLets::Apply(const Program* src, const DataMap&, DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/ast/transform/fold_trivial_lets.h b/src/tint/ast/transform/fold_trivial_lets.h
new file mode 100644
index 0000000..9fee63e
--- /dev/null
+++ b/src/tint/ast/transform/fold_trivial_lets.h
@@ -0,0 +1,44 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
+#define SRC_TINT_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
+
+#include "src/tint/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// FoldTrivialLets is a transform that inlines the initializers of let declarations whose
+/// initializers are just identifier expressions, or lets that are only used once. This is used to
+/// clean up unnecessary let declarations created by the SPIR-V reader.
+class FoldTrivialLets final : public utils::Castable<FoldTrivialLets, Transform> {
+ public:
+ /// Constructor
+ FoldTrivialLets();
+
+ /// Destructor
+ ~FoldTrivialLets() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_AST_TRANSFORM_FOLD_TRIVIAL_LETS_H_
diff --git a/src/tint/ast/transform/fold_trivial_lets_test.cc b/src/tint/ast/transform/fold_trivial_lets_test.cc
new file mode 100644
index 0000000..c000c7a
--- /dev/null
+++ b/src/tint/ast/transform/fold_trivial_lets_test.cc
@@ -0,0 +1,286 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/ast/transform/fold_trivial_lets.h"
+
+#include "src/tint/ast/transform/test_helper.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using FoldTrivialLetsTest = TransformTest;
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_AssignRHS) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + 1);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_IfCondition) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ if (x > 0) {
+ v = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ if ((v > 0)) {
+ v = 0;
+ }
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_IdentInitializer_StoreBeforeUse) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = 0;
+ v = (x + 1);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_IdentInitializer_SideEffectsInUseExpression) {
+ auto* src = R"(
+var<private> v = 42;
+
+fn g() -> i32 {
+ v = 0;
+ return 1;
+}
+
+fn f() {
+ let x = v;
+ v = (g() + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_MultiUse) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + x);
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ v = (v + v);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_IdentInitializer_MultiUse_OnlySomeInlineable) {
+ auto* src = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (x + x);
+ if (x > 0) {
+ v = 0;
+ }
+}
+)";
+
+ auto* expect = R"(
+fn f() {
+ var v = 42;
+ let x = v;
+ v = (v + v);
+ if ((x > 0)) {
+ v = 0;
+ }
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUse) {
+ auto* src = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ let x = v[idx].y;
+ v[0].x = (x + 1);
+}
+)";
+
+ auto* expect = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ v[0].x = (v[idx].y + 1);
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_ComplexInitializer_SingleUseWithSideEffects) {
+ auto* src = R"(
+var<private> i = 0;
+
+fn bar() -> i32 {
+ i++;
+ return i;
+}
+
+fn f() -> i32 {
+ var v = array<vec4i, 4>();
+ let x = v[bar()].y;
+ return (i + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, NoFold_ComplexInitializer_MultiUse) {
+ auto* src = R"(
+fn f(idx : i32) {
+ var v = array<vec4i, 4>();
+ let x = v[idx].y;
+ v[0].x = (x + x);
+}
+)";
+
+ auto* expect = src;
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUseViaSimpleLet) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let y = x;
+ return y;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let y = ((a * b) + c);
+ return y;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_SingleUseViaSimpleLetUsedTwice) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = (a * b) + c;
+ let y = x;
+ let z = y + y;
+ return z;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let z = (x + x);
+ return z;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(FoldTrivialLetsTest, Fold_ComplexInitializer_MultiUseUseDifferentLets) {
+ auto* src = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = (a * b) + c;
+ let y = x;
+ let z = x + y;
+ return z;
+}
+)";
+
+ auto* expect = R"(
+fn f(a : i32, b : i32, c : i32) -> i32 {
+ let x = ((a * b) + c);
+ let z = (x + x);
+ return z;
+}
+)";
+
+ auto got = Run<FoldTrivialLets>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
index 25ac065..6f6efc2 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
@@ -18,6 +18,8 @@
#include <vector>
#include "src/tint/program_builder.h"
+#include "src/tint/sem/call.h"
+#include "src/tint/sem/statement.h"
#include "src/tint/type/abstract_float.h"
#include "src/tint/type/abstract_int.h"
@@ -94,6 +96,16 @@
std::vector<ast::UnaryOp> MutationWrapUnaryOperator::GetValidUnaryWrapper(
const sem::ValueExpression& expr) {
+ if (auto* call_expr = expr.As<sem::Call>()) {
+ if (auto* stmt = call_expr->Stmt()) {
+ if (auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>()) {
+ if (call_stmt->expr == expr.Declaration()) {
+ return {}; // A call statement must only wrap a call expression.
+ }
+ }
+ }
+ }
+
const auto* expr_type = expr.Type();
if (expr_type->is_bool_scalar_or_vector()) {
return {ast::UnaryOp::kNot};
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator_test.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator_test.cc
index c73c1b3..91c9453 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator_test.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator_test.cc
@@ -519,5 +519,34 @@
node_id_map, &program, &node_id_map, nullptr));
}
+TEST(WrapUnaryOperatorTest, NotApplicable_CallStmt) {
+ std::string content = R"(
+ fn main() {
+ f();
+ }
+ fn f() -> bool {
+ return false;
+ }
+ )";
+ Source::File file("test.wgsl", content);
+ auto program = reader::wgsl::Parse(&file);
+ ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
+
+ NodeIdMap node_id_map(program);
+
+ const auto& main_fn_statements = program.AST().Functions()[0]->body->statements;
+
+ const auto* call_stmt = main_fn_statements[0]->As<ast::CallStatement>();
+ ASSERT_NE(call_stmt, nullptr);
+
+ const auto expr_id = node_id_map.GetId(call_stmt->expr);
+ ASSERT_NE(expr_id, 0);
+
+ // The id provided for the expression is not a valid expression type.
+ ASSERT_FALSE(MaybeApplyMutation(
+ program, MutationWrapUnaryOperator(expr_id, node_id_map.TakeFreshId(), ast::UnaryOp::kNot),
+ node_id_map, &program, &node_id_map, nullptr));
+}
+
} // namespace
} // namespace tint::fuzzers::ast_fuzzer
diff --git a/src/tint/reader/spirv/parser.cc b/src/tint/reader/spirv/parser.cc
index 1699d2c..417667b 100644
--- a/src/tint/reader/spirv/parser.cc
+++ b/src/tint/reader/spirv/parser.cc
@@ -18,6 +18,7 @@
#include "src/tint/ast/transform/decompose_strided_array.h"
#include "src/tint/ast/transform/decompose_strided_matrix.h"
+#include "src/tint/ast/transform/fold_trivial_lets.h"
#include "src/tint/ast/transform/remove_unreachable_statements.h"
#include "src/tint/ast/transform/simplify_pointers.h"
#include "src/tint/ast/transform/spirv_atomic.h"
@@ -60,6 +61,7 @@
transform::DataMap outputs;
manager.Add<ast::transform::Unshadow>();
manager.Add<ast::transform::SimplifyPointers>();
+ manager.Add<ast::transform::FoldTrivialLets>();
manager.Add<ast::transform::DecomposeStridedMatrix>();
manager.Add<ast::transform::DecomposeStridedArray>();
manager.Add<ast::transform::RemoveUnreachableStatements>();