ast: Add TraverseExpressions()

An ast::Expression traversal helper extracted from Resolver.

Change-Id: I88754cbc86cc12cbf8348fb36a3f038904017f3d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67202
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index aa4a13d..b400b27 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -317,6 +317,7 @@
     "ast/texture.cc",
     "ast/texture.h",
     "ast/type.h",
+    "ast/traverse_expressions.h",
     "ast/type_constructor_expression.cc",
     "ast/type_constructor_expression.h",
     "ast/type_decl.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 53def54..a77bb9b 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -180,6 +180,7 @@
   ast/switch_statement.h
   ast/texture.cc
   ast/texture.h
+  ast/traverse_expressions.h
   ast/type_constructor_expression.cc
   ast/type_constructor_expression.h
   ast/type_name.cc
@@ -639,6 +640,7 @@
     ast/switch_statement_test.cc
     ast/test_helper.h
     ast/texture_test.cc
+    ast/traverse_expressions_test.cc
     ast/type_constructor_expression_test.cc
     ast/u32_test.cc
     ast/uint_literal_test.cc
diff --git a/src/ast/traverse_expressions.h b/src/ast/traverse_expressions.h
new file mode 100644
index 0000000..28df9f2
--- /dev/null
+++ b/src/ast/traverse_expressions.h
@@ -0,0 +1,141 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_AST_TRAVERSE_EXPRESSIONS_H_
+#define SRC_AST_TRAVERSE_EXPRESSIONS_H_
+
+#include <vector>
+
+#include "src/ast/array_accessor_expression.h"
+#include "src/ast/binary_expression.h"
+#include "src/ast/bitcast_expression.h"
+#include "src/ast/call_expression.h"
+#include "src/ast/member_accessor_expression.h"
+#include "src/ast/phony_expression.h"
+#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/type_constructor_expression.h"
+#include "src/ast/unary_op_expression.h"
+#include "src/utils/reverse.h"
+
+namespace tint {
+namespace ast {
+
+/// The action to perform after calling the TraverseExpressions() callback
+/// function.
+enum class TraverseAction {
+  /// Stop traversal immediately.
+  Stop,
+  /// Descend into this expression.
+  Descend,
+  /// Do not descend into this expression.
+  Skip,
+};
+
+/// The order TraverseExpressions() will traverse expressions
+enum class TraverseOrder {
+  /// Expressions will be traversed from left to right
+  LeftToRight,
+  /// Expressions will be traversed from right to left
+  RightToLeft,
+};
+
+/// TraverseExpressions performs a depth-first traversal of the expression nodes
+/// from `root`, calling `callback` for each of the visited expressions that
+/// match the predicate parameter type, in pre-ordering (root first).
+/// @param root the root expression node
+/// @param diags the diagnostics used for error messages
+/// @param callback the callback function. Must be of the signature:
+///        `TraverseAction(const T*)` where T is an ast::Expression type.
+/// @return true on success, false on error
+template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
+bool TraverseExpressions(const ast::Expression* root,
+                         diag::List& diags,
+                         CALLBACK&& callback) {
+  using EXPR_TYPE = std::remove_pointer_t<traits::ParamTypeT<CALLBACK, 0>>;
+  std::vector<const ast::Expression*> to_visit{root};
+
+  auto push_pair = [&](const ast::Expression* left,
+                       const ast::Expression* right) {
+    if (ORDER == TraverseOrder::LeftToRight) {
+      to_visit.push_back(right);
+      to_visit.push_back(left);
+    } else {
+      to_visit.push_back(left);
+      to_visit.push_back(right);
+    }
+  };
+  auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
+    if (ORDER == TraverseOrder::LeftToRight) {
+      for (auto* expr : utils::Reverse(exprs)) {
+        to_visit.push_back(expr);
+      }
+    } else {
+      for (auto* expr : exprs) {
+        to_visit.push_back(expr);
+      }
+    }
+  };
+
+  while (!to_visit.empty()) {
+    auto* expr = to_visit.back();
+    to_visit.pop_back();
+
+    if (auto* filtered = expr->As<EXPR_TYPE>()) {
+      switch (callback(filtered)) {
+        case TraverseAction::Stop:
+          return true;
+        case TraverseAction::Skip:
+          continue;
+        case TraverseAction::Descend:
+          break;
+      }
+    }
+
+    if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
+      push_pair(array->array, array->index);
+    } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
+      push_pair(bin_op->lhs, bin_op->rhs);
+    } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
+      to_visit.push_back(bitcast->expr);
+    } else if (auto* call = expr->As<ast::CallExpression>()) {
+      // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
+      // function name in the traversal.
+      // to_visit.push_back(call->func);
+      push_list(call->args);
+    } else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
+      push_list(type_ctor->values);
+    } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
+      // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
+      // member name in the traversal.
+      // push_pair(member->structure, member->member);
+      to_visit.push_back(member->structure);
+    } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
+      to_visit.push_back(unary->expr);
+    } else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
+                             ast::IdentifierExpression,
+                             ast::PhonyExpression>()) {
+      // Leaf expression
+    } else {
+      TINT_ICE(AST, diags) << "unhandled expression type: "
+                           << expr->TypeInfo().name;
+      return false;
+    }
+  }
+  return true;
+}
+
+}  // namespace ast
+}  // namespace tint
+
+#endif  // SRC_AST_TRAVERSE_EXPRESSIONS_H_
diff --git a/src/ast/traverse_expressions_test.cc b/src/ast/traverse_expressions_test.cc
new file mode 100644
index 0000000..d1b2673
--- /dev/null
+++ b/src/ast/traverse_expressions_test.cc
@@ -0,0 +1,262 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/traverse_expressions.h"
+#include "gmock/gmock.h"
+#include "src/ast/test_helper.h"
+
+namespace tint {
+namespace ast {
+namespace {
+
+using ::testing::ElementsAre;
+
+using TraverseExpressionsTest = TestHelper;
+
+TEST_F(TraverseExpressionsTest, DescendArrayAccessorExpression) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
+                                           IndexAccessor(e[2], e[3])};
+  auto* root = IndexAccessor(i[0], i[1]);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
+  auto* root = Mul(i[0], i[1]);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
+  auto* e = Expr(1);
+  auto* b0 = Bitcast<i32>(e);
+  auto* b1 = Bitcast<i32>(b0);
+  auto* b2 = Bitcast<i32>(b1);
+  auto* root = Bitcast<i32>(b2);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, b2, b1, b0, e));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, b2, b1, b0, e));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, DescendCallExpression) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> c = {Call("a", e[0], e[1]),
+                                           Call("b", e[2], e[3])};
+  auto* root = Call("c", c[0], c[1]);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, DescendTypeConstructorExpression) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> c = {vec2<i32>(e[0], e[1]),
+                                           vec2<i32>(e[2], e[3])};
+  auto* root = vec2<i32>(c[0], c[1]);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
+  }
+}
+
+// TODO(crbug.com/tint/1257): Test ignores member accessor 'member' field.
+// Replace with the test below when fixed.
+TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) {
+  auto* e = Expr(1);
+  auto* m = MemberAccessor(e, Expr("a"));
+  auto* root = MemberAccessor(m, Expr("b"));
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, m, e));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, m, e));
+  }
+}
+
+// TODO(crbug.com/tint/1257): The correct test for DescendMemberIndexExpression.
+TEST_F(TraverseExpressionsTest, DISABLED_DescendMemberIndexExpression) {
+  auto* e = Expr(1);
+  std::vector<const ast::IdentifierExpression*> i = {Expr("a"), Expr("b")};
+  auto* m = MemberAccessor(e, i[0]);
+  auto* root = MemberAccessor(m, i[1]);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, m, e, i[0], i[1]));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, i[1], m, i[0], e));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, DescendUnaryExpression) {
+  auto* e = Expr(1);
+  auto* u0 = AddressOf(e);
+  auto* u1 = Deref(u0);
+  auto* u2 = AddressOf(u1);
+  auto* root = Deref(u2);
+  {
+    std::vector<const ast::Expression*> l2r;
+    TraverseExpressions<TraverseOrder::LeftToRight>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          l2r.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(l2r, ElementsAre(root, u2, u1, u0, e));
+  }
+  {
+    std::vector<const ast::Expression*> r2l;
+    TraverseExpressions<TraverseOrder::RightToLeft>(
+        root, Diagnostics(), [&](const ast::Expression* expr) {
+          r2l.push_back(expr);
+          return ast::TraverseAction::Descend;
+        });
+    EXPECT_THAT(r2l, ElementsAre(root, u2, u1, u0, e));
+  }
+}
+
+TEST_F(TraverseExpressionsTest, Skip) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
+                                           IndexAccessor(e[2], e[3])};
+  auto* root = IndexAccessor(i[0], i[1]);
+  std::vector<const ast::Expression*> order;
+  TraverseExpressions<TraverseOrder::LeftToRight>(
+      root, Diagnostics(), [&](const ast::Expression* expr) {
+        order.push_back(expr);
+        return expr == i[0] ? ast::TraverseAction::Skip
+                            : ast::TraverseAction::Descend;
+      });
+  EXPECT_THAT(order, ElementsAre(root, i[0], i[1], e[2], e[3]));
+}
+
+TEST_F(TraverseExpressionsTest, Stop) {
+  std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
+  std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
+                                           IndexAccessor(e[2], e[3])};
+  auto* root = IndexAccessor(i[0], i[1]);
+  std::vector<const ast::Expression*> order;
+  TraverseExpressions<TraverseOrder::LeftToRight>(
+      root, Diagnostics(), [&](const ast::Expression* expr) {
+        order.push_back(expr);
+        return expr == i[0] ? ast::TraverseAction::Stop
+                            : ast::TraverseAction::Descend;
+      });
+  EXPECT_THAT(order, ElementsAre(root, i[0]));
+}
+
+}  // namespace
+}  // namespace ast
+}  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index b8b278e..aee387b 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -45,6 +45,7 @@
 #include "src/ast/storage_texture.h"
 #include "src/ast/struct_block_decoration.h"
 #include "src/ast/switch_statement.h"
