Import Tint changes from Dawn
Changes:
- 28779af91cd356a28917ae522c21d7f97b56555a tint: impement short-circuiting for const eval of logical... by Antonio Maiorano <amaiorano@google.com>
- f528d33d52340ec0855354cd48bea241d7ed6f57 tint/transform: fix PromoteInitializersToLet for constant... by Ben Clayton <bclayton@google.com>
GitOrigin-RevId: 28779af91cd356a28917ae522c21d7f97b56555a
Change-Id: I046db0bc60d3582fdb190be4a6efdaadf0e7727f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/113604
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ast/node_id.h b/src/tint/ast/node_id.h
index 79683b0..f6bddb1 100644
--- a/src/tint/ast/node_id.h
+++ b/src/tint/ast/node_id.h
@@ -25,7 +25,12 @@
/// Equality operator
/// @param other the other NodeID
/// @returns true if the NodeIDs are the same
- bool operator==(const NodeID& other) const { return value == other.value; }
+ bool operator==(NodeID other) const { return value == other.value; }
+
+ /// Less-than comparison operator
+ /// @param other the other NodeID
+ /// @returns true if the other comes before this node
+ bool operator<(NodeID other) const { return value < other.value; }
/// The numerical value for the node identifier
size_t value = 0;
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index a2ca5b4..ff23562 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -1977,6 +1977,16 @@
return create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, Expr(std::forward<EXPR>(expr)));
}
+ /// @param source the source information
+ /// @param expr the expression to perform a unary not on
+ /// @return an ast::UnaryOpExpression that is the unary not of the input
+ /// expression
+ template <typename EXPR>
+ const ast::UnaryOpExpression* Not(const Source& source, EXPR&& expr) {
+ return create<ast::UnaryOpExpression>(source, ast::UnaryOp::kNot,
+ Expr(std::forward<EXPR>(expr)));
+ }
+
/// @param expr the expression to perform a unary complement on
/// @return an ast::UnaryOpExpression that is the unary complement of the
/// input expression
@@ -2121,6 +2131,17 @@
Expr(std::forward<RHS>(rhs)));
}
+ /// @param source the source information
+ /// @param lhs the left hand argument to the division operation
+ /// @param rhs the right hand argument to the division operation
+ /// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
+ template <typename LHS, typename RHS>
+ const ast::BinaryExpression* Div(const Source& source, LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(source, ast::BinaryOp::kDivide,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param lhs the left hand argument to the modulo operation
/// @param rhs the right hand argument to the modulo operation
/// @returns a `ast::BinaryExpression` applying modulo of `lhs` by `rhs`
@@ -2177,6 +2198,17 @@
ast::BinaryOp::kLogicalAnd, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs)));
}
+ /// @param source the source information
+ /// @param lhs the left hand argument to the logical and operation
+ /// @param rhs the right hand argument to the logical and operation
+ /// @returns a `ast::BinaryExpression` of `lhs` && `rhs`
+ template <typename LHS, typename RHS>
+ const ast::BinaryExpression* LogicalAnd(const Source& source, LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalAnd,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param lhs the left hand argument to the logical or operation
/// @param rhs the right hand argument to the logical or operation
/// @returns a `ast::BinaryExpression` of `lhs` || `rhs`
@@ -2186,6 +2218,17 @@
ast::BinaryOp::kLogicalOr, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs)));
}
+ /// @param source the source information
+ /// @param lhs the left hand argument to the logical or operation
+ /// @param rhs the right hand argument to the logical or operation
+ /// @returns a `ast::BinaryExpression` of `lhs` || `rhs`
+ template <typename LHS, typename RHS>
+ const ast::BinaryExpression* LogicalOr(const Source& source, LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalOr,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param lhs the left hand argument to the greater than operation
/// @param rhs the right hand argument to the greater than operation
/// @returns a `ast::BinaryExpression` of `lhs` > `rhs`
@@ -2234,6 +2277,17 @@
Expr(std::forward<RHS>(rhs)));
}
+ /// @param source the source information
+ /// @param lhs the left hand argument to the equal expression
+ /// @param rhs the right hand argument to the equal expression
+ /// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs`
+ template <typename LHS, typename RHS>
+ const ast::BinaryExpression* Equal(const Source& source, LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(source, ast::BinaryOp::kEqual,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param lhs the left hand argument to the not-equal expression
/// @param rhs the right hand argument to the not-equal expression
/// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs` for
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 02b36d5..6469897 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1814,13 +1814,17 @@
ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
+ // Note: Due to short-circuiting, this function is only called if lhs is true, so we could
+ // technically only return the value of the rhs.
return CreateElement(builder, source, ty, args[0]->As<bool>() && args[1]->As<bool>());
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
- return CreateElement(builder, source, ty, args[0]->As<bool>() || args[1]->As<bool>());
+ // Note: Due to short-circuiting, this function is only called if lhs is false, so we could
+ // technically only return the value of the rhs.
+ return CreateElement(builder, source, ty, args[1]->As<bool>());
}
ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 9753ceb..b3fc800 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -14,6 +14,7 @@
#include "src/tint/resolver/const_eval_test.h"
+#include "src/tint/reader/wgsl/parser.h"
#include "src/tint/utils/result.h"
using namespace tint::number_suffixes; // NOLINT
@@ -1366,5 +1367,917 @@
ShiftRightCases<i32>(), //
ShiftRightCases<u32>()))));
+namespace LogicalShortCircuit {
+
+/// Validates that `binary` is a short-circuiting logical and expression
+static void ValidateAnd(const sem::Info& sem, const ast::BinaryExpression* binary) {
+ auto* lhs = binary->lhs;
+ auto* rhs = binary->rhs;
+
+ auto* lhs_sem = sem.Get(lhs);
+ ASSERT_TRUE(lhs_sem->ConstantValue());
+ EXPECT_EQ(lhs_sem->ConstantValue()->As<bool>(), false);
+ EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant);
+
+ auto* rhs_sem = sem.Get(rhs);
+ EXPECT_EQ(rhs_sem->ConstantValue(), nullptr);
+ EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated);
+
+ auto* binary_sem = sem.Get(binary);
+ ASSERT_TRUE(binary_sem->ConstantValue());
+ EXPECT_EQ(binary_sem->ConstantValue()->As<bool>(), false);
+ EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant);
+}
+
+/// Validates that `binary` is a short-circuiting logical or expression
+static void ValidateOr(const sem::Info& sem, const ast::BinaryExpression* binary) {
+ auto* lhs = binary->lhs;
+ auto* rhs = binary->rhs;
+
+ auto* lhs_sem = sem.Get(lhs);
+ ASSERT_TRUE(lhs_sem->ConstantValue());
+ EXPECT_EQ(lhs_sem->ConstantValue()->As<bool>(), true);
+ EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant);
+
+ auto* rhs_sem = sem.Get(rhs);
+ EXPECT_EQ(rhs_sem->ConstantValue(), nullptr);
+ EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated);
+
+ auto* binary_sem = sem.Get(binary);
+ ASSERT_TRUE(binary_sem->ConstantValue());
+ EXPECT_EQ(binary_sem->ConstantValue()->As<bool>(), true);
+ EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant);
+}
+
+// Naming convention for tests below:
+//
+// [Non]ShortCircuit_[And|Or]_[Error|Invalid]_<Op>
+//
+// Where:
+// ShortCircuit: the rhs will not be const-evaluated
+// NonShortCircuitL the rhs will be const-evaluated
+//
+// And/Or: type of binary expression
+//
+// Error: a non-const evaluation error (e.g. parser or validation error)
+// Invalid: a const-evaluation error
+//
+// <Op> the type of operation on the rhs that may or may not be short-circuited.
+
+////////////////////////////////////////////////
+// Short-Circuit Unary
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid unary op as const eval of unary does not
+// fail.
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Unary) {
+ // const one = 1;
+ // const result = (one == 0) && (!0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Not(Source{{12, 34}}, 0_a);
+ GlobalConst("result", LogicalAnd(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int)
+
+2 candidate operators:
+ operator ! (bool) -> bool
+ operator ! (vecN<bool>) -> vecN<bool>
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Unary) {
+ // const one = 1;
+ // const result = (one == 1) || (!0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Not(Source{{12, 34}}, 0_a);
+ GlobalConst("result", LogicalOr(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int)
+
+2 candidate operators:
+ operator ! (bool) -> bool
+ operator ! (vecN<bool>) -> vecN<bool>
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Binary
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Binary) {
+ // const one = 1;
+ // const result = (one == 0) && ((2 / 0) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Div(2_a, 0_a), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Binary) {
+ // const one = 1;
+ // const result = (one == 1) && ((2 / 0) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Binary) {
+ // const one = 1;
+ // const result = (one == 0) && (2 / 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Div(2_a, 0_a);
+ auto* binary = LogicalAnd(Source{{12, 34}}, lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator && (bool, abstract-int)
+
+1 candidate operator:
+ operator && (bool, bool) -> bool
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Binary) {
+ // const one = 1;
+ // const result = (one == 1) || ((2 / 0) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Div(2_a, 0_a), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Binary) {
+ // const one = 1;
+ // const result = (one == 0) || ((2 / 0) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Binary) {
+ // const one = 1;
+ // const result = (one == 1) || (2 / 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Div(2_a, 0_a);
+ auto* binary = LogicalOr(Source{{12, 34}}, lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator || (bool, abstract-int)
+
+1 candidate operator:
+ operator || (bool, bool) -> bool
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Materialize
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Materialize) {
+ // const one = 1;
+ // const result = (one == 0) && (1.7976931348623157e+308 == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Expr(1.7976931348623157e+308_a), 0_f);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Materialize) {
+ // const one = 1;
+ // const result = (one == 1) && (1.7976931348623157e+308 == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Materialize) {
+ // const one = 1;
+ // const result = (one == 0) && (1.7976931348623157e+308 == 0i);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Source{{12, 34}}, 1.7976931348623157e+308_a, 0_i);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (abstract-float, i32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Materialize) {
+ // const one = 1;
+ // const result = (one == 1) || (1.7976931348623157e+308 == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(1.7976931348623157e+308_a, 0_f);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Materialize) {
+ // const one = 1;
+ // const result = (one == 0) || (1.7976931348623157e+308 == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Materialize) {
+ // const one = 1;
+ // const result = (one == 1) || (1.7976931348623157e+308 == 0i);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Source{{12, 34}}, Expr(1.7976931348623157e+308_a), 0_i);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (abstract-float, i32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Index
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 4;
+ // const result = (one == 0) && (a[i] == 0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(4_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(IndexAccessor("a", "i"), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 3;
+ // const result = (one == 1) && (a[i] == 0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(3_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 3;
+ // const result = (one == 0) && (a[i] == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(3_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (i32, f32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 4;
+ // const result = (one == 1) || (a[i] == 0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(4_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(IndexAccessor("a", "i"), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 3;
+ // const result = (one == 0) || (a[i] == 0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(3_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Index) {
+ // const one = 1;
+ // const a = array(1i, 2i, 3i);
+ // const i = 3;
+ // const result = (one == 1) || (a[i] == 0.0f);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", array<i32, 3>(1_i, 2_i, 3_i));
+ GlobalConst("i", Expr(3_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (i32, f32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Bitcast
+////////////////////////////////////////////////
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Invalid_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 0) && (bitcast<f32>(a) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Bitcast<f32>("a"), 0.0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_And_Invalid_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 1) && (bitcast<f32>(a) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
+}
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Error_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 0) && (bitcast<f32>(a) == 0i);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Source{{12, 34}}, Bitcast(ty.f32(), "a"), 0_i);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
+}
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Invalid_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 1) || (bitcast<f32>(a) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Bitcast<f32>("a"), 0.0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_Or_Invalid_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 0) || (bitcast<f32>(a) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
+}
+
+// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
+TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Error_Bitcast) {
+ // const one = 1;
+ // const a = 0x7F800000;
+ // const result = (one == 1) || (bitcast<f32>(a) == 0i);
+ GlobalConst("one", Expr(1_a));
+ GlobalConst("a", Expr(0x7F800000_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Source{{12, 34}}, Bitcast(ty.f32(), "a"), 0_i);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Type Init/Convert
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid init/convert as const eval of init/convert
+// always succeeds.
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Init) {
+ // const one = 1;
+ // const result = (one == 0) && (vec2<f32>(1.0, true).x == 0.0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(MemberAccessor(vec2<f32>(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a);
+ GlobalConst("result", LogicalAnd(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching initializer for vec2<f32>(abstract-float, bool)
+
+4 candidate initializers:
+ vec2(x: T, y: T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2(T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2(vec2<T>) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2<T>() -> vec2<T> where: T is f32, f16, i32, u32 or bool
+
+5 candidate conversions:
+ vec2<T>(vec2<U>) -> vec2<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
+ vec2<T>(vec2<U>) -> vec2<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Init) {
+ // const one = 1;
+ // const result = (one == 1) || (vec2<f32>(1.0, true).x == 0.0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(MemberAccessor(vec2<f32>(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a);
+ GlobalConst("result", LogicalOr(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching initializer for vec2<f32>(abstract-float, bool)
+
+4 candidate initializers:
+ vec2(x: T, y: T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2(T) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2(vec2<T>) -> vec2<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ vec2<T>() -> vec2<T> where: T is f32, f16, i32, u32 or bool
+
+5 candidate conversions:
+ vec2<T>(vec2<U>) -> vec2<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
+ vec2<T>(vec2<U>) -> vec2<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
+ vec2<T>(vec2<U>) -> vec2<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Array/Struct Init
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid array/struct init as const eval of
+// array/struct init always succeeds.
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_StructInit) {
+ // struct S {
+ // a : i32,
+ // b : f32,
+ // }
+ // const one = 1;
+ // const result = (one == 0) && Foo(1, true).a == 0;
+ Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(
+ MemberAccessor(Construct(ty.type_name("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"),
+ 0_a);
+ GlobalConst("result", LogicalAnd(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in struct initializer does not match struct member type: "
+ "expected 'f32', found 'bool'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_StructInit) {
+ // struct S {
+ // a : i32,
+ // b : f32,
+ // }
+ // const one = 1;
+ // const result = (one == 1) || Foo(1, true).a == 0;
+ Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(
+ MemberAccessor(Construct(ty.type_name("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"),
+ 0_a);
+ GlobalConst("result", LogicalOr(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type in struct initializer does not match struct member type: "
+ "expected 'f32', found 'bool'");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Builtin Call
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_BuiltinCall) {
+ // const one = 1;
+ // return (one == 0) && (extractBits(1, 0, 99) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_BuiltinCall) {
+ // const one = 1;
+ // return (one == 1) && (extractBits(1, 0, 99) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_BuiltinCall) {
+ // const one = 1;
+ // return (one == 0) && (extractBits(1, 0, 99) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (i32, abstract-float)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_BuiltinCall) {
+ // const one = 1;
+ // return (one == 1) || (extractBits(1, 0, 99) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_BuiltinCall) {
+ // const one = 1;
+ // return (one == 0) || (extractBits(1, 0, 99) == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_BuiltinCall) {
+ // const one = 1;
+ // return (one == 1) || (extractBits(1, 0, 99) == 0.0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a);
+ auto* binary = LogicalOr(lhs, rhs);
+ GlobalConst("result", binary);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: no matching overload for operator == (i32, abstract-float)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Literal
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid literal as const eval of a literal does not
+// fail.
+
+#if TINT_BUILD_WGSL_READER
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Literal) {
+ // NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder
+ // for this test.
+ auto src = R"(
+const one = 1;
+const result = (one == 0) && (1111111111111111111111111111111i == 0);
+)";
+
+ auto file = std::make_unique<Source::File>("test", src);
+ auto program = reader::wgsl::Parse(file.get());
+ EXPECT_FALSE(program.IsValid());
+
+ diag::Formatter::Style style;
+ style.print_newline_at_end = false;
+ auto error = diag::Formatter(style).format(program.Diagnostics());
+ EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32'
+const result = (one == 0) && (1111111111111111111111111111111i == 0);
+ ^
+)");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Literal) {
+ // NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder
+ // for this test.
+ auto src = R"(
+const one = 1;
+const result = (one == 1) || (1111111111111111111111111111111i == 0);
+)";
+
+ auto file = std::make_unique<Source::File>("test", src);
+ auto program = reader::wgsl::Parse(file.get());
+ EXPECT_FALSE(program.IsValid());
+
+ diag::Formatter::Style style;
+ style.print_newline_at_end = false;
+ auto error = diag::Formatter(style).format(program.Diagnostics());
+ EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32'
+const result = (one == 1) || (1111111111111111111111111111111i == 0);
+ ^
+)");
+}
+#endif // TINT_BUILD_WGSL_READER
+
+////////////////////////////////////////////////
+// Short-Circuit Member Access
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid member access as const eval of member access
+// always succeeds.
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_MemberAccess) {
+ // struct S {
+ // a : i32,
+ // b : f32,
+ // }
+ // const s = S(1, 2.0);
+ // const one = 1;
+ // const result = (one == 0) && (s.c == 0);
+ Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
+ GlobalConst("s", Construct(ty.type_name("S"), Expr(1_a), Expr(2.0_a)));
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", Expr("c")), 0_a);
+ GlobalConst("result", LogicalAnd(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: struct member c not found");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_MemberAccess) {
+ // struct S {
+ // a : i32,
+ // b : f32,
+ // }
+ // const s = S(1, 2.0);
+ // const one = 1;
+ // const result = (one == 1) || (s.c == 0);
+ Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
+ GlobalConst("s", Construct(ty.type_name("S"), Expr(1_a), Expr(2.0_a)));
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", Expr("c")), 0_a);
+ GlobalConst("result", LogicalOr(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: struct member c not found");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Swizzle
+////////////////////////////////////////////////
+
+// NOTE: Cannot demonstrate short-circuiting an invalid swizzle as const eval of swizzle always
+// succeeds.
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Swizzle) {
+ // const one = 1;
+ // const result = (one == 0) && (vec2(1, 2).z == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Expr(Source{{12, 34}}, "z")), 0_a);
+ GlobalConst("result", LogicalAnd(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member");
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Swizzle) {
+ // const one = 1;
+ // const result = (one == 1) || (vec2(1, 2).z == 0);
+ GlobalConst("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Expr(Source{{12, 34}}, "z")), 0_a);
+ GlobalConst("result", LogicalOr(lhs, rhs));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member");
+}
+
+////////////////////////////////////////////////
+// Short-Circuit Nested
+////////////////////////////////////////////////
+
+#if TINT_BUILD_WGSL_READER
+using ResolverConstEvalTestShortCircuit = ResolverTestWithParam<std::tuple<const char*, bool>>;
+TEST_P(ResolverConstEvalTestShortCircuit, Test) {
+ const char* expr = std::get<0>(GetParam());
+ bool should_pass = std::get<1>(GetParam());
+
+ auto src = std::string(R"(
+const one = 1;
+const result = )");
+ src = src + expr + ";";
+ auto file = std::make_unique<Source::File>("test", src);
+ auto program = reader::wgsl::Parse(file.get());
+
+ if (should_pass) {
+ diag::Formatter::Style style;
+ style.print_newline_at_end = false;
+ auto error = diag::Formatter(style).format(program.Diagnostics());
+
+ EXPECT_TRUE(program.IsValid()) << error;
+ } else {
+ EXPECT_FALSE(program.IsValid());
+ }
+}
+INSTANTIATE_TEST_SUITE_P(Nested,
+ ResolverConstEvalTestShortCircuit,
+ testing::ValuesIn(std::vector<std::tuple<const char*, bool>>{
+ // AND nested rhs
+ {"(one == 0) && ((one == 0) && ((2/0)==0))", true},
+ {"(one == 1) && ((one == 0) && ((2/0)==0))", true},
+ {"(one == 0) && ((one == 1) && ((2/0)==0))", true},
+ {"(one == 1) && ((one == 1) && ((2/0)==0))", false},
+ // AND nested lhs
+ {"((one == 0) && ((2/0)==0)) && (one == 0)", true},
+ {"((one == 0) && ((2/0)==0)) && (one == 1)", true},
+ {"((one == 1) && ((2/0)==0)) && (one == 0)", false},
+ {"((one == 1) && ((2/0)==0)) && (one == 1)", false},
+ // OR nested rhs
+ {"(one == 1) || ((one == 1) || ((2/0)==0))", true},
+ {"(one == 0) || ((one == 1) || ((2/0)==0))", true},
+ {"(one == 1) || ((one == 0) || ((2/0)==0))", true},
+ {"(one == 0) || ((one == 0) || ((2/0)==0))", false},
+ // OR nested lhs
+ {"((one == 1) || ((2/0)==0)) || (one == 1)", true},
+ {"((one == 1) || ((2/0)==0)) || (one == 0)", true},
+ {"((one == 0) || ((2/0)==0)) || (one == 1)", false},
+ {"((one == 0) || ((2/0)==0)) || (one == 0)", false},
+ // AND nested both sides
+ {"((one == 0) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", true},
+ {"((one == 0) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", true},
+ {"((one == 1) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", false},
+ {"((one == 1) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", false},
+ // OR nested both sides
+ {"((one == 1) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", true},
+ {"((one == 1) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false},
+ {"((one == 0) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", false},
+ {"((one == 0) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false},
+ // AND chained
+ {"(one == 0) && (one == 0) && ((2 / 0) == 0)", true},
+ {"(one == 1) && (one == 0) && ((2 / 0) == 0)", true},
+ {"(one == 0) && (one == 1) && ((2 / 0) == 0)", true},
+ {"(one == 1) && (one == 1) && ((2 / 0) == 0)", false},
+ // OR chained
+ {"(one == 1) || (one == 1) || ((2 / 0) == 0)", true},
+ {"(one == 0) || (one == 1) || ((2 / 0) == 0)", true},
+ {"(one == 1) || (one == 0) || ((2 / 0) == 0)", true},
+ {"(one == 0) || (one == 0) || ((2 / 0) == 0)", false},
+ }));
+#endif // TINT_BUILD_WGSL_READER
+
+} // namespace LogicalShortCircuit
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/evaluation_stage_test.cc b/src/tint/resolver/evaluation_stage_test.cc
index 4d637b5..258031c 100644
--- a/src/tint/resolver/evaluation_stage_test.cc
+++ b/src/tint/resolver/evaluation_stage_test.cc
@@ -293,5 +293,53 @@
EXPECT_EQ(Sem().Get(expr)->Stage(), sem::EvaluationStage::kRuntime);
}
+TEST_F(ResolverEvaluationStageTest, Binary_Runtime) {
+ // let one = 1;
+ // let result = (one == 1) && (one == 1);
+ auto* one = Let("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal("one", 1_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Let("result", binary);
+ WrapInFunction(one, result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverEvaluationStageTest, Binary_Const) {
+ // const one = 1;
+ // const result = (one == 1) && (one == 1);
+ auto* one = Const("one", Expr(1_a));
+ auto* lhs = Equal("one", 1_a);
+ auto* rhs = Equal("one", 1_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Const("result", binary);
+ WrapInFunction(one, result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kConstant);
+}
+
+TEST_F(ResolverEvaluationStageTest, Binary_NotEvaluated) {
+ // const one = 1;
+ // const result = (one == 0) && (one == 1);
+ auto* one = Const("one", Expr(1_a));
+ auto* lhs = Equal("one", 0_a);
+ auto* rhs = Equal("one", 1_a);
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Const("result", binary);
+ WrapInFunction(one, result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), sem::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kConstant);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index eb1b7a1..ce18dc4 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1510,6 +1510,11 @@
failed = true;
return ast::TraverseAction::Stop;
}
+ if (auto* binary = expr->As<ast::BinaryExpression>();
+ binary && binary->IsLogical()) {
+ // Store potential const-eval short-circuit pair
+ logical_binary_lhs_to_parent_.Add(binary->lhs, binary);
+ }
sorted.Push(expr);
return ast::TraverseAction::Descend;
})) {
@@ -1568,6 +1573,26 @@
if (expr == root) {
return sem_expr;
}
+
+ // If we just processed the lhs of a constexpr logical binary expression, mark the rhs for
+ // short-circuiting.
+ if (sem_expr->ConstantValue()) {
+ if (auto binary = logical_binary_lhs_to_parent_.Find(expr)) {
+ const bool lhs_is_true = sem_expr->ConstantValue()->As<bool>();
+ if (((*binary)->IsLogicalAnd() && !lhs_is_true) ||
+ ((*binary)->IsLogicalOr() && lhs_is_true)) {
+ // Mark entire expression tree to not const-evaluate
+ auto r = ast::TraverseExpressions( //
+ (*binary)->rhs, diagnostics_, [&](const ast::Expression* e) {
+ skip_const_eval_.Add(e);
+ return ast::TraverseAction::Descend;
+ });
+ if (!r) {
+ return nullptr;
+ }
+ }
+ }
+ }
}
TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
@@ -1779,27 +1804,32 @@
return nullptr;
}
- auto expr_val = expr->ConstantValue();
- if (!expr_val) {
- TINT_ICE(Resolver, builder_->Diagnostics())
- << decl->source << "Materialize(" << decl->TypeInfo().name
- << ") called on expression with no constant value";
- return nullptr;
+ const sem::Constant* materialized_val = nullptr;
+ if (!skip_const_eval_.Contains(decl)) {
+ auto expr_val = expr->ConstantValue();
+ if (!expr_val) {
+ TINT_ICE(Resolver, builder_->Diagnostics())
+ << decl->source << "Materialize(" << decl->TypeInfo().name
+ << ") called on expression with no constant value";
+ return nullptr;
+ }
+
+ auto val = const_eval_.Convert(concrete_ty, expr_val, decl->source);
+ if (!val) {
+ // Convert() has already failed and raised an diagnostic error.
+ return nullptr;
+ }
+ materialized_val = val.Get();
+ if (!materialized_val) {
+ TINT_ICE(Resolver, builder_->Diagnostics())
+ << decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val->Type())
+ << " -> " << builder_->FriendlyName(concrete_ty) << ") returned invalid value";
+ return nullptr;
+ }
}
- auto materialized_val = const_eval_.Convert(concrete_ty, expr_val, decl->source);
- if (!materialized_val) {
- // ConvertValue() has already failed and raised an diagnostic error.
- return nullptr;
- }
-
- if (!materialized_val.Get()) {
- TINT_ICE(Resolver, builder_->Diagnostics())
- << decl->source << "ConvertValue(" << builder_->FriendlyName(expr_val->Type()) << " -> "
- << builder_->FriendlyName(concrete_ty) << ") returned invalid value";
- return nullptr;
- }
- auto* m = builder_->create<sem::Materialize>(expr, current_statement_, materialized_val.Get());
+ auto* m =
+ builder_->create<sem::Materialize>(expr, current_statement_, concrete_ty, materialized_val);
m->Behaviors() = expr->Behaviors();
builder_->Sem().Replace(decl, m);
return m;
@@ -1894,12 +1924,16 @@
ty = builder_->create<type::Reference>(ty, ref->AddressSpace(), ref->Access());
}
- auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
const sem::Constant* val = nullptr;
- if (auto r = const_eval_.Index(obj, idx)) {
- val = r.Get();
+ auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
+ if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ stage = sem::EvaluationStage::kNotEvaluated;
} else {
- return nullptr;
+ if (auto r = const_eval_.Index(obj, idx)) {
+ val = r.Get();
+ } else {
+ return nullptr;
+ }
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>(
@@ -1922,6 +1956,7 @@
RegisterLoadIfNeeded(inner);
const sem::Constant* val = nullptr;
+ // TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented.
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
} else {
@@ -1981,8 +2016,12 @@
if (!MaybeMaterializeArguments(args, ctor_or_conv.target)) {
return nullptr;
}
+
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
+ if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ stage = sem::EvaluationStage::kNotEvaluated;
+ }
if (stage == sem::EvaluationStage::kConstant) {
auto const_args = ConvertArguments(args, ctor_or_conv.target);
if (!const_args) {
@@ -2302,13 +2341,17 @@
// If the builtin is @const, and all arguments have constant values, evaluate the builtin
// now.
- auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage());
const sem::Constant* value = nullptr;
+ auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage());
+ if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ stage = sem::EvaluationStage::kNotEvaluated;
+ }
if (stage == sem::EvaluationStage::kConstant) {
auto const_args = ConvertArguments(args, builtin.sem);
if (!const_args) {
return nullptr;
}
+
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(),
const_args.Get(), expr->source)) {
value = r.Get();
@@ -2787,19 +2830,25 @@
const sem::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
- auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
- // Implicit conversion (e.g. AInt -> AFloat)
- if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
- return nullptr;
- }
- if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
- return nullptr;
- }
-
- if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
- value = r.Get();
+ if (skip_const_eval_.Contains(expr)) {
+ stage = sem::EvaluationStage::kNotEvaluated;
+ } else if (skip_const_eval_.Contains(expr->rhs)) {
+ // Only the rhs should be short-circuited, use the lhs value
+ value = lhs->ConstantValue();
} else {
- return nullptr;
+ auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
+ // Implicit conversion (e.g. AInt -> AFloat)
+ if (!Convert(const_args[0], op.lhs, lhs->Declaration()->source)) {
+ return nullptr;
+ }
+ if (!Convert(const_args[1], op.rhs, rhs->Declaration()->source)) {
+ return nullptr;
+ }
+ if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
+ value = r.Get();
+ } else {
+ return nullptr;
+ }
}
} else {
stage = sem::EvaluationStage::kRuntime;
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index 869cb19..451de0b 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -478,6 +478,9 @@
uint32_t current_scoping_depth_ = 0;
utils::UniqueVector<const sem::GlobalVariable*, 4>* resolved_overrides_ = nullptr;
utils::Hashset<TypeAndAddressSpace, 8> valid_type_storage_layouts_;
+ utils::Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8>
+ logical_binary_lhs_to_parent_;
+ utils::Hashset<const ast::Expression*, 8> skip_const_eval_;
};
} // namespace tint::resolver
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 9d21314..a3bc938 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -1318,6 +1318,9 @@
bool Validator::EvaluationStage(const sem::Expression* expr,
sem::EvaluationStage latest_stage,
std::string_view constraint) const {
+ if (expr->Stage() == sem::EvaluationStage::kNotEvaluated) {
+ return true;
+ }
if (expr->Stage() > latest_stage) {
auto stage_name = [](sem::EvaluationStage stage) -> std::string {
switch (stage) {
@@ -1327,6 +1330,8 @@
return "an override-expression";
case sem::EvaluationStage::kConstant:
return "a const-expression";
+ case sem::EvaluationStage::kNotEvaluated:
+ return "an unevaluated expression";
}
return "<unknown>";
};
diff --git a/src/tint/sem/call.cc b/src/tint/sem/call.cc
index 1fb7de4..e89bc5f 100644
--- a/src/tint/sem/call.cc
+++ b/src/tint/sem/call.cc
@@ -32,7 +32,8 @@
target_(target),
arguments_(std::move(arguments)) {
// Check that the stage is no earlier than the target supports
- TINT_ASSERT(Semantic, target->Stage() <= stage);
+ TINT_ASSERT(Semantic,
+ (target->Stage() <= stage) || (stage == sem::EvaluationStage::kNotEvaluated));
}
Call::~Call() = default;
diff --git a/src/tint/sem/evaluation_stage.h b/src/tint/sem/evaluation_stage.h
index b5e554d..a41f345 100644
--- a/src/tint/sem/evaluation_stage.h
+++ b/src/tint/sem/evaluation_stage.h
@@ -22,6 +22,8 @@
/// The earliest point in time that an expression can be evaluated
enum class EvaluationStage {
+ /// Expression will not be evaluated
+ kNotEvaluated,
/// Expression can be evaluated at shader creation time
kConstant,
/// Expression can be evaluated at pipeline creation time
@@ -43,7 +45,7 @@
/// @param stages a list of EvaluationStage.
/// @returns the earliest stage supported by all the provided stages
inline EvaluationStage EarliestStage(std::initializer_list<EvaluationStage> stages) {
- auto earliest = EvaluationStage::kConstant;
+ auto earliest = EvaluationStage::kNotEvaluated;
for (auto stage : stages) {
earliest = std::max(stage, earliest);
}
diff --git a/src/tint/sem/expression_test.cc b/src/tint/sem/expression_test.cc
index f12f0ed..aa3dc00 100644
--- a/src/tint/sem/expression_test.cc
+++ b/src/tint/sem/expression_test.cc
@@ -47,7 +47,7 @@
sem::EvaluationStage::kRuntime, /* statement */ nullptr,
/* constant_value */ nullptr,
/* has_side_effects */ false, /* root_ident */ nullptr);
- auto* b = create<Materialize>(a, /* statement */ nullptr, &c);
+ auto* b = create<Materialize>(a, /* statement */ nullptr, c.Type(), &c);
EXPECT_EQ(a, a->UnwrapMaterialize());
EXPECT_EQ(a, b->UnwrapMaterialize());
diff --git a/src/tint/sem/materialize.cc b/src/tint/sem/materialize.cc
index c735f6d..b682463 100644
--- a/src/tint/sem/materialize.cc
+++ b/src/tint/sem/materialize.cc
@@ -19,10 +19,11 @@
namespace tint::sem {
Materialize::Materialize(const Expression* expr,
const Statement* statement,
+ const type::Type* type,
const Constant* constant)
: Base(/* declaration */ expr->Declaration(),
- /* type */ constant->Type(),
- /* stage */ EvaluationStage::kConstant, // Abstract can only be const-expr
+ /* type */ type,
+ /* stage */ constant ? EvaluationStage::kConstant : EvaluationStage::kNotEvaluated,
/* statement */ statement,
/* constant */ constant,
/* has_side_effects */ false,
diff --git a/src/tint/sem/materialize.h b/src/tint/sem/materialize.h
index 0cbac29..99fee52 100644
--- a/src/tint/sem/materialize.h
+++ b/src/tint/sem/materialize.h
@@ -30,8 +30,12 @@
/// Constructor
/// @param expr the inner expression, being materialized
/// @param statement the statement that owns this expression
- /// @param constant the constant value of this expression
- Materialize(const Expression* expr, const Statement* statement, const Constant* constant);
+ /// @param type concrete type to materialize to
+ /// @param constant the constant value of this expression or nullptr
+ Materialize(const Expression* expr,
+ const Statement* statement,
+ const type::Type* type,
+ const Constant* constant);
/// Destructor
~Materialize() override;
diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc
index 5ae0830..cc5991d 100644
--- a/src/tint/transform/builtin_polyfill.cc
+++ b/src/tint/transform/builtin_polyfill.cc
@@ -801,7 +801,8 @@
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
auto* expr = src->Sem().Get<sem::Expression>(node);
- if (!expr || expr->Stage() == sem::EvaluationStage::kConstant) {
+ if (!expr || expr->Stage() == sem::EvaluationStage::kConstant ||
+ expr->Stage() == sem::EvaluationStage::kNotEvaluated) {
continue; // Don't polyfill @const expressions
}
diff --git a/src/tint/transform/promote_initializers_to_let.cc b/src/tint/transform/promote_initializers_to_let.cc
index a6a5cf6..3b1ea7f 100644
--- a/src/tint/transform/promote_initializers_to_let.cc
+++ b/src/tint/transform/promote_initializers_to_let.cc
@@ -16,11 +16,14 @@
#include <utility>
+#include "src/tint/ast/traverse_expressions.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/type_initializer.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
+#include "src/tint/type/struct.h"
+#include "src/tint/utils/hashset.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToLet);
@@ -36,87 +39,111 @@
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
- HoistToDeclBefore hoist_to_decl_before(ctx);
-
- bool any_promoted = false;
-
- // Hoists array and structure initializers to a constant variable, declared
- // just before the statement of usage.
- auto promote = [&](const sem::Expression* expr) {
- auto* sem_stmt = expr->Stmt();
- if (!sem_stmt) {
- // Expression is outside of a statement. This usually means the
- // expression is part of a global (module-scope) constant declaration.
- // These must be constexpr, and so cannot contain the type of
- // expressions that must be sanitized.
- return true;
+ // Returns true if the expression should be hoisted to a new let statement before the
+ // expression's statement.
+ auto should_hoist = [&](const sem::Expression* expr) {
+ if (!expr->Type()->IsAnyOf<type::Array, type::StructBase>()) {
+ // We only care about array and struct initializers
+ return false;
}
- auto* stmt = sem_stmt->Declaration();
+ // Check whether the expression is an array or structure constructor
+ {
+ // Follow const-chains
+ auto* root_expr = expr;
+ if (expr->Stage() == sem::EvaluationStage::kConstant) {
+ while (auto* user = root_expr->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ root_expr = user->Variable()->Initializer();
+ }
+ }
- if (auto* src_var_decl = stmt->As<ast::VariableDeclStatement>()) {
- if (src_var_decl->variable->initializer == expr->Declaration()) {
- // This statement is just a variable declaration with the
- // initializer as the initializer value. This is what we're
- // attempting to transform to, and so ignore.
- return true;
+ auto* ctor = root_expr->UnwrapMaterialize()->As<sem::Call>();
+ if (!ctor || !ctor->Target()->Is<sem::TypeInitializer>()) {
+ // Root expression is not a type constructor. Not interested in this.
+ return false;
}
}
- auto* src_ty = expr->Type();
- if (!src_ty->IsAnyOf<type::Array, sem::Struct>()) {
- // We only care about array and struct initializers
- return true;
+ if (auto* src_var_decl = expr->Stmt()->Declaration()->As<ast::VariableDeclStatement>()) {
+ if (src_var_decl->variable->initializer == expr->Declaration()) {
+ // This statement is just a variable declaration with the initializer as the
+ // initializer value. This is what we're attempting to transform to, and so
+ // ignore.
+ return false;
+ }
}
- any_promoted = true;
- return hoist_to_decl_before.Add(expr, expr->Declaration(),
- HoistToDeclBefore::VariableKind::kLet);
+ return true;
};
+ // A list of expressions that should be hoisted.
+ utils::Vector<const sem::Expression*, 32> to_hoist;
+ // A set of expressions that are constant, which _may_ need to be hoisted.
+ utils::Hashset<const ast::Expression*, 32> const_chains;
+
+ // Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
for (auto* node : src->ASTNodes().Objects()) {
- bool ok = Switch(
- node, //
- [&](const ast::CallExpression* expr) {
- if (auto* sem = src->Sem().Get(expr)) {
- auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>();
- if (ctor->Target()->Is<sem::TypeInitializer>()) {
- return promote(sem);
- }
+ if (auto* sem = src->Sem().Get<sem::Expression>(node)) {
+ auto* stmt = sem->Stmt();
+ if (!stmt) {
+ // Expression is outside of a statement. This usually means the expression is part
+ // of a global (module-scope) constant declaration. These must be constexpr, and so
+ // cannot contain the type of expressions that must be sanitized.
+ continue;
+ }
+
+ if (sem->Stage() == sem::EvaluationStage::kConstant) {
+ // Expression is constant. We only need to hoist expressions if they're the
+ // outermost constant expression in a chain. Remove the immediate child nodes of the
+ // expression from const_chains, and add this expression to the const_chains. As we
+ // visit leaf-expressions first, this means the content of const_chains only
+ // contains the outer-most constant expressions.
+ auto* expr = sem->Declaration();
+ bool ok = ast::TraverseExpressions(
+ expr, b.Diagnostics(), [&](const ast::Expression* child) {
+ const_chains.Remove(child);
+ return child == expr ? ast::TraverseAction::Descend
+ : ast::TraverseAction::Skip;
+ });
+ if (!ok) {
+ return Program(std::move(b));
}
- return true;
- },
- [&](const ast::IdentifierExpression* expr) {
- if (auto* sem = src->Sem().Get(expr)) {
- if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
- // Identifier resolves to a variable
- if (auto* stmt = user->Stmt()) {
- if (auto* decl = stmt->Declaration()->As<ast::VariableDeclStatement>();
- decl && decl->variable->Is<ast::Const>()) {
- // The identifier is used on the RHS of a 'const' declaration.
- // Ignore.
- return true;
- }
- }
- if (user->Variable()->Declaration()->Is<ast::Const>()) {
- // The identifier resolves to a 'const' variable, but isn't used to
- // initialize another 'const'. This needs promoting.
- return promote(user);
- }
- }
- }
- return true;
- },
- [&](Default) { return true; });
- if (!ok) {
- return Program(std::move(b));
+ const_chains.Add(expr);
+ } else if (should_hoist(sem)) {
+ to_hoist.Push(sem);
+ }
}
}
- if (!any_promoted) {
+ // After walking the full AST, const_chains only contains the outer-most constant expressions.
+ // Check if any of these need hoisting, and append those to to_hoist.
+ for (auto* expr : const_chains) {
+ if (auto* sem = src->Sem().Get(expr); should_hoist(sem)) {
+ to_hoist.Push(sem);
+ }
+ }
+
+ if (to_hoist.IsEmpty()) {
+ // Nothing to do. Skip.
return SkipTransform;
}
+ // The order of to_hoist is currently undefined. Sort by AST node id, which will make this
+ // deterministic.
+ to_hoist.Sort([&](auto* expr_a, auto* expr_b) {
+ return expr_a->Declaration()->node_id < expr_b->Declaration()->node_id;
+ });
+
+ // Hoist all the expression in to_hoist to a constant variable, declared just before the
+ // statement of usage.
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+ for (auto* expr : to_hoist) {
+ if (!hoist_to_decl_before.Add(expr, expr->Declaration(),
+ HoistToDeclBefore::VariableKind::kLet)) {
+ return Program(std::move(b));
+ }
+ }
+
ctx.Clone();
return Program(std::move(b));
}
diff --git a/src/tint/transform/promote_initializers_to_let_test.cc b/src/tint/transform/promote_initializers_to_let_test.cc
index 536d04a..dd72671 100644
--- a/src/tint/transform/promote_initializers_to_let_test.cc
+++ b/src/tint/transform/promote_initializers_to_let_test.cc
@@ -30,7 +30,21 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteInitializersToLetTest, BasicArray) {
+TEST_F(PromoteInitializersToLetTest, BasicConstArray) {
+ auto* src = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const f2 = 3.0;
+ const f3 = 4.0;
+ var i = array<f32, 4u>(f0, f1, f2, f3)[2];
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicRuntimeArray) {
auto* src = R"(
fn f() {
var f0 = 1.0;
@@ -52,13 +66,12 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteInitializersToLetTest, BasicStruct) {
+TEST_F(PromoteInitializersToLetTest, BasicConstStruct) {
auto* src = R"(
struct S {
a : i32,
@@ -71,6 +84,23 @@
}
)";
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, BasicRuntimeStruct) {
+ auto* src = R"(
+struct S {
+ a : i32,
+ b : f32,
+ c : vec3<f32>,
+};
+
+fn f() {
+ let runtime_value = 1;
+ var x = S(runtime_value, 2.0, vec3<f32>()).b;
+}
+)";
+
auto* expect = R"(
struct S {
a : i32,
@@ -79,12 +109,12 @@
}
fn f() {
- let tint_symbol = S(1, 2.0, vec3<f32>());
+ let runtime_value = 1;
+ let tint_symbol = S(runtime_value, 2.0, vec3<f32>());
var x = tint_symbol.b;
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -93,7 +123,8 @@
TEST_F(PromoteInitializersToLetTest, BasicStruct_OutOfOrder) {
auto* src = R"(
fn f() {
- var x = S(1, 2.0, vec3<f32>()).b;
+ let runtime_value = 1;
+ var x = S(runtime_value, 2.0, vec3<f32>()).b;
}
struct S {
@@ -105,7 +136,8 @@
auto* expect = R"(
fn f() {
- let tint_symbol = S(1, 2.0, vec3<f32>());
+ let runtime_value = 1;
+ let tint_symbol = S(runtime_value, 2.0, vec3<f32>());
var x = tint_symbol.b;
}
@@ -116,7 +148,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -133,29 +164,29 @@
fn f() {
var f0 = 100.0;
var f1 = 100.0;
- var i = C[1];
+ var i = C[1]; // Not hoisted, as the final const value is not an array
}
)";
- auto* expect = R"(
-const f0 = 1.0;
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
-const f1 = 2.0;
-
-const C = array<f32, 2u>(f0, f1);
-
+TEST_F(PromoteInitializersToLetTest, GlobalConstBasicArray_OutOfOrder) {
+ auto* src = R"(
fn f() {
var f0 = 100.0;
var f1 = 100.0;
- let tint_symbol = C;
- var i = tint_symbol[1];
+ var i = C[1];
}
+
+const C = array<f32, 2u>(f0, f1);
+
+const f0 = 1.0;
+
+const f1 = 2.0;
)";
- DataMap data;
- auto got = Run<PromoteInitializersToLet>(src);
-
- EXPECT_EQ(expect, str(got));
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
}
TEST_F(PromoteInitializersToLetTest, GlobalConstArrayDynamicIndex) {
@@ -183,43 +214,6 @@
}
)";
- DataMap data;
- auto got = Run<PromoteInitializersToLet>(src);
-
- EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(PromoteInitializersToLetTest, GlobalConstBasicArray_OutOfOrder) {
- auto* src = R"(
-fn f() {
- var f0 = 100.0;
- var f1 = 100.0;
- var i = C[1];
-}
-
-const C = array<f32, 2u>(f0, f1);
-
-const f0 = 1.0;
-
-const f1 = 2.0;
-)";
-
- auto* expect = R"(
-fn f() {
- var f0 = 100.0;
- var f1 = 100.0;
- let tint_symbol = C;
- var i = tint_symbol[1];
-}
-
-const C = array<f32, 2u>(f0, f1);
-
-const f0 = 1.0;
-
-const f1 = 2.0;
-)";
-
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -231,7 +225,21 @@
const f0 = 1.0;
const f1 = 2.0;
const C = array<f32, 2u>(f0, f1);
- var i = C[1];
+ var i = C[1]; // Not hoisted, as the final const value is not an array
+}
+)";
+
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstBasicArrayRuntimeIndex) {
+ auto* src = R"(
+fn f() {
+ const f0 = 1.0;
+ const f1 = 2.0;
+ const C = array<f32, 2u>(f0, f1);
+ let runtime_value = 1;
+ var i = C[runtime_value];
}
)";
@@ -240,12 +248,12 @@
const f0 = 1.0;
const f1 = 2.0;
const C = array<f32, 2u>(f0, f1);
+ let runtime_value = 1;
let tint_symbol = C;
- var i = tint_symbol[1];
+ var i = tint_symbol[runtime_value];
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -255,7 +263,7 @@
auto* src = R"(
fn f() {
var insert_after = 1;
- for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[2]; ; ) {
+ for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[insert_after]; ; ) {
break;
}
}
@@ -265,13 +273,12 @@
fn f() {
var insert_after = 1;
let tint_symbol = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
- for(var i = tint_symbol[2]; ; ) {
+ for(var i = tint_symbol[insert_after]; ; ) {
break;
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -281,8 +288,9 @@
auto* src = R"(
fn f() {
const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ let runtime_value = 1;
var insert_after = 1;
- for(var i = arr[2]; ; ) {
+ for(var i = arr[runtime_value]; ; ) {
break;
}
}
@@ -291,15 +299,15 @@
auto* expect = R"(
fn f() {
const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
+ let runtime_value = 1;
var insert_after = 1;
let tint_symbol = arr;
- for(var i = tint_symbol[2]; ; ) {
+ for(var i = tint_symbol[runtime_value]; ; ) {
break;
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -310,8 +318,9 @@
const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
fn f() {
+ let runtime_value = 1;
var insert_after = 1;
- for(var i = arr[2]; ; ) {
+ for(var i = arr[runtime_value]; ; ) {
break;
}
}
@@ -321,15 +330,15 @@
const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
fn f() {
+ let runtime_value = 1;
var insert_after = 1;
let tint_symbol = arr;
- for(var i = tint_symbol[2]; ; ) {
+ for(var i = tint_symbol[runtime_value]; ; ) {
break;
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -343,9 +352,13 @@
c : vec3<f32>,
};
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
fn f() {
var insert_after = 1;
- for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
+ for(var x = get_b_runtime(S(1, 2.0, vec3<f32>())); ; ) {
break;
}
}
@@ -358,16 +371,19 @@
c : vec3<f32>,
}
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
fn f() {
var insert_after = 1;
let tint_symbol = S(1, 2.0, vec3<f32>());
- for(var x = tint_symbol.b; ; ) {
+ for(var x = get_b_runtime(tint_symbol); ; ) {
break;
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -377,11 +393,15 @@
auto* src = R"(
fn f() {
var insert_after = 1;
- for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
+ for(var x = get_b_runtime(S(1, 2.0, vec3<f32>())); ; ) {
break;
}
}
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
struct S {
a : i32,
b : f32,
@@ -393,11 +413,15 @@
fn f() {
var insert_after = 1;
let tint_symbol = S(1, 2.0, vec3<f32>());
- for(var x = tint_symbol.b; ; ) {
+ for(var x = get_b_runtime(tint_symbol); ; ) {
break;
}
}
+fn get_b_runtime(s : S) -> f32 {
+ return s.b;
+}
+
struct S {
a : i32,
b : f32,
@@ -405,7 +429,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -440,7 +463,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -449,9 +471,10 @@
TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopCond) {
auto* src = R"(
fn f() {
+ let runtime_value = 0;
const f = 1.0;
const arr = array<f32, 1u>(f);
- for(var i = f; i == arr[0]; i = i + 1.0) {
+ for(var i = f; i == arr[runtime_value]; i = i + 1.0) {
var marker = 1;
}
}
@@ -459,13 +482,14 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
const f = 1.0;
const arr = array<f32, 1u>(f);
{
var i = f;
loop {
let tint_symbol = arr;
- if (!((i == tint_symbol[0]))) {
+ if (!((i == tint_symbol[runtime_value]))) {
break;
}
{
@@ -480,7 +504,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -493,7 +516,8 @@
const arr = array<f32, 1u>(f);
fn F() {
- for(var i = f; i == arr[0]; i = i + 1.0) {
+ let runtime_value = 0;
+ for(var i = f; i == arr[runtime_value]; i = i + 1.0) {
var marker = 1;
}
}
@@ -505,11 +529,12 @@
const arr = array<f32, 1u>(f);
fn F() {
+ let runtime_value = 0;
{
var i = f;
loop {
let tint_symbol = arr;
- if (!((i == tint_symbol[0]))) {
+ if (!((i == tint_symbol[runtime_value]))) {
break;
}
{
@@ -524,7 +549,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -533,8 +557,9 @@
TEST_F(PromoteInitializersToLetTest, ArrayInForLoopCont) {
auto* src = R"(
fn f() {
+ let runtime_value = 0;
var f = 0.0;
- for(; f < 10.0; f = f + array<f32, 1u>(1.0)[0]) {
+ for(; f < 10.0; f = f + array<f32, 1u>(1.0)[runtime_value]) {
var marker = 1;
}
}
@@ -542,6 +567,7 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
var f = 0.0;
loop {
if (!((f < 10.0))) {
@@ -553,13 +579,12 @@
continuing {
let tint_symbol = array<f32, 1u>(1.0);
- f = (f + tint_symbol[0]);
+ f = (f + tint_symbol[runtime_value]);
}
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -568,9 +593,10 @@
TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopCont) {
auto* src = R"(
fn f() {
+ let runtime_value = 0;
const arr = array<f32, 1u>(1.0);
var f = 0.0;
- for(; f < 10.0; f = f + arr[0]) {
+ for(; f < 10.0; f = f + arr[runtime_value]) {
var marker = 1;
}
}
@@ -578,6 +604,7 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
const arr = array<f32, 1u>(1.0);
var f = 0.0;
loop {
@@ -590,13 +617,12 @@
continuing {
let tint_symbol = arr;
- f = (f + tint_symbol[0]);
+ f = (f + tint_symbol[runtime_value]);
}
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -607,8 +633,9 @@
const arr = array<f32, 1u>(1.0);
fn f() {
+ let runtime_value = 0;
var f = 0.0;
- for(; f < 10.0; f = f + arr[0]) {
+ for(; f < 10.0; f = f + arr[runtime_value]) {
var marker = 1;
}
}
@@ -618,6 +645,7 @@
const arr = array<f32, 1u>(1.0);
fn f() {
+ let runtime_value = 0;
var f = 0.0;
loop {
if (!((f < 10.0))) {
@@ -629,13 +657,12 @@
continuing {
let tint_symbol = arr;
- f = (f + tint_symbol[0]);
+ f = (f + tint_symbol[runtime_value]);
}
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -644,9 +671,10 @@
TEST_F(PromoteInitializersToLetTest, ArrayInForLoopInitCondCont) {
auto* src = R"(
fn f() {
- for(var f = array<f32, 1u>(0.0)[0];
- f < array<f32, 1u>(1.0)[0];
- f = f + array<f32, 1u>(2.0)[0]) {
+ let runtime_value = 0;
+ for(var f = array<f32, 1u>(0.0)[runtime_value];
+ f < array<f32, 1u>(1.0)[runtime_value];
+ f = f + array<f32, 1u>(2.0)[runtime_value]) {
var marker = 1;
}
}
@@ -654,12 +682,13 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
let tint_symbol = array<f32, 1u>(0.0);
{
- var f = tint_symbol[0];
+ var f = tint_symbol[runtime_value];
loop {
let tint_symbol_1 = array<f32, 1u>(1.0);
- if (!((f < tint_symbol_1[0]))) {
+ if (!((f < tint_symbol_1[runtime_value]))) {
break;
}
{
@@ -668,14 +697,13 @@
continuing {
let tint_symbol_2 = array<f32, 1u>(2.0);
- f = (f + tint_symbol_2[0]);
+ f = (f + tint_symbol_2[runtime_value]);
}
}
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -684,10 +712,11 @@
TEST_F(PromoteInitializersToLetTest, LocalConstArrayInForLoopInitCondCont) {
auto* src = R"(
fn f() {
+ let runtime_value = 0;
const arr_a = array<f32, 1u>(0.0);
const arr_b = array<f32, 1u>(1.0);
const arr_c = array<f32, 1u>(2.0);
- for(var f = arr_a[0]; f < arr_b[0]; f = f + arr_c[0]) {
+ for(var f = arr_a[runtime_value]; f < arr_b[runtime_value]; f = f + arr_c[runtime_value]) {
var marker = 1;
}
}
@@ -695,15 +724,16 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
const arr_a = array<f32, 1u>(0.0);
const arr_b = array<f32, 1u>(1.0);
const arr_c = array<f32, 1u>(2.0);
let tint_symbol = arr_a;
{
- var f = tint_symbol[0];
+ var f = tint_symbol[runtime_value];
loop {
let tint_symbol_1 = arr_b;
- if (!((f < tint_symbol_1[0]))) {
+ if (!((f < tint_symbol_1[runtime_value]))) {
break;
}
{
@@ -712,14 +742,13 @@
continuing {
let tint_symbol_2 = arr_c;
- f = (f + tint_symbol_2[0]);
+ f = (f + tint_symbol_2[runtime_value]);
}
}
}
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -751,7 +780,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -802,7 +830,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -811,15 +838,16 @@
TEST_F(PromoteInitializersToLetTest, LocalConstArrayInElseIfChain) {
auto* src = R"(
fn f() {
+ let runtime_value = 0;
const f = 1.0;
const arr = array<f32, 2u>(f, f);
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
- } else if (f == arr[0]) {
+ } else if (f == arr[runtime_value]) {
var marker = 2;
- } else if (f == arr[1]) {
+ } else if (f == arr[runtime_value + 1]) {
var marker = 3;
} else if (true) {
var marker = 4;
@@ -831,6 +859,7 @@
auto* expect = R"(
fn f() {
+ let runtime_value = 0;
const f = 1.0;
const arr = array<f32, 2u>(f, f);
if (true) {
@@ -839,11 +868,11 @@
var marker = 1;
} else {
let tint_symbol = arr;
- if ((f == tint_symbol[0])) {
+ if ((f == tint_symbol[runtime_value])) {
var marker = 2;
} else {
let tint_symbol_1 = arr;
- if ((f == tint_symbol_1[1])) {
+ if ((f == tint_symbol_1[(runtime_value + 1)])) {
var marker = 3;
} else if (true) {
var marker = 4;
@@ -855,7 +884,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -868,13 +896,14 @@
const arr = array<f32, 2u>(f, f);
fn F() {
+ let runtime_value = 0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
- } else if (f == arr[0]) {
+ } else if (f == arr[runtime_value]) {
var marker = 2;
- } else if (f == arr[1]) {
+ } else if (f == arr[runtime_value + 1]) {
var marker = 3;
} else if (true) {
var marker = 4;
@@ -890,17 +919,18 @@
const arr = array<f32, 2u>(f, f);
fn F() {
+ let runtime_value = 0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else {
let tint_symbol = arr;
- if ((f == tint_symbol[0])) {
+ if ((f == tint_symbol[runtime_value])) {
var marker = 2;
} else {
let tint_symbol_1 = arr;
- if ((f == tint_symbol_1[1])) {
+ if ((f == tint_symbol_1[(runtime_value + 1)])) {
var marker = 3;
} else if (true) {
var marker = 4;
@@ -912,35 +942,43 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteInitializersToLetTest, ArrayInArrayArray) {
+TEST_F(PromoteInitializersToLetTest, ArrayInArrayArrayConstIndex) {
auto* src = R"(
fn f() {
var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
}
)";
- auto* expect = R"(
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, ArrayInArrayArrayRuntimeIndex) {
+ auto* src = R"(
fn f() {
- let tint_symbol = array<f32, 2u>(1.0, 2.0);
- let tint_symbol_1 = array<f32, 2u>(3.0, 4.0);
- let tint_symbol_2 = array<array<f32, 2u>, 2u>(tint_symbol, tint_symbol_1);
- var i = tint_symbol_2[0][1];
+ let runtime_value = 1;
+ var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[runtime_value][runtime_value + 1];
}
)";
- DataMap data;
+ auto* expect = R"(
+fn f() {
+ let runtime_value = 1;
+ let tint_symbol = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0));
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
+}
+)";
+
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
}
-TEST_F(PromoteInitializersToLetTest, LocalConstArrayInArrayArray) {
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInArrayArrayConstIndex) {
auto* src = R"(
fn f() {
const arr_0 = array<f32, 2u>(1.0, 2.0);
@@ -950,17 +988,31 @@
}
)";
+ EXPECT_FALSE(ShouldRun<PromoteInitializersToLet>(src));
+}
+
+TEST_F(PromoteInitializersToLetTest, LocalConstArrayInArrayArrayRuntimeIndex) {
+ auto* src = R"(
+fn f() {
+ let runtime_value = 1;
+ const arr_0 = array<f32, 2u>(1.0, 2.0);
+ const arr_1 = array<f32, 2u>(3.0, 4.0);
+ const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
+ var i = arr_2[runtime_value][runtime_value + 1];
+}
+)";
+
auto* expect = R"(
fn f() {
+ let runtime_value = 1;
const arr_0 = array<f32, 2u>(1.0, 2.0);
const arr_1 = array<f32, 2u>(3.0, 4.0);
const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
let tint_symbol = arr_2;
- var i = tint_symbol[0][1];
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -975,7 +1027,8 @@
const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
fn f() {
- var i = arr_2[0][1];
+ let runtime_value = 1;
+ var i = arr_2[runtime_value][runtime_value + 1];
}
)";
@@ -987,12 +1040,12 @@
const arr_2 = array<array<f32, 2u>, 2u>(arr_0, arr_1);
fn f() {
+ let runtime_value = 1;
let tint_symbol = arr_2;
- var i = tint_symbol[0][1];
+ var i = tint_symbol[runtime_value][(runtime_value + 1)];
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1014,8 +1067,12 @@
a : S2,
};
+fn get_a(s : S3) -> S2 {
+ return s.a;
+}
+
fn f() {
- var x = S3(S2(1, S1(2), 3)).a.b.a;
+ var x = get_a(S3(S2(1, S1(2), 3))).b.a;
}
)";
@@ -1034,15 +1091,16 @@
a : S2,
}
+fn get_a(s : S3) -> S2 {
+ return s.a;
+}
+
fn f() {
- let tint_symbol = S1(2);
- let tint_symbol_1 = S2(1, tint_symbol, 3);
- let tint_symbol_2 = S3(tint_symbol_1);
- var x = tint_symbol_2.a.b.a;
+ let tint_symbol = S3(S2(1, S1(2), 3));
+ var x = get_a(tint_symbol).b.a;
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1058,8 +1116,12 @@
a : array<S1, 3u>,
};
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
fn f() {
- var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
+ var x = get_a(S2(array<S1, 3u>(S1(1), S1(2), S1(3))))[1].a;
}
)";
@@ -1072,17 +1134,16 @@
a : array<S1, 3u>,
}
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
+}
+
fn f() {
- let tint_symbol = S1(1);
- let tint_symbol_1 = S1(2);
- let tint_symbol_2 = S1(3);
- let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
- let tint_symbol_4 = S2(tint_symbol_3);
- var x = tint_symbol_4.a[1].a;
+ let tint_symbol = S2(array<S1, 3u>(S1(1), S1(2), S1(3)));
+ var x = get_a(tint_symbol)[1].a;
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1091,7 +1152,11 @@
TEST_F(PromoteInitializersToLetTest, Mixed_OutOfOrder) {
auto* src = R"(
fn f() {
- var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
+ var x = get_a(S2(array<S1, 3u>(S1(1), S1(2), S1(3))))[1].a;
+}
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
}
struct S2 {
@@ -1105,12 +1170,12 @@
auto* expect = R"(
fn f() {
- let tint_symbol = S1(1);
- let tint_symbol_1 = S1(2);
- let tint_symbol_2 = S1(3);
- let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
- let tint_symbol_4 = S2(tint_symbol_3);
- var x = tint_symbol_4.a[1].a;
+ let tint_symbol = S2(array<S1, 3u>(S1(1), S1(2), S1(3)));
+ var x = get_a(tint_symbol)[1].a;
+}
+
+fn get_a(s : S2) -> array<S1, 3u> {
+ return s.a;
}
struct S2 {
@@ -1122,7 +1187,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1144,7 +1208,6 @@
auto* expect = src;
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1166,7 +1229,6 @@
auto* expect = src;
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
@@ -1246,7 +1308,6 @@
}
)";
- DataMap data;
auto got = Run<PromoteInitializersToLet>(src);
EXPECT_EQ(expect, str(got));
diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h
index 53f71f5..2d427a1 100644
--- a/src/tint/utils/hashset.h
+++ b/src/tint/utils/hashset.h
@@ -43,6 +43,18 @@
struct NoValue {};
return this->template Put<PutMode::kAdd>(std::forward<V>(value), NoValue{});
}
+
+ /// @returns the set entries of the map as a vector
+ /// @note the order of the returned vector is non-deterministic between compilers.
+ template <size_t N2 = N>
+ utils::Vector<KEY, N2> Vector() const {
+ utils::Vector<KEY, N2> out;
+ out.Reserve(this->Count());
+ for (auto& value : *this) {
+ out.Push(value);
+ }
+ return out;
+ }
};
} // namespace tint::utils
diff --git a/src/tint/utils/hashset_test.cc b/src/tint/utils/hashset_test.cc
index 64f0da3..c541fb5 100644
--- a/src/tint/utils/hashset_test.cc
+++ b/src/tint/utils/hashset_test.cc
@@ -91,6 +91,16 @@
EXPECT_THAT(set, testing::UnorderedElementsAre("one", "two", "three", "four"));
}
+TEST(Hashset, Vector) {
+ Hashset<std::string, 8> set;
+ set.Add("one");
+ set.Add("four");
+ set.Add("three");
+ set.Add("two");
+ auto vec = set.Vector();
+ EXPECT_THAT(vec, testing::UnorderedElementsAre("one", "two", "three", "four"));
+}
+
TEST(Hashset, Soak) {
std::mt19937 rnd;
std::unordered_set<std::string> reference;
diff --git a/src/tint/writer/glsl/generator_impl_sanitizer_test.cc b/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
index 9c425a8..6bcca21 100644
--- a/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
+++ b/src/tint/writer/glsl/generator_impl_sanitizer_test.cc
@@ -186,12 +186,14 @@
Member("b", ty.vec3<f32>()),
Member("c", ty.i32()),
});
- auto* struct_init = Construct(ty.Of(str), 1_i, vec3<f32>(2_f, 3_f, 4_f), 4_i);
+ auto* runtime_value = Var("runtime_value", Expr(3_f));
+ auto* struct_init = Construct(ty.Of(str), 1_i, vec3<f32>(2_f, runtime_value, 4_f), 4_i);
auto* struct_access = MemberAccessor(struct_init, "b");
auto* pos = Var("pos", ty.vec3<f32>(), struct_access);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
+ Decl(runtime_value),
Decl(pos),
},
utils::Vector{
@@ -213,7 +215,8 @@
};
void tint_symbol() {
- S tint_symbol_1 = S(1, vec3(2.0f, 3.0f, 4.0f), 4);
+ float runtime_value = 3.0f;
+ S tint_symbol_1 = S(1, vec3(2.0f, runtime_value, 4.0f), 4);
vec3 pos = tint_symbol_1.b;
}
diff --git a/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc b/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
index dbe8a51..761eb82 100644
--- a/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_sanitizer_test.cc
@@ -197,17 +197,19 @@
}
TEST_F(HlslSanitizerTest, PromoteStructInitializerToConstVar) {
+ auto* runtime_value = Var("runtime_value", Expr(3_f));
auto* str = Structure("S", utils::Vector{
Member("a", ty.i32()),
Member("b", ty.vec3<f32>()),
Member("c", ty.i32()),
});
- auto* struct_init = Construct(ty.Of(str), 1_i, vec3<f32>(2_f, 3_f, 4_f), 4_i);
+ auto* struct_init = Construct(ty.Of(str), 1_i, vec3<f32>(2_f, runtime_value, 4_f), 4_i);
auto* struct_access = MemberAccessor(struct_init, "b");
auto* pos = Var("pos", ty.vec3<f32>(), struct_access);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
+ Decl(runtime_value),
Decl(pos),
},
utils::Vector{
@@ -226,7 +228,8 @@
};
void main() {
- const S tint_symbol = {1, float3(2.0f, 3.0f, 4.0f), 4};
+ float runtime_value = 3.0f;
+ const S tint_symbol = {1, float3(2.0f, runtime_value, 4.0f), 4};
float3 pos = tint_symbol.b;
return;
}