+#include "src/ast/traverse_expressions.h"
 #include "src/ast/type_name.h"
 #include "src/ast/unary_op_expression.h"
 #include "src/ast/variable_decl_statement.h"
@@ -469,7 +470,6 @@
 
   // Does the variable have a constructor?
   if (auto* ctor = var->constructor) {
-    Mark(var->constructor);
     if (!Expression(var->constructor)) {
       return nullptr;
     }
@@ -1886,7 +1886,6 @@
         continue;
       }
 
-      Mark(expr);
       if (!Expression(expr)) {
         return false;
       }
@@ -2061,7 +2060,6 @@
     return true;
   }
   if (auto* c = stmt->As<ast::CallStatement>()) {
-    Mark(c->expr);
     if (!Expression(c->expr)) {
       return false;
     }
@@ -2138,7 +2136,6 @@
       builder_->create<sem::IfStatement>(stmt, current_compound_statement_);
   builder_->Sem().Add(stmt, sem);
   return Scope(sem, [&] {
-    Mark(stmt->condition);
     if (!Expression(stmt->condition)) {
       return false;
     }
@@ -2175,7 +2172,6 @@
   builder_->Sem().Add(stmt, sem);
   return Scope(sem, [&] {
     if (auto* cond = stmt->condition) {
-      Mark(cond);
       if (!Expression(cond)) {
         return false;
       }
@@ -2250,7 +2246,6 @@
     }
 
     if (auto* condition = stmt->condition) {
-      Mark(condition);
       if (!Expression(condition)) {
         return false;
       }
@@ -2279,58 +2274,14 @@
   });
 }
 
-bool Resolver::TraverseExpressions(const ast::Expression* root,
-                                   std::vector<const ast::Expression*>& out) {
-  std::vector<const ast::Expression*> to_visit;
-  to_visit.emplace_back(root);
-
-  auto add = [&](const ast::Expression* e) {
-    Mark(e);
-    to_visit.emplace_back(e);
-  };
-
-  while (!to_visit.empty()) {
-    auto* expr = to_visit.back();
-    to_visit.pop_back();
-
-    out.emplace_back(expr);
-
-    if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
-      add(array->array);
-      add(array->index);
-    } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
-      add(bin_op->lhs);
-      add(bin_op->rhs);
-    } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
-      add(bitcast->expr);
-    } else if (auto* call = expr->As<ast::CallExpression>()) {
-      for (auto* arg : call->args) {
-        add(arg);
-      }
-    } else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
-      for (auto* value : type_ctor->values) {
-        add(value);
-      }
-    } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
-      add(member->structure);
-    } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
-      add(unary->expr);
-    } else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
-                             ast::IdentifierExpression>()) {
-      // Leaf expression
-    } else {
-      TINT_ICE(Resolver, diagnostics_)
-          << "unhandled expression type: " << expr->TypeInfo().name;
-      return false;
-    }
-  }
-
-  return true;
-}
-
 bool Resolver::Expression(const ast::Expression* root) {
   std::vector<const ast::Expression*> sorted;
-  if (!TraverseExpressions(root, sorted)) {
+  if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
+          root, diagnostics_, [&](const ast::Expression* expr) {
+            Mark(expr);
+            sorted.emplace_back(expr);
+            return ast::TraverseAction::Descend;
+          })) {
     return false;
   }
 
@@ -3874,7 +3825,6 @@
   // sem::Array uses a size of 0 for a runtime-sized array.
   uint32_t count = 0;
   if (auto* count_expr = arr->count) {
-    Mark(count_expr);
     if (!Expression(count_expr)) {
       return nullptr;
     }
@@ -4340,7 +4290,6 @@
   current_function_->return_statements.push_back(ret);
 
   if (auto* value = ret->value) {
-    Mark(value);
     if (!Expression(value)) {
       return false;
     }
@@ -4424,7 +4373,6 @@
       builder_->create<sem::SwitchStatement>(stmt, current_compound_statement_);
   builder_->Sem().Add(stmt, sem);
   return Scope(sem, [&] {
-    Mark(stmt->condition);
     if (!Expression(stmt->condition)) {
       return false;
     }
@@ -4442,9 +4390,6 @@
 }
 
 bool Resolver::Assignment(const ast::AssignmentStatement* a) {
-  Mark(a->lhs);
-  Mark(a->rhs);
-
   if (!Expression(a->lhs) || !Expression(a->rhs)) {
     return false;
   }
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index b7f839b..212360b 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -262,15 +262,6 @@
   bool UnaryOp(const ast::UnaryOpExpression*);
   bool VariableDeclStatement(const ast::VariableDeclStatement*);
 
-  /// Performs a depth-first traversal of the expression nodes from `root`,
-  /// collecting all the visited expressions in pre-ordering (root first).
-  /// @param root the root expression node
-  /// @param out the ordered list of visited expression nodes, starting with the
-  ///        root node, and ending with leaf nodes
-  /// @return true on success, false on error
-  bool TraverseExpressions(const ast::Expression* root,
-                           std::vector<const ast::Expression*>& out);
-
   // AST and Type validation methods
   // Each return true on success, false on failure.
   bool ValidateArray(const sem::Array* arr, const Source& source);
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 53dc747..085a8dc 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -207,6 +207,7 @@
     "../src/ast/switch_statement_test.cc",
     "../src/ast/test_helper.h",
     "../src/ast/texture_test.cc",
+    "../src/ast/traverse_expressions_test.cc",
     "../src/ast/type_constructor_expression_test.cc",
     "../src/ast/u32_test.cc",
     "../src/ast/uint_literal_test.cc",