Import Tint changes from Dawn
Changes:
- 659b5b7727d5e1000271fd75ba99bdca2647a54b tint: add const eval matrix accessor unit tests by Antonio Maiorano <amaiorano@google.com>
- c73d67397624cab512a1bd680370a146a2640a3a tint: add const eval vector accessor unit tests by Antonio Maiorano <amaiorano@google.com>
- b71898ea30e5b10b7f9e96d80b5944fc14538b91 tint: add const eval array accessor unit tests by Antonio Maiorano <amaiorano@google.com>
- dfa92a9cb68366a8655ac26b51af0c740bcf9b72 tint: Fix ProgramBuilder::WrapInFunction overload not bei... by Antonio Maiorano <amaiorano@google.com>
- 64c243e9e675069871cfbb829196be79ca3ebae1 tint: add missing unit tests for const eval vector constr... by Antonio Maiorano <amaiorano@google.com>
- 6b4622fb0721f08323db2ce576f16f2226439c9d tint: add const eval array constructor tests by Antonio Maiorano <amaiorano@google.com>
- 906fc9df206d668191e9660a16688e27eb3d97ce tint/uniformity: Add a NameFor helper by James Price <jrprice@google.com>
- a84ebc3af9c2086581008e7a05444698e1a6eb15 tint: Add forward declaration for CastableBase by James Price <jrprice@google.com>
- 0244804193ccb20b6bbcc08caa385c72bff9547e tint/uniformity: Avoid string allocations for node tags by James Price <jrprice@google.com>
- 857b1580c74193cdab72d77583859f5b96cab30f tint/uniformity: Handle pointer uniformity by James Price <jrprice@google.com>
- 1d77e2531c5e54385fa95608cc4e9f426a5e0bac tint: add const eval of swizzle tests by Antonio Maiorano <amaiorano@google.com>
- ffb322a096a1d0f1e2166e6e26d86ab36e31dbd6 tint: add bool member to const eval struct member access ... by Antonio Maiorano <amaiorano@google.com>
- 309b10a8c5f91c671f5bd586001b8a0db8ae61dd tint: add const eval struct zero init tests by Antonio Maiorano <amaiorano@google.com>
- 0890ecabdae58876413d9b3dff5e13f1d7e8e60f tint: add const eval zero init tests for scalars, vectors... by Antonio Maiorano <amaiorano@google.com>
- 994b70feb919b49acd46ae5d71a1af5ee3192bd5 tint: add AFloat and AInt const eval scalar constructor t... by Antonio Maiorano <amaiorano@google.com>
- bc44620d68bc66b1e664269b7365ea162a08b45c tint: implement short-circuiting of const eval bitcast by Antonio Maiorano <amaiorano@google.com>
- eb34a764a8c486f59f8c50e1faeed8dac04caa7c tint/uniformity: Fix handling of continuing block by James Price <jrprice@google.com>
- efc9df4695dac323230d37a6868455acd7fda395 tint/uniformity: Fix issues with for-loops by James Price <jrprice@google.com>
- 056618541fc5502e475e35cbacc10f4985f0eeb9 tint: const eval of bitcast operator by Antonio Maiorano <amaiorano@google.com>
- ffa83ad1f74d664b4382bf07c4102b6a6ff3a67b tint: make utils::Bitcast not trigger gcc warning by Antonio Maiorano <amaiorano@google.com>
GitOrigin-RevId: 659b5b7727d5e1000271fd75ba99bdca2647a54b
Change-Id: I4562368a6302698b47a28ceeb63c0cebd542f456
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/115320
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index acb05e0..d942b86 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1186,6 +1186,7 @@
"resolver/compound_assignment_validation_test.cc",
"resolver/compound_statement_test.cc",
"resolver/const_eval_binary_op_test.cc",
+ "resolver/const_eval_bitcast_test.cc",
"resolver/const_eval_builtin_test.cc",
"resolver/const_eval_construction_test.cc",
"resolver/const_eval_conversion_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 0777746..e158941 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -894,6 +894,7 @@
resolver/compound_assignment_validation_test.cc
resolver/compound_statement_test.cc
resolver/const_eval_binary_op_test.cc
+ resolver/const_eval_bitcast_test.cc
resolver/const_eval_builtin_test.cc
resolver/const_eval_construction_test.cc
resolver/const_eval_conversion_test.cc
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 5396cdd..4dac010 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -149,6 +149,10 @@
};
} // namespace detail
+// Forward declare metafunction that evaluates to true iff T can be wrapped in a statement.
+template <typename T, typename = void>
+struct CanWrapInStatement;
+
/// ProgramBuilder is a mutable builder for a Program.
/// To construct a Program, populate the builder and then `std::move` it to a
/// Program.
@@ -3326,12 +3330,13 @@
/// by the Resolver.
/// @param args a mix of ast::Expression, ast::Statement, ast::Variables.
/// @returns the function
- template <typename... ARGS>
+ template <typename... ARGS,
+ typename = traits::EnableIf<(CanWrapInStatement<ARGS>::value && ...)>>
const ast::Function* WrapInFunction(ARGS&&... args) {
utils::Vector stmts{
WrapInStatement(std::forward<ARGS>(args))...,
};
- return WrapInFunction(utils::VectorRef<const ast::Statement*>{std::move(stmts)});
+ return WrapInFunction(std::move(stmts));
}
/// @param stmts a list of ast::Statement that will be wrapped by a function,
/// so that each statement is reachable by the Resolver.
@@ -3411,6 +3416,17 @@
return builder->ID();
}
+// Primary template for metafunction that evaluates to true iff T can be wrapped in a statement.
+template <typename T, typename /* = void */>
+struct CanWrapInStatement : std::false_type {};
+
+// Specialization of CanWrapInStatement
+template <typename T>
+struct CanWrapInStatement<
+ T,
+ std::void_t<decltype(std::declval<ProgramBuilder>().WrapInStatement(std::declval<T>()))>>
+ : std::true_type {};
+
} // namespace tint
#endif // SRC_TINT_PROGRAM_BUILDER_H_
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 29c57c7..b52f4aa 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -71,6 +71,17 @@
/// Helper that calls `f` passing in the value of all `cs`.
/// Calls `f` with all constants cast to the type of the first `cs` argument.
template <typename F, typename... CONSTANTS>
+auto Dispatch_fiu32(F&& f, CONSTANTS&&... cs) {
+ return Switch(
+ First(cs...)->Type(), //
+ [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); },
+ [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); },
+ [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); });
+}
+
+/// Helper that calls `f` passing in the value of all `cs`.
+/// Calls `f` with all constants cast to the type of the first `cs` argument.
+template <typename F, typename... CONSTANTS>
auto Dispatch_ia_iu32(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
@@ -1319,9 +1330,30 @@
return builder.create<constant::Composite>(ty, std::move(values));
}
-ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) {
- // TODO(crbug.com/tint/1581): Implement @const intrinsics
- return nullptr;
+ConstEval::Result ConstEval::Bitcast(const type::Type* ty,
+ const constant::Value* value,
+ const Source& source) {
+ auto* el_ty = type::Type::DeepestElementOf(ty);
+ auto transform = [&](const constant::Value* c0) {
+ auto create = [&](auto e) {
+ return Switch(
+ el_ty,
+ [&](const type::U32*) { //
+ auto r = utils::Bitcast<u32>(e);
+ return CreateScalar(builder, source, el_ty, r);
+ },
+ [&](const type::I32*) { //
+ auto r = utils::Bitcast<i32>(e);
+ return CreateScalar(builder, source, el_ty, r);
+ },
+ [&](const type::F32*) { //
+ auto r = utils::Bitcast<f32>(e);
+ return CreateScalar(builder, source, el_ty, r);
+ });
+ };
+ return Dispatch_fiu32(create, c0);
+ };
+ return TransformElements(builder, ty, transform, value);
}
ConstEval::Result ConstEval::OpComplement(const type::Type* ty,
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index d75f5de..dbb1c2e 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -80,10 +80,11 @@
Result ArrayOrStructInit(const type::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// @param ty the target type
- /// @param expr the input expression
+ /// @param value the value being converted
+ /// @param source the source location
/// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated
- Result Bitcast(const type::Type* ty, const sem::Expression* expr);
+ Result Bitcast(const type::Type* ty, const constant::Value* value, const Source& source);
/// @param obj the object being indexed
/// @param idx the index expression
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index e7e306a..7ea998a 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -1775,8 +1775,7 @@
// Short-Circuit Bitcast
////////////////////////////////////////////////
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Invalid_Bitcast) {
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) && (bitcast<f32>(a) == 0.0);
@@ -1791,8 +1790,7 @@
ValidateAnd(Sem(), binary);
}
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_And_Invalid_Bitcast) {
+TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) && (bitcast<f32>(a) == 0.0);
@@ -1804,11 +1802,10 @@
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
+ EXPECT_EQ(r()->error(), "12:34 error: value inf cannot be represented as 'f32'");
}
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_And_Error_Bitcast) {
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) && (bitcast<f32>(a) == 0i);
@@ -1820,11 +1817,15 @@
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (f32, i32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
}
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Invalid_Bitcast) {
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) || (bitcast<f32>(a) == 0.0);
@@ -1839,8 +1840,7 @@
ValidateOr(Sem(), binary);
}
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_NonShortCircuit_Or_Invalid_Bitcast) {
+TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 0) || (bitcast<f32>(a) == 0.0);
@@ -1852,11 +1852,10 @@
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: value not representable as f32 message here");
+ EXPECT_EQ(r()->error(), "12:34 error: value inf cannot be represented as 'f32'");
}
-// @TODO(crbug.com/tint/1581): Enable once const eval of bitcast is implemented
-TEST_F(ResolverConstEvalTest, DISABLED_ShortCircuit_Or_Error_Bitcast) {
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Bitcast) {
// const one = 1;
// const a = 0x7F800000;
// const result = (one == 1) || (bitcast<f32>(a) == 0i);
@@ -1868,7 +1867,12 @@
GlobalConst("result", binary);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload message here)");
+ EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (f32, i32)
+
+2 candidate operators:
+ operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+ operator == (vecN<T>, vecN<T>) -> vecN<bool> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
+)");
}
////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_bitcast_test.cc b/src/tint/resolver/const_eval_bitcast_test.cc
new file mode 100644
index 0000000..50d4693
--- /dev/null
+++ b/src/tint/resolver/const_eval_bitcast_test.cc
@@ -0,0 +1,193 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/const_eval_test.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::resolver {
+namespace {
+
+struct Case {
+ Value input;
+ struct Success {
+ Value value;
+ };
+ struct Failure {
+ builder::CreatePtrs create_ptrs;
+ };
+ utils::Result<Success, Failure> expected;
+};
+
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ o << "input: " << c.input;
+ if (c.expected) {
+ o << ", expected: " << c.expected.Get().value;
+ } else {
+ o << ", expected failed bitcast to " << c.expected.Failure().create_ptrs;
+ }
+ return o;
+}
+
+template <typename TO, typename FROM>
+Case Success(FROM input, TO expected) {
+ return Case{input, Case::Success{expected}};
+}
+
+template <typename TO, typename FROM>
+Case Failure(FROM input) {
+ return Case{input, Case::Failure{builder::CreatePtrsFor<TO>()}};
+}
+
+using ResolverConstEvalBitcastTest = ResolverTestWithParam<Case>;
+
+TEST_P(ResolverConstEvalBitcastTest, Test) {
+ const auto& input = GetParam().input;
+ const auto& expected = GetParam().expected;
+
+ // Get the target type CreatePtrs
+ builder::CreatePtrs target_create_ptrs;
+ if (expected) {
+ target_create_ptrs = expected.Get().value.create_ptrs;
+ } else {
+ target_create_ptrs = expected.Failure().create_ptrs;
+ }
+
+ auto* target_ty = target_create_ptrs.ast(*this);
+ ASSERT_NE(target_ty, nullptr);
+ auto* input_val = input.Expr(*this);
+ const ast::Expression* expr = Bitcast(Source{{12, 34}}, target_ty, input_val);
+
+ WrapInFunction(expr);
+
+ auto* target_sem_ty = target_create_ptrs.sem(*this);
+
+ if (expected) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TYPE(sem->Type(), target_sem_ty);
+ ASSERT_NE(sem->ConstantValue(), nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
+
+ auto expected_values = expected.Get().value.args;
+ auto got_values = ScalarsFrom(sem->ConstantValue());
+ EXPECT_EQ(expected_values, got_values);
+ } else {
+ ASSERT_FALSE(r()->Resolve());
+ EXPECT_THAT(r()->error(), testing::StartsWith("12:34 error:"));
+ EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as"));
+ }
+}
+
+const u32 nan_as_u32 = utils::Bitcast<u32>(std::numeric_limits<float>::quiet_NaN());
+const i32 nan_as_i32 = utils::Bitcast<i32>(std::numeric_limits<float>::quiet_NaN());
+const u32 inf_as_u32 = utils::Bitcast<u32>(std::numeric_limits<float>::infinity());
+const i32 inf_as_i32 = utils::Bitcast<i32>(std::numeric_limits<float>::infinity());
+const u32 neg_inf_as_u32 = utils::Bitcast<u32>(-std::numeric_limits<float>::infinity());
+const i32 neg_inf_as_i32 = utils::Bitcast<i32>(-std::numeric_limits<float>::infinity());
+
+INSTANTIATE_TEST_SUITE_P(Bitcast,
+ ResolverConstEvalBitcastTest,
+ testing::ValuesIn({
+ // Bitcast to same (concrete) type, no change
+ Success(Val(0_u), Val(0_u)), //
+ Success(Val(0_i), Val(0_i)), //
+ Success(Val(0_f), Val(0_f)), //
+ Success(Val(123_u), Val(123_u)), //
+ Success(Val(123_i), Val(123_i)), //
+ Success(Val(123.456_f), Val(123.456_f)), //
+ Success(Val(u32::Highest()), Val(u32::Highest())), //
+ Success(Val(u32::Lowest()), Val(u32::Lowest())), //
+ Success(Val(i32::Highest()), Val(i32::Highest())), //
+ Success(Val(i32::Lowest()), Val(i32::Lowest())), //
+ Success(Val(f32::Highest()), Val(f32::Highest())), //
+ Success(Val(f32::Lowest()), Val(f32::Lowest())), //
+
+ // Bitcast to different type
+ Success(Val(0_u), Val(0_i)), //
+ Success(Val(0_u), Val(0_f)), //
+ Success(Val(0_i), Val(0_u)), //
+ Success(Val(0_i), Val(0_f)), //
+ Success(Val(0.0_f), Val(0_i)), //
+ Success(Val(0.0_f), Val(0_u)), //
+ Success(Val(1_u), Val(1_i)), //
+ Success(Val(1_u), Val(1.4013e-45_f)), //
+ Success(Val(1_i), Val(1_u)), //
+ Success(Val(1_i), Val(1.4013e-45_f)), //
+ Success(Val(1.0_f), Val(0x3F800000_u)), //
+ Success(Val(1.0_f), Val(0x3F800000_i)), //
+ Success(Val(123_u), Val(123_i)), //
+ Success(Val(123_u), Val(1.7236e-43_f)), //
+ Success(Val(123_i), Val(123_u)), //
+ Success(Val(123_i), Val(1.7236e-43_f)), //
+ Success(Val(123.0_f), Val(0x42F60000_u)), //
+ Success(Val(123.0_f), Val(0x42F60000_i)), //
+
+ // Bitcast from abstract materializes lhs first,
+ // so same results as above.
+ Success(Val(0_a), Val(0_i)), //
+ Success(Val(0_a), Val(0_f)), //
+ Success(Val(0_a), Val(0_u)), //
+ Success(Val(0_a), Val(0_f)), //
+ Success(Val(0_a), Val(0_i)), //
+ Success(Val(0_a), Val(0_u)), //
+ Success(Val(1_a), Val(1_i)), //
+ Success(Val(1_a), Val(1.4013e-45_f)), //
+ Success(Val(1_a), Val(1_u)), //
+ Success(Val(1_a), Val(1.4013e-45_f)), //
+ Success(Val(1.0_a), Val(0x3F800000_u)), //
+ Success(Val(1.0_a), Val(0x3F800000_i)), //
+ Success(Val(123_a), Val(123_i)), //
+ Success(Val(123_a), Val(1.7236e-43_f)), //
+ Success(Val(123_a), Val(123_u)), //
+ Success(Val(123_a), Val(1.7236e-43_f)), //
+ Success(Val(123.0_a), Val(0x42F60000_u)), //
+ Success(Val(123.0_a), Val(0x42F60000_i)), //
+
+ // u32 <-> i32 sign bit
+ Success(Val(0xFFFFFFFF_u), Val(-1_i)), //
+ Success(Val(-1_i), Val(0xFFFFFFFF_u)), //
+ Success(Val(0x80000000_u), Val(i32::Lowest())), //
+ Success(Val(i32::Lowest()), Val(0x80000000_u)), //
+
+ // Vector tests
+ Success(Vec(0_u, 1_u, 123_u), Vec(0_i, 1_i, 123_i)),
+ Success(Vec(0.0_f, 1.0_f, 123.0_f),
+ Vec(0_i, 0x3F800000_i, 0x42F60000_i)),
+
+ // Unrepresentable
+ Failure<f32>(Val(nan_as_u32)), //
+ Failure<f32>(Val(nan_as_i32)), //
+ Failure<f32>(Val(inf_as_u32)), //
+ Failure<f32>(Val(inf_as_i32)), //
+ Failure<f32>(Val(neg_inf_as_u32)), //
+ Failure<f32>(Val(neg_inf_as_i32)), //
+ Failure<builder::vec2<f32>>(Vec(nan_as_u32, 0_u)), //
+ Failure<builder::vec2<f32>>(Vec(nan_as_i32, 0_i)), //
+ Failure<builder::vec2<f32>>(Vec(inf_as_u32, 0_u)), //
+ Failure<builder::vec2<f32>>(Vec(inf_as_i32, 0_i)), //
+ Failure<builder::vec2<f32>>(Vec(neg_inf_as_u32, 0_u)), //
+ Failure<builder::vec2<f32>>(Vec(neg_inf_as_i32, 0_i)), //
+ Failure<builder::vec2<f32>>(Vec(0_u, nan_as_u32)), //
+ Failure<builder::vec2<f32>>(Vec(0_i, nan_as_i32)), //
+ Failure<builder::vec2<f32>>(Vec(0_u, inf_as_u32)), //
+ Failure<builder::vec2<f32>>(Vec(0_i, inf_as_i32)), //
+ Failure<builder::vec2<f32>>(Vec(0_u, neg_inf_as_u32)), //
+ Failure<builder::vec2<f32>>(Vec(0_i, neg_inf_as_i32)), //
+ }));
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc
index 93871b4..306911f 100644
--- a/src/tint/resolver/const_eval_construction_test.cc
+++ b/src/tint/resolver/const_eval_construction_test.cc
@@ -19,6 +19,40 @@
namespace tint::resolver {
namespace {
+TEST_F(ResolverConstEvalTest, Scalar_AFloat) {
+ auto* expr = Expr(99.0_a);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<type::AbstractFloat>());
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->ValueAs<AFloat>(), 99.0f);
+}
+
+TEST_F(ResolverConstEvalTest, Scalar_AInt) {
+ auto* expr = Expr(99_a);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<type::AbstractInt>());
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->ValueAs<AInt>(), 99);
+}
+
TEST_F(ResolverConstEvalTest, Scalar_i32) {
auto* expr = Expr(99_i);
WrapInFunction(expr);
@@ -102,6 +136,86 @@
EXPECT_EQ(sem->ConstantValue()->ValueAs<bool>(), true);
}
+namespace ZeroInit {
+struct Case {
+ builder::ast_type_func_ptr type;
+};
+template <typename T>
+Case C() {
+ return Case{builder::DataType<T>::AST};
+}
+using ResolverConstEvalZeroInitTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalZeroInitTest, Test) {
+ Enable(ast::Extension::kF16);
+ auto& param = GetParam();
+ auto* ty = param.type(*this);
+ auto* expr = Construct(ty);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ if (sem->Type()->is_scalar()) {
+ EXPECT_EQ(sem->ConstantValue()->Index(0), nullptr);
+ EXPECT_EQ(sem->ConstantValue()->ValueAs<f32>(), 0.0f);
+ } else if (auto* vec = sem->Type()->As<type::Vector>()) {
+ for (size_t i = 0; i < vec->Width(); ++i) {
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(i)->ValueAs<f32>(), 0.0f);
+ }
+ } else if (auto* mat = sem->Type()->As<type::Matrix>()) {
+ for (size_t i = 0; i < mat->columns(); ++i) {
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllZero());
+ for (size_t j = 0; j < mat->rows(); ++j) {
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->Index(j)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->Index(j)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->Index(j)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(i)->Index(j)->ValueAs<f32>(), 0.0f);
+ }
+ }
+ } else if (auto* arr = sem->Type()->As<type::Array>()) {
+ for (size_t i = 0; i < *(arr->ConstantCount()); ++i) {
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(i)->ValueAs<f32>(), 0.0f);
+ }
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ZeroInit,
+ ResolverConstEvalZeroInitTest,
+ testing::ValuesIn({
+ C<u32>(),
+ C<i32>(),
+ C<f32>(),
+ C<f16>(),
+ C<bool>(),
+ C<builder::vec3<u32>>(),
+ C<builder::vec3<i32>>(),
+ C<builder::vec3<f32>>(),
+ C<builder::vec3<f16>>(),
+ C<builder::mat2x2<f32>>(),
+ C<builder::mat2x2<f16>>(),
+ C<builder::array<3, u32>>(),
+ C<builder::array<3, i32>>(),
+ C<builder::array<3, f32>>(),
+ C<builder::array<3, f16>>(),
+ C<builder::array<3, bool>>(),
+ }));
+
+} // namespace ZeroInit
+
TEST_F(ResolverConstEvalTest, Vec3_ZeroInit_i32) {
auto* expr = vec3<i32>();
WrapInFunction(expr);
@@ -437,6 +551,74 @@
EXPECT_EQ(sem->ConstantValue()->Index(2)->ValueAs<bool>(), true);
}
+TEST_F(ResolverConstEvalTest, Vec3_FullConstruct_AInt) {
+ auto* expr = vec3<AInt>(1_a, 2_a, 3_a);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* vec = sem->Type()->As<type::Vector>();
+ ASSERT_NE(vec, nullptr);
+ EXPECT_TRUE(vec->type()->Is<type::AbstractInt>());
+ EXPECT_EQ(vec->Width(), 3u);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->ValueAs<AInt>(), 1);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->ValueAs<AInt>(), 2);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->ValueAs<AInt>(), 3);
+}
+
+TEST_F(ResolverConstEvalTest, Vec3_FullConstruct_AFloat) {
+ auto* expr = vec3<AFloat>(1.0_a, 2.0_a, 3.0_a);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* vec = sem->Type()->As<type::Vector>();
+ ASSERT_NE(vec, nullptr);
+ EXPECT_TRUE(vec->type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(vec->Width(), 3u);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->ValueAs<AFloat>(), 1.0f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->ValueAs<AFloat>(), 2.0f);
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
+ EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->ValueAs<AFloat>(), 3.0f);
+}
+
TEST_F(ResolverConstEvalTest, Vec3_FullConstruct_i32) {
auto* expr = vec3<i32>(1_i, 2_i, 3_i);
WrapInFunction(expr);
@@ -1509,6 +1691,102 @@
EXPECT_EQ(sem->ConstantValue()->Index(3)->ValueAs<i32>(), 40_i);
}
+namespace ArrayInit {
+struct Case {
+ Value input;
+};
+static Case C(Value input) {
+ return Case{std::move(input)};
+}
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ return o << "input: " << c.input;
+}
+
+using ResolverConstEvalArrayInitTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalArrayInitTest, Test) {
+ Enable(ast::Extension::kF16);
+ auto& param = GetParam();
+ auto* expr = param.input.Expr(*this);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* arr = sem->Type()->As<type::Array>();
+ ASSERT_NE(arr, nullptr);
+
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ // Constant values should match input values
+ CheckConstant(sem->ConstantValue(), param.input);
+}
+template <typename T>
+std::vector<Case> ArrayInitCases() {
+ return {
+ C(Array(T(0))), //
+ C(Array(T(0))), //
+ C(Array(T(0), T(1))), //
+ C(Array(T(0), T(1), T(2))), //
+ C(Array(T(2), T(1), T(0))), //
+ C(Array(T(2), T(0), T(1))), //
+ };
+}
+INSTANTIATE_TEST_SUITE_P( //
+ ArrayInit,
+ ResolverConstEvalArrayInitTest,
+ testing::ValuesIn(Concat(ArrayInitCases<AInt>(), //
+ ArrayInitCases<AFloat>(), //
+ ArrayInitCases<i32>(), //
+ ArrayInitCases<u32>(), //
+ ArrayInitCases<f32>(), //
+ ArrayInitCases<f16>(), //
+ ArrayInitCases<bool>())));
+} // namespace ArrayInit
+
+TEST_F(ResolverConstEvalTest, ArrayInit_Nested_f32) {
+ auto inner_ty = [&] { return ty.array<f32, 2>(); };
+ auto outer_ty = ty.array(inner_ty(), Expr(3_i));
+
+ auto* expr = Construct(outer_ty, //
+ Construct(inner_ty(), 1_f, 2_f), //
+ Construct(inner_ty(), 3_f, 4_f), //
+ Construct(inner_ty(), 5_f, 6_f));
+
+ WrapInFunction(expr);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* outer_arr = sem->Type()->As<type::Array>();
+ ASSERT_NE(outer_arr, nullptr);
+ EXPECT_TRUE(outer_arr->ElemType()->Is<type::Array>());
+ EXPECT_TRUE(outer_arr->ElemType()->As<type::Array>()->ElemType()->Is<type::F32>());
+
+ auto* arr = sem->ConstantValue();
+ EXPECT_FALSE(arr->AllEqual());
+ EXPECT_FALSE(arr->AnyZero());
+ EXPECT_FALSE(arr->AllZero());
+
+ EXPECT_FALSE(arr->Index(0)->AllEqual());
+ EXPECT_FALSE(arr->Index(0)->AnyZero());
+ EXPECT_FALSE(arr->Index(0)->AllZero());
+ EXPECT_FALSE(arr->Index(1)->AllEqual());
+ EXPECT_FALSE(arr->Index(1)->AnyZero());
+ EXPECT_FALSE(arr->Index(1)->AllZero());
+ EXPECT_FALSE(arr->Index(2)->AllEqual());
+ EXPECT_FALSE(arr->Index(2)->AnyZero());
+ EXPECT_FALSE(arr->Index(2)->AllZero());
+
+ EXPECT_EQ(arr->Index(0)->Index(0)->ValueAs<f32>(), 1.0f);
+ EXPECT_EQ(arr->Index(0)->Index(1)->ValueAs<f32>(), 2.0f);
+ EXPECT_EQ(arr->Index(1)->Index(0)->ValueAs<f32>(), 3.0f);
+ EXPECT_EQ(arr->Index(1)->Index(1)->ValueAs<f32>(), 4.0f);
+ EXPECT_EQ(arr->Index(2)->Index(0)->ValueAs<f32>(), 5.0f);
+ EXPECT_EQ(arr->Index(2)->Index(1)->ValueAs<f32>(), 6.0f);
+}
+
TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* expr = Construct(ty.array<f32, 4>(), 10_f, 20_f, 30_f, 40_f);
WrapInFunction(expr);
@@ -1613,6 +1891,104 @@
EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(1)->ValueAs<f32>(), 4_f);
}
+TEST_F(ResolverConstEvalTest, Struct_ZeroInit) {
+ Enable(ast::Extension::kF16);
+ auto* s = Structure("S", utils::Vector{
+ Member("a", ty.i32()),
+ Member("b", ty.u32()),
+ Member("c", ty.f32()),
+ Member("d", ty.f16()),
+ Member("e", ty.bool_()),
+ });
+
+ auto* expr = Construct(ty.Of(s));
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<type::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().Length(), 5u);
+
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_FALSE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ EXPECT_TRUE(sem->ConstantValue()->Index(0)->Type()->Is<type::I32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(0)->ValueAs<i32>(), 0);
+ EXPECT_TRUE(sem->ConstantValue()->Index(1)->Type()->Is<type::U32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(1)->ValueAs<u32>(), 0u);
+ EXPECT_TRUE(sem->ConstantValue()->Index(2)->Type()->Is<type::F32>());
+ EXPECT_EQ(sem->ConstantValue()->Index(2)->ValueAs<f32>(), 0.0f);
+ EXPECT_TRUE(sem->ConstantValue()->Index(3)->Type()->Is<type::F16>());
+ EXPECT_EQ(sem->ConstantValue()->Index(3)->ValueAs<f16>(), 0.0f);
+ EXPECT_TRUE(sem->ConstantValue()->Index(4)->Type()->Is<type::Bool>());
+ EXPECT_EQ(sem->ConstantValue()->Index(4)->ValueAs<bool>(), false);
+
+ for (size_t i = 0; i < str->Members().Length(); ++i) {
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->Index(i)->AllZero());
+ }
+}
+
+TEST_F(ResolverConstEvalTest, Struct_Nested_ZeroInit) {
+ Enable(ast::Extension::kF16);
+ auto* inner = Structure("Inner", utils::Vector{
+ Member("a", ty.i32()),
+ Member("b", ty.u32()),
+ Member("c", ty.f32()),
+ Member("d", ty.f16()),
+ Member("e", ty.bool_()),
+ });
+ auto* s = Structure("s", //
+ utils::Vector{
+ Member("inner", ty.Of(inner)),
+ });
+
+ auto* expr = Construct(ty.Of(s));
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* str = sem->Type()->As<type::Struct>();
+ ASSERT_NE(str, nullptr);
+ EXPECT_EQ(str->Members().Length(), 1u);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ EXPECT_TRUE(sem->ConstantValue()->AllEqual());
+ EXPECT_TRUE(sem->ConstantValue()->AnyZero());
+ EXPECT_TRUE(sem->ConstantValue()->AllZero());
+
+ auto* inner_struct = sem->ConstantValue()->Index(0);
+ EXPECT_FALSE(inner_struct->AllEqual());
+ EXPECT_TRUE(inner_struct->AnyZero());
+ EXPECT_TRUE(inner_struct->AllZero());
+
+ EXPECT_TRUE(inner_struct->Index(0)->Type()->Is<type::I32>());
+ EXPECT_EQ(inner_struct->Index(0)->ValueAs<i32>(), 0);
+ EXPECT_TRUE(inner_struct->Index(1)->Type()->Is<type::U32>());
+ EXPECT_EQ(inner_struct->Index(1)->ValueAs<u32>(), 0u);
+ EXPECT_TRUE(inner_struct->Index(2)->Type()->Is<type::F32>());
+ EXPECT_EQ(inner_struct->Index(2)->ValueAs<f32>(), 0.0f);
+ EXPECT_TRUE(inner_struct->Index(3)->Type()->Is<type::F16>());
+ EXPECT_EQ(inner_struct->Index(3)->ValueAs<f16>(), 0.0f);
+ EXPECT_TRUE(inner_struct->Index(4)->Type()->Is<type::Bool>());
+ EXPECT_EQ(inner_struct->Index(4)->ValueAs<bool>(), false);
+
+ for (size_t i = 0; i < str->Members().Length(); ++i) {
+ EXPECT_TRUE(inner_struct->Index(i)->AllEqual());
+ EXPECT_TRUE(inner_struct->Index(i)->AnyZero());
+ EXPECT_TRUE(inner_struct->Index(i)->AllZero());
+ }
+}
+
TEST_F(ResolverConstEvalTest, Struct_I32s_ZeroInit) {
Structure(
"S", utils::Vector{Member("m1", ty.i32()), Member("m2", ty.i32()), Member("m3", ty.i32())});
diff --git a/src/tint/resolver/const_eval_indexing_test.cc b/src/tint/resolver/const_eval_indexing_test.cc
index 7f9ae6f..9229f98 100644
--- a/src/tint/resolver/const_eval_indexing_test.cc
+++ b/src/tint/resolver/const_eval_indexing_test.cc
@@ -51,6 +51,78 @@
EXPECT_EQ(r()->error(), "12:34 error: index -3 out of bounds [0..2]");
}
+namespace Swizzle {
+struct Case {
+ Value input;
+ const char* swizzle;
+ Value expected;
+};
+
+static Case C(Value input, const char* swizzle, Value expected) {
+ return Case{std::move(input), swizzle, std::move(expected)};
+}
+
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ return o << "input: " << c.input << ", swizzle: " << c.swizzle << ", expected: " << c.expected;
+}
+
+using ResolverConstEvalSwizzleTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalSwizzleTest, Test) {
+ Enable(ast::Extension::kF16);
+ auto& param = GetParam();
+ auto* expr = MemberAccessor(param.input.Expr(*this), param.swizzle);
+ auto* a = Const("a", expr);
+ WrapInFunction(a);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+
+ CheckConstant(sem->ConstantValue(), param.expected);
+}
+template <typename T>
+std::vector<Case> SwizzleCases() {
+ return {
+ C(Vec(T(0), T(1), T(2)), "xyz", Vec(T(0), T(1), T(2))),
+ C(Vec(T(0), T(1), T(2)), "xzy", Vec(T(0), T(2), T(1))),
+ C(Vec(T(0), T(1), T(2)), "yxz", Vec(T(1), T(0), T(2))),
+ C(Vec(T(0), T(1), T(2)), "yzx", Vec(T(1), T(2), T(0))),
+ C(Vec(T(0), T(1), T(2)), "zxy", Vec(T(2), T(0), T(1))),
+ C(Vec(T(0), T(1), T(2)), "zyx", Vec(T(2), T(1), T(0))),
+ C(Vec(T(0), T(1), T(2)), "xy", Vec(T(0), T(1))),
+ C(Vec(T(0), T(1), T(2)), "xz", Vec(T(0), T(2))),
+ C(Vec(T(0), T(1), T(2)), "yx", Vec(T(1), T(0))),
+ C(Vec(T(0), T(1), T(2)), "yz", Vec(T(1), T(2))),
+ C(Vec(T(0), T(1), T(2)), "zx", Vec(T(2), T(0))),
+ C(Vec(T(0), T(1), T(2)), "zy", Vec(T(2), T(1))),
+ C(Vec(T(0), T(1), T(2)), "xxxx", Vec(T(0), T(0), T(0), T(0))),
+ C(Vec(T(0), T(1), T(2)), "yyyy", Vec(T(1), T(1), T(1), T(1))),
+ C(Vec(T(0), T(1), T(2)), "zzzz", Vec(T(2), T(2), T(2), T(2))),
+ C(Vec(T(0), T(1), T(2)), "xxx", Vec(T(0), T(0), T(0))),
+ C(Vec(T(0), T(1), T(2)), "yyy", Vec(T(1), T(1), T(1))),
+ C(Vec(T(0), T(1), T(2)), "zzz", Vec(T(2), T(2), T(2))),
+ C(Vec(T(0), T(1), T(2)), "xx", Vec(T(0), T(0))),
+ C(Vec(T(0), T(1), T(2)), "yy", Vec(T(1), T(1))),
+ C(Vec(T(0), T(1), T(2)), "zz", Vec(T(2), T(2))),
+ C(Vec(T(0), T(1), T(2)), "x", Val(T(0))),
+ C(Vec(T(0), T(1), T(2)), "y", Val(T(1))),
+ C(Vec(T(0), T(1), T(2)), "z", Val(T(2))),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(Swizzle,
+ ResolverConstEvalSwizzleTest,
+ testing::ValuesIn(Concat(SwizzleCases<AInt>(), //
+ SwizzleCases<AFloat>(), //
+ SwizzleCases<f32>(), //
+ SwizzleCases<f16>(), //
+ SwizzleCases<i32>(), //
+ SwizzleCases<u32>(), //
+ SwizzleCases<bool>() //
+ )));
+} // namespace Swizzle
+
TEST_F(ResolverConstEvalTest, Vec3_Swizzle_Scalar) {
auto* expr = MemberAccessor(vec3<i32>(1_i, 2_i, 3_i), "y");
WrapInFunction(expr);
diff --git a/src/tint/resolver/const_eval_member_access_test.cc b/src/tint/resolver/const_eval_member_access_test.cc
index 705ed27..374cf46 100644
--- a/src/tint/resolver/const_eval_member_access_test.cc
+++ b/src/tint/resolver/const_eval_member_access_test.cc
@@ -19,11 +19,12 @@
namespace tint::resolver {
namespace {
-TEST_F(ResolverConstEvalTest, MemberAccess) {
+TEST_F(ResolverConstEvalTest, StructMemberAccess) {
Structure("Inner", utils::Vector{
Member("i1", ty.i32()),
Member("i2", ty.u32()),
Member("i3", ty.f32()),
+ Member("i4", ty.bool_()),
});
Structure("Outer", utils::Vector{
@@ -31,7 +32,7 @@
Member("o2", ty.type_name("Inner")),
});
auto* outer_expr = Construct(ty.type_name("Outer"), //
- Construct(ty.type_name("Inner"), 1_i, 2_u, 3_f),
+ Construct(ty.type_name("Inner"), 1_i, 2_u, 3_f, true),
Construct(ty.type_name("Inner")));
auto* o1_expr = MemberAccessor(outer_expr, "o1");
auto* i2_expr = MemberAccessor(o1_expr, "i2");
@@ -59,6 +60,7 @@
EXPECT_EQ(o1->ConstantValue()->Index(0)->ValueAs<i32>(), 1_i);
EXPECT_EQ(o1->ConstantValue()->Index(1)->ValueAs<u32>(), 2_u);
EXPECT_EQ(o1->ConstantValue()->Index(2)->ValueAs<f32>(), 3_f);
+ EXPECT_EQ(o1->ConstantValue()->Index(2)->ValueAs<bool>(), true);
auto* i2 = Sem().Get(i2_expr);
ASSERT_NE(i2->ConstantValue(), nullptr);
@@ -94,5 +96,310 @@
EXPECT_EQ(c1->Index(0)->ValueAs<AFloat>(), 3.0);
EXPECT_EQ(c1->Index(1)->ValueAs<AFloat>(), 4.0);
}
+
+TEST_F(ResolverConstEvalTest, MatrixMemberAccess_AFloat) {
+ auto* c =
+ Const("a", Construct(ty.mat(nullptr, 2, 3), //
+ Construct(ty.vec(nullptr, 3), Expr(1.0_a), Expr(2.0_a), Expr(3.0_a)),
+ Construct(ty.vec(nullptr, 3), Expr(4.0_a), Expr(5.0_a), Expr(6.0_a))));
+
+ auto* col_0 = Const("col_0", IndexAccessor("a", Expr(0_i)));
+ auto* col_1 = Const("col_1", IndexAccessor("a", Expr(1_i)));
+ auto* e00 = Const("e00", IndexAccessor("col_0", Expr(0_i)));
+ auto* e01 = Const("e01", IndexAccessor("col_0", Expr(1_i)));
+ auto* e02 = Const("e02", IndexAccessor("col_0", Expr(2_i)));
+ auto* e10 = Const("e10", IndexAccessor("col_1", Expr(0_i)));
+ auto* e11 = Const("e11", IndexAccessor("col_1", Expr(1_i)));
+ auto* e12 = Const("e12", IndexAccessor("col_1", Expr(2_i)));
+
+ (void)col_0;
+ (void)col_1;
+
+ WrapInFunction(c, col_0, col_1, e00, e01, e02, e10, e11, e12);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(c);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<type::Matrix>());
+ auto* cv = sem->ConstantValue();
+ EXPECT_TYPE(cv->Type(), sem->Type());
+ EXPECT_TRUE(cv->Index(0)->Type()->Is<type::Vector>());
+ EXPECT_TRUE(cv->Index(0)->Index(0)->Type()->Is<type::AbstractFloat>());
+ EXPECT_FALSE(cv->AllEqual());
+ EXPECT_FALSE(cv->AnyZero());
+ EXPECT_FALSE(cv->AllZero());
+
+ auto* sem_col0 = Sem().Get(col_0);
+ ASSERT_NE(sem_col0, nullptr);
+ EXPECT_TRUE(sem_col0->Type()->Is<type::Vector>());
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(0)->ValueAs<AFloat>(), 1.0);
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(1)->ValueAs<AFloat>(), 2.0);
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(2)->ValueAs<AFloat>(), 3.0);
+
+ auto* sem_col1 = Sem().Get(col_1);
+ ASSERT_NE(sem_col1, nullptr);
+ EXPECT_TRUE(sem_col1->Type()->Is<type::Vector>());
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(0)->ValueAs<AFloat>(), 4.0);
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(1)->ValueAs<AFloat>(), 5.0);
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(2)->ValueAs<AFloat>(), 6.0);
+
+ auto* sem_e00 = Sem().Get(e00);
+ ASSERT_NE(sem_e00, nullptr);
+ EXPECT_TRUE(sem_e00->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e00->ConstantValue()->ValueAs<AFloat>(), 1.0);
+
+ auto* sem_e01 = Sem().Get(e01);
+ ASSERT_NE(sem_e01, nullptr);
+ EXPECT_TRUE(sem_e01->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e01->ConstantValue()->ValueAs<AFloat>(), 2.0);
+
+ auto* sem_e02 = Sem().Get(e02);
+ ASSERT_NE(sem_e02, nullptr);
+ EXPECT_TRUE(sem_e02->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e02->ConstantValue()->ValueAs<AFloat>(), 3.0);
+
+ auto* sem_e10 = Sem().Get(e10);
+ ASSERT_NE(sem_e10, nullptr);
+ EXPECT_TRUE(sem_e10->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e10->ConstantValue()->ValueAs<AFloat>(), 4.0);
+
+ auto* sem_e11 = Sem().Get(e11);
+ ASSERT_NE(sem_e11, nullptr);
+ EXPECT_TRUE(sem_e11->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e11->ConstantValue()->ValueAs<AFloat>(), 5.0);
+
+ auto* sem_e12 = Sem().Get(e12);
+ ASSERT_NE(sem_e12, nullptr);
+ EXPECT_TRUE(sem_e12->Type()->Is<type::AbstractFloat>());
+ EXPECT_EQ(sem_e12->ConstantValue()->ValueAs<AFloat>(), 6.0);
+}
+
+TEST_F(ResolverConstEvalTest, MatrixMemberAccess_f32) {
+ auto* c =
+ Const("a", Construct(ty.mat(nullptr, 2, 3), //
+ Construct(ty.vec(nullptr, 3), Expr(1.0_f), Expr(2.0_f), Expr(3.0_f)),
+ Construct(ty.vec(nullptr, 3), Expr(4.0_f), Expr(5.0_f), Expr(6.0_f))));
+
+ auto* col_0 = Const("col_0", IndexAccessor("a", Expr(0_i)));
+ auto* col_1 = Const("col_1", IndexAccessor("a", Expr(1_i)));
+ auto* e00 = Const("e00", IndexAccessor("col_0", Expr(0_i)));
+ auto* e01 = Const("e01", IndexAccessor("col_0", Expr(1_i)));
+ auto* e02 = Const("e02", IndexAccessor("col_0", Expr(2_i)));
+ auto* e10 = Const("e10", IndexAccessor("col_1", Expr(0_i)));
+ auto* e11 = Const("e11", IndexAccessor("col_1", Expr(1_i)));
+ auto* e12 = Const("e12", IndexAccessor("col_1", Expr(2_i)));
+
+ (void)col_0;
+ (void)col_1;
+
+ WrapInFunction(c, col_0, col_1, e00, e01, e02, e10, e11, e12);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(c);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_TRUE(sem->Type()->Is<type::Matrix>());
+ auto* cv = sem->ConstantValue();
+ EXPECT_TYPE(cv->Type(), sem->Type());
+ EXPECT_TRUE(cv->Index(0)->Type()->Is<type::Vector>());
+ EXPECT_TRUE(cv->Index(0)->Index(0)->Type()->Is<type::F32>());
+ EXPECT_FALSE(cv->AllEqual());
+ EXPECT_FALSE(cv->AnyZero());
+ EXPECT_FALSE(cv->AllZero());
+
+ auto* sem_col0 = Sem().Get(col_0);
+ ASSERT_NE(sem_col0, nullptr);
+ EXPECT_TRUE(sem_col0->Type()->Is<type::Vector>());
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(0)->ValueAs<f32>(), 1.0f);
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(1)->ValueAs<f32>(), 2.0f);
+ EXPECT_EQ(sem_col0->ConstantValue()->Index(2)->ValueAs<f32>(), 3.0f);
+
+ auto* sem_col1 = Sem().Get(col_1);
+ ASSERT_NE(sem_col1, nullptr);
+ EXPECT_TRUE(sem_col1->Type()->Is<type::Vector>());
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(0)->ValueAs<f32>(), 4.0f);
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(1)->ValueAs<f32>(), 5.0f);
+ EXPECT_EQ(sem_col1->ConstantValue()->Index(2)->ValueAs<f32>(), 6.0f);
+
+ auto* sem_e00 = Sem().Get(e00);
+ ASSERT_NE(sem_e00, nullptr);
+ EXPECT_TRUE(sem_e00->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e00->ConstantValue()->ValueAs<f32>(), 1.0f);
+
+ auto* sem_e01 = Sem().Get(e01);
+ ASSERT_NE(sem_e01, nullptr);
+ EXPECT_TRUE(sem_e01->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e01->ConstantValue()->ValueAs<f32>(), 2.0f);
+
+ auto* sem_e02 = Sem().Get(e02);
+ ASSERT_NE(sem_e02, nullptr);
+ EXPECT_TRUE(sem_e02->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e02->ConstantValue()->ValueAs<f32>(), 3.0f);
+
+ auto* sem_e10 = Sem().Get(e10);
+ ASSERT_NE(sem_e10, nullptr);
+ EXPECT_TRUE(sem_e10->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e10->ConstantValue()->ValueAs<f32>(), 4.0f);
+
+ auto* sem_e11 = Sem().Get(e11);
+ ASSERT_NE(sem_e11, nullptr);
+ EXPECT_TRUE(sem_e11->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e11->ConstantValue()->ValueAs<f32>(), 5.0f);
+
+ auto* sem_e12 = Sem().Get(e12);
+ ASSERT_NE(sem_e12, nullptr);
+ EXPECT_TRUE(sem_e12->Type()->Is<type::F32>());
+ EXPECT_EQ(sem_e12->ConstantValue()->ValueAs<f32>(), 6.0f);
+}
+
+namespace ArrayAccess {
+struct Case {
+ Value input;
+};
+static Case C(Value input) {
+ return Case{std::move(input)};
+}
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ return o << "input: " << c.input;
+}
+
+using ResolverConstEvalArrayAccessTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalArrayAccessTest, Test) {
+ Enable(ast::Extension::kF16);
+
+ auto& param = GetParam();
+ auto* expr = param.input.Expr(*this);
+ auto* a = Const("a", expr);
+
+ utils::Vector<const ast::IndexAccessorExpression*, 4> index_accessors;
+ for (size_t i = 0; i < param.input.args.Length(); ++i) {
+ auto* index = IndexAccessor("a", Expr(i32(i)));
+ index_accessors.Push(index);
+ }
+
+ utils::Vector<const ast::Statement*, 5> stmts;
+ stmts.Push(WrapInStatement(a));
+ for (auto* ia : index_accessors) {
+ stmts.Push(WrapInStatement(ia));
+ }
+ WrapInFunction(std::move(stmts));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* arr = sem->Type()->As<type::Array>();
+ ASSERT_NE(arr, nullptr);
+
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ for (size_t i = 0; i < index_accessors.Length(); ++i) {
+ auto* ia_sem = Sem().Get(index_accessors[i]);
+ ASSERT_NE(ia_sem, nullptr);
+ ASSERT_NE(ia_sem->ConstantValue(), nullptr);
+ EXPECT_EQ(ia_sem->ConstantValue()->ValueAs<AInt>(), i);
+ }
+}
+template <typename T>
+std::vector<Case> ArrayAccessCases() {
+ if constexpr (std::is_same_v<T, bool>) {
+ return {
+ C(Array(false, true)),
+ };
+ } else {
+ return {
+ C(Array(T(0))), //
+ C(Array(T(0), T(1))), //
+ C(Array(T(0), T(1), T(2))), //
+ C(Array(T(0), T(1), T(2), T(3))), //
+ C(Array(T(0), T(1), T(2), T(3), T(4))), //
+ };
+ }
+}
+INSTANTIATE_TEST_SUITE_P( //
+ ArrayAccess,
+ ResolverConstEvalArrayAccessTest,
+ testing::ValuesIn(Concat(ArrayAccessCases<AInt>(), //
+ ArrayAccessCases<AFloat>(), //
+ ArrayAccessCases<i32>(), //
+ ArrayAccessCases<u32>(), //
+ ArrayAccessCases<f32>(), //
+ ArrayAccessCases<f16>(), //
+ ArrayAccessCases<bool>())));
+} // namespace ArrayAccess
+
+namespace VectorAccess {
+struct Case {
+ Value input;
+};
+static Case C(Value input) {
+ return Case{std::move(input)};
+}
+static std::ostream& operator<<(std::ostream& o, const Case& c) {
+ return o << "input: " << c.input;
+}
+
+using ResolverConstEvalVectorAccessTest = ResolverTestWithParam<Case>;
+TEST_P(ResolverConstEvalVectorAccessTest, Test) {
+ Enable(ast::Extension::kF16);
+
+ auto& param = GetParam();
+ auto* expr = param.input.Expr(*this);
+ auto* a = Const("a", expr);
+
+ utils::Vector<const ast::IndexAccessorExpression*, 4> index_accessors;
+ for (size_t i = 0; i < param.input.args.Length(); ++i) {
+ auto* index = IndexAccessor("a", Expr(i32(i)));
+ index_accessors.Push(index);
+ }
+
+ utils::Vector<const ast::Statement*, 5> stmts;
+ stmts.Push(WrapInStatement(a));
+ for (auto* ia : index_accessors) {
+ stmts.Push(WrapInStatement(ia));
+ }
+ WrapInFunction(std::move(stmts));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem = Sem().Get(expr);
+ ASSERT_NE(sem, nullptr);
+ auto* vec = sem->Type()->As<type::Vector>();
+ ASSERT_NE(vec, nullptr);
+
+ EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
+ for (size_t i = 0; i < index_accessors.Length(); ++i) {
+ auto* ia_sem = Sem().Get(index_accessors[i]);
+ ASSERT_NE(ia_sem, nullptr);
+ ASSERT_NE(ia_sem->ConstantValue(), nullptr);
+ EXPECT_EQ(ia_sem->ConstantValue()->ValueAs<AInt>(), i);
+ }
+}
+template <typename T>
+std::vector<Case> VectorAccessCases() {
+ if constexpr (std::is_same_v<T, bool>) {
+ return {
+ C(Vec(false, true)),
+ };
+ } else {
+ return {
+ C(Vec(T(0), T(1))), //
+ C(Vec(T(0), T(1), T(2))), //
+ C(Vec(T(0), T(1), T(2), T(3))), //
+ };
+ }
+}
+INSTANTIATE_TEST_SUITE_P( //
+ VectorAccess,
+ ResolverConstEvalVectorAccessTest,
+ testing::ValuesIn(Concat(VectorAccessCases<AInt>(), //
+ VectorAccessCases<AFloat>(), //
+ VectorAccessCases<i32>(), //
+ VectorAccessCases<u32>(), //
+ VectorAccessCases<f32>(), //
+ VectorAccessCases<f16>(), //
+ VectorAccessCases<bool>())));
+} // namespace VectorAccess
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index c31d89f..001d3d1 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -245,6 +245,7 @@
return ss.str();
}
+using builder::Array;
using builder::IsValue;
using builder::Mat;
using builder::Val;
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 8de7af9..49da494 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -789,11 +789,11 @@
// let a = abstract_expr;
kLet,
- // bitcast<f32>(abstract_expr)
- kBitcastF32Arg,
+ // bitcast<i32>(abstract_expr)
+ kBitcastI32Arg,
- // bitcast<vec3<f32>>(abstract_expr)
- kBitcastVec3F32Arg,
+ // bitcast<vec3<i32>>(abstract_expr)
+ kBitcastVec3I32Arg,
// array<i32, abstract_expr>()
kArrayLength,
@@ -825,10 +825,10 @@
return o << "var";
case Method::kLet:
return o << "let";
- case Method::kBitcastF32Arg:
- return o << "bitcast-f32-arg";
- case Method::kBitcastVec3F32Arg:
- return o << "bitcast-vec3-f32-arg";
+ case Method::kBitcastI32Arg:
+ return o << "bitcast-i32-arg";
+ case Method::kBitcastVec3I32Arg:
+ return o << "bitcast-vec3-i32-arg";
case Method::kArrayLength:
return o << "array-length";
case Method::kSwitch:
@@ -903,12 +903,12 @@
WrapInFunction(Decl(Let("a", abstract_expr())));
break;
}
- case Method::kBitcastF32Arg: {
- WrapInFunction(Bitcast<f32>(abstract_expr()));
+ case Method::kBitcastI32Arg: {
+ WrapInFunction(Bitcast<i32>(abstract_expr()));
break;
}
- case Method::kBitcastVec3F32Arg: {
- WrapInFunction(Bitcast(ty.vec3<f32>(), abstract_expr()));
+ case Method::kBitcastVec3I32Arg: {
+ WrapInFunction(Bitcast(ty.vec3<i32>(), abstract_expr()));
break;
}
case Method::kArrayLength: {
@@ -977,7 +977,7 @@
constexpr Method kScalarMethods[] = {
Method::kLet,
Method::kVar,
- Method::kBitcastF32Arg,
+ Method::kBitcastI32Arg,
Method::kTintMaterializeBuiltin,
};
@@ -985,7 +985,7 @@
constexpr Method kVectorMethods[] = {
Method::kLet,
Method::kVar,
- Method::kBitcastVec3F32Arg,
+ Method::kBitcastVec3I32Arg,
Method::kRuntimeIndex,
Method::kTintMaterializeBuiltin,
};
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 5d9dfca..cec6e5a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1957,24 +1957,27 @@
if (!ty) {
return nullptr;
}
-
- const constant::Value* val = nullptr;
- // TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented.
- if (auto r = const_eval_.Bitcast(ty, inner)) {
- val = r.Get();
- } else {
- return nullptr;
- }
- auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
- auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_,
- std::move(val), inner->HasSideEffects());
-
- sem->Behaviors() = inner->Behaviors();
-
if (!validator_.Bitcast(expr, ty)) {
return nullptr;
}
+ auto stage = inner->Stage();
+ if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+ stage = sem::EvaluationStage::kNotEvaluated;
+ }
+
+ const constant::Value* value = nullptr;
+ if (stage == sem::EvaluationStage::kConstant) {
+ if (auto r = const_eval_.Bitcast(ty, inner->ConstantValue(), expr->source)) {
+ value = r.Get();
+ } else {
+ return nullptr;
+ }
+ }
+
+ auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_,
+ std::move(value), inner->HasSideEffects());
+ sem->Behaviors() = inner->Behaviors();
return sem;
}
@@ -2047,8 +2050,8 @@
return nullptr;
}
- auto stage = args_stage; // The evaluation stage of the call
- const constant::Value* value = nullptr; // The constant value for the call
+ auto stage = args_stage; // The evaluation stage of the call
+ const constant::Value* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) {
if (auto r = const_eval_.ArrayOrStructInit(ty, args)) {
value = r.Get();
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 86bc5d5..53c187f 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -755,15 +755,18 @@
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
using EL_TY = typename builder::DataType<T>::ElementType;
return Value{
- std::move(args), CreatePtrsFor<T>().expr, tint::IsAbstract<EL_TY>,
- tint::IsIntegral<EL_TY>, tint::FriendlyName<EL_TY>(),
+ std::move(args), //
+ CreatePtrsFor<T>(), //
+ tint::IsAbstract<EL_TY>, //
+ tint::IsIntegral<EL_TY>, //
+ tint::FriendlyName<EL_TY>(),
};
}
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
- const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); }
+ const ast::Expression* Expr(ProgramBuilder& b) const { return (*create_ptrs.expr)(b, args); }
/// Prints this value to the output stream
/// @param o the output stream
@@ -782,8 +785,8 @@
/// The arguments used to construct the value
utils::Vector<Scalar, 4> args;
- /// Function used to construct an expression with the given value
- builder::ast_expr_func_ptr create;
+ /// CreatePtrs for value's type used to create an expression with `args`
+ builder::CreatePtrs create_ptrs;
/// True if the element type is abstract
bool is_abstract = false;
/// True if the element type is an integer
@@ -809,14 +812,28 @@
}
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
-template <typename... T>
-Value Vec(T... args) {
- using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
+template <typename... Ts>
+Value Vec(Ts... args) {
+ using FirstT = std::tuple_element_t<0, std::tuple<Ts...>>;
+ static_assert(sizeof...(args) >= 2 && sizeof...(args) <= 4, "Invalid vector size");
+ static_assert(std::conjunction_v<std::is_same<FirstT, Ts>...>,
+ "Vector args must all be the same type");
constexpr size_t N = sizeof...(args);
utils::Vector<Scalar, sizeof...(args)> v{args...};
return Value::Create<vec<N, FirstT>>(std::move(v));
}
+/// Creates a Value of DataType<array<N, T>> from N scalar `args`
+template <typename... Ts>
+Value Array(Ts... args) {
+ using FirstT = std::tuple_element_t<0, std::tuple<Ts...>>;
+ static_assert(std::conjunction_v<std::is_same<FirstT, Ts>...>,
+ "Array args must all be the same type");
+ constexpr size_t N = sizeof...(args);
+ utils::Vector<Scalar, sizeof...(args)> v{args...};
+ return Value::Create<array<N, FirstT>>(std::move(v));
+}
+
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T>
Value Mat(const T (&m_in)[C][R]) {
@@ -879,7 +896,6 @@
}
return Value::Create<mat<C, R, T>>(std::move(m));
}
-
} // namespace builder
} // namespace tint::resolver
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 491360f..5f148f0 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -27,6 +27,7 @@
#include "src/tint/sem/function.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/info.h"
+#include "src/tint/sem/load.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/switch_statement.h"
@@ -76,8 +77,8 @@
/// ParameterTag describes the uniformity requirements of values passed to a function parameter.
enum ParameterTag {
- ParameterRequiredToBeUniform,
- ParameterRequiredToBeUniformForReturnValue,
+ ParameterValueRequiredToBeUniform,
+ ParameterContentsRequiredToBeUniform,
ParameterNoRestriction,
};
@@ -97,7 +98,8 @@
/// information.
enum Type {
kRegular,
- kFunctionCallArgument,
+ kFunctionCallArgumentValue,
+ kFunctionCallArgumentContents,
kFunctionCallPointerArgumentResult,
kFunctionCallReturnValue,
};
@@ -122,7 +124,10 @@
/// Add an edge to the `to` node.
/// @param to the destination node
- void AddEdge(Node* to) { edges.Add(to); }
+ void AddEdge(Node* to) {
+ TINT_ASSERT(Resolver, to != nullptr);
+ edges.Add(to);
+ }
};
/// ParameterInfo holds information about the uniformity requirements and effects for a particular
@@ -130,18 +135,25 @@
struct ParameterInfo {
/// The semantic node in corresponds to this parameter.
const sem::Parameter* sem;
- /// The parameter's uniformity requirements.
- ParameterTag tag = ParameterNoRestriction;
+ /// The parameter's direct uniformity requirements.
+ ParameterTag tag_direct = ParameterNoRestriction;
+ /// The parameter's uniformity requirements that affect the function return value.
+ ParameterTag tag_retval = ParameterNoRestriction;
/// Will be `true` if this function may cause the contents of this pointer parameter to become
/// non-uniform.
bool pointer_may_become_non_uniform = false;
/// The parameters that are required to be uniform for the contents of this pointer parameter to
/// be uniform at function exit.
- utils::Vector<const sem::Parameter*, 8> pointer_param_output_sources;
- /// The node in the graph that corresponds to this parameter's initial value.
- Node* init_value;
- /// The node in the graph that corresponds to this parameter's output value (or nullptr).
- Node* pointer_return_value = nullptr;
+ utils::Vector<const sem::Parameter*, 8> ptr_output_source_param_values;
+ /// The pointer parameters whose contents are required to be uniform for the contents of this
+ /// pointer parameter to be uniform at function exit.
+ utils::Vector<const sem::Parameter*, 8> ptr_output_source_param_contents;
+ /// The node in the graph that corresponds to this parameter's (immutable) value.
+ Node* value;
+ /// The node in the graph that corresponds to this pointer parameter's initial contents.
+ Node* ptr_input_contents = nullptr;
+ /// The node in the graph that corresponds to this pointer parameter's contents on return.
+ Node* ptr_output_contents = nullptr;
};
/// FunctionInfo holds information about the uniformity requirements and effects for a particular
@@ -156,11 +168,11 @@
function_tag = NoRestriction;
// Create special nodes.
- required_to_be_uniform = CreateNode("RequiredToBeUniform");
- may_be_non_uniform = CreateNode("MayBeNonUniform");
- cf_start = CreateNode("CF_start");
+ required_to_be_uniform = CreateNode({"RequiredToBeUniform"});
+ may_be_non_uniform = CreateNode({"MayBeNonUniform"});
+ cf_start = CreateNode({"CF_start"});
if (func->return_type) {
- value_return = CreateNode("Value_return");
+ value_return = CreateNode({"Value_return"});
}
// Create nodes for parameters.
@@ -171,16 +183,19 @@
auto* sem = builder->Sem().Get<sem::Parameter>(param);
parameters[i].sem = sem;
- Node* node_init;
+ parameters[i].value = CreateNode({"param_", param_name});
if (sem->Type()->Is<type::Pointer>()) {
- node_init = CreateNode("ptrparam_" + name + "_init");
- parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return");
+ // Create extra nodes for a pointer parameter's initial contents and its contents
+ // when the function returns.
+ parameters[i].ptr_input_contents =
+ CreateNode({"ptrparam_", param_name, "_input_contents"});
+ parameters[i].ptr_output_contents =
+ CreateNode({"ptrparam_", param_name, "_output_contents"});
+ variables.Set(sem, parameters[i].ptr_input_contents);
local_var_decls.Add(sem);
} else {
- node_init = CreateNode("param_" + name);
+ variables.Set(sem, parameters[i].value);
}
- parameters[i].init_value = node_init;
- variables.Set(sem, node_init);
}
}
@@ -209,8 +224,10 @@
/// Map from variables to their value nodes in the graph, scoped with respect to control flow.
ScopeStack<const sem::Variable*, Node*> variables;
- /// The set of a local read-write vars that are in scope at any given point in the process.
- /// Includes pointer parameters.
+ /// The set of mutable variables declared in the function that are in scope at any given point
+ /// in the analysis. This includes the contents of parameters to the function that are pointers.
+ /// This is used by the analysis for if statements and loops to know which variables need extra
+ /// nodes to capture their state when entering/exiting those constructs.
utils::Hashset<const sem::Variable*, 8> local_var_decls;
/// The set of partial pointer variables - pointers that point to a subobject (into an array or
@@ -238,15 +255,20 @@
void RemoveLoopSwitchInfoFor(const sem::Statement* stmt) { loop_switch_infos.Remove(stmt); }
/// Create a new node.
- /// @param tag a tag used to identify the node for debugging purposes
+ /// @param tag_list a string list that will be used to identify the node for debugging purposes
/// @param ast the optional AST node that this node corresponds to
/// @returns the new node
- Node* CreateNode([[maybe_unused]] std::string tag, const ast::Node* ast = nullptr) {
+ Node* CreateNode([[maybe_unused]] std::initializer_list<std::string_view> tag_list,
+ const ast::Node* ast = nullptr) {
auto* node = nodes.Create(ast);
#if TINT_DUMP_UNIFORMITY_GRAPH
// Make the tag unique and set it.
// This only matters if we're dumping the graph.
+ std::string tag = "";
+ for (auto& t : tag_list) {
+ tag += t;
+ }
std::string unique_tag = tag;
int suffix = 0;
while (tags_.Contains(unique_tag)) {
@@ -329,11 +351,20 @@
FunctionInfo* current_function_;
/// Create a new node.
- /// @param tag a tag used to identify the node for debugging purposes.
+ /// @param tag_list a string list that will be used to identify the node for debugging purposes
/// @param ast the optional AST node that this node corresponds to
/// @returns the new node
- Node* CreateNode(std::string tag, const ast::Node* ast = nullptr) {
- return current_function_->CreateNode(std::move(tag), ast);
+ inline Node* CreateNode(std::initializer_list<std::string_view> tag_list,
+ const ast::Node* ast = nullptr) {
+ return current_function_->CreateNode(std::move(tag_list), ast);
+ }
+
+ /// Get the symbol name of an AST node.
+ /// @param ast the AST node to get the symbol name of
+ /// @returns the symbol name
+ template <typename T>
+ inline std::string NameFor(const T* ast) {
+ return builder_->Symbols().NameFor(ast->symbol);
}
/// Process a function.
@@ -360,6 +391,25 @@
std::cout << "\n}\n";
#endif
+ /// Helper to generate a tag for the uniformity requirements of the parameter at `index`.
+ auto get_param_tag = [&](utils::UniqueVector<Node*, 4>& reachable, size_t index) {
+ auto* param = sem_.Get(func->params[index]);
+ auto& param_info = current_function_->parameters[index];
+ if (param->Type()->Is<type::Pointer>()) {
+ // For pointers, we distinguish between requiring uniformity of the contents versus
+ // the pointer itself.
+ if (reachable.Contains(param_info.ptr_input_contents)) {
+ return ParameterContentsRequiredToBeUniform;
+ } else if (reachable.Contains(param_info.value)) {
+ return ParameterValueRequiredToBeUniform;
+ }
+ } else if (reachable.Contains(current_function_->variables.Get(param))) {
+ // For non-pointers, the requirement is always on the value.
+ return ParameterValueRequiredToBeUniform;
+ }
+ return ParameterNoRestriction;
+ };
+
// Look at which nodes are reachable from "RequiredToBeUniform".
{
utils::UniqueVector<Node*, 4> reachable;
@@ -372,17 +422,13 @@
current_function_->callsite_tag = CallSiteRequiredToBeUniform;
}
- // Set the parameter tag to ParameterRequiredToBeUniform for each parameter node that
- // was reachable.
+ // Set the tags to capture the direct uniformity requirements of each parameter.
for (size_t i = 0; i < func->params.Length(); i++) {
- auto* param = func->params[i];
- if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) {
- current_function_->parameters[i].tag = ParameterRequiredToBeUniform;
- }
+ current_function_->parameters[i].tag_direct = get_param_tag(reachable, i);
}
}
- // If "Value_return" exists, look at which nodes are reachable from it
+ // If "Value_return" exists, look at which nodes are reachable from it.
if (current_function_->value_return) {
utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->value_return, &reachable);
@@ -390,20 +436,17 @@
current_function_->function_tag = ReturnValueMayBeNonUniform;
}
- // Set the parameter tag to ParameterRequiredToBeUniformForReturnValue for each
- // parameter node that was reachable.
+ // Set the tags to capture the uniformity requirements of each parameter with respect to
+ // the function return value.
for (size_t i = 0; i < func->params.Length(); i++) {
- auto* param = func->params[i];
- if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) {
- current_function_->parameters[i].tag =
- ParameterRequiredToBeUniformForReturnValue;
- }
+ current_function_->parameters[i].tag_retval = get_param_tag(reachable, i);
}
}
// Traverse the graph for each pointer parameter.
for (size_t i = 0; i < func->params.Length(); i++) {
- if (current_function_->parameters[i].pointer_return_value == nullptr) {
+ auto& param_info = current_function_->parameters[i];
+ if (param_info.ptr_output_contents == nullptr) {
continue;
}
@@ -411,17 +454,21 @@
current_function_->ResetVisited();
utils::UniqueVector<Node*, 4> reachable;
- Traverse(current_function_->parameters[i].pointer_return_value, &reachable);
+ Traverse(param_info.ptr_output_contents, &reachable);
if (reachable.Contains(current_function_->may_be_non_uniform)) {
- current_function_->parameters[i].pointer_may_become_non_uniform = true;
+ param_info.pointer_may_become_non_uniform = true;
}
- // Check every other parameter to see if they feed into this parameter's final value.
+ // Check every parameter to see if it feeds into this parameter's output value.
+ // This includes checking this parameter (as it may feed into its own output value), so
+ // we do not skip the `i==j` case.
for (size_t j = 0; j < func->params.Length(); j++) {
- auto* param_source = sem_.Get<sem::Parameter>(func->params[j]);
- if (reachable.Contains(current_function_->parameters[j].init_value)) {
- current_function_->parameters[i].pointer_param_output_sources.Push(
- param_source);
+ auto tag = get_param_tag(reachable, j);
+ auto* source_param = sem_.Get<sem::Parameter>(func->params[j]);
+ if (tag == ParameterContentsRequiredToBeUniform) {
+ param_info.ptr_output_source_param_contents.Push(source_param);
+ } else if (tag == ParameterValueRequiredToBeUniform) {
+ param_info.ptr_output_source_param_values.Push(source_param);
}
}
}
@@ -462,12 +509,28 @@
}
}
+ auto* parent = sem_.Get(b)->Parent();
+ auto* loop = parent ? parent->As<sem::LoopStatement>() : nullptr;
+ if (loop) {
+ // We've reached the end of a loop body. If there is a continuing block,
+ // process it before ending the block so that any variables declared in the
+ // loop body are visible to the continuing block.
+ if (auto* continuing =
+ loop->Declaration()->As<ast::LoopStatement>()->continuing) {
+ auto& loop_body_behavior = sem_.Get(b)->Behaviors();
+ if (loop_body_behavior.Contains(sem::Behavior::kNext) ||
+ loop_body_behavior.Contains(sem::Behavior::kContinue)) {
+ cf = ProcessStatement(cf, continuing);
+ }
+ }
+ }
+
if (sem_.Get<sem::FunctionBlockStatement>(b)) {
// We've reached the end of the function body.
// Add edges from pointer parameter outputs to their current value.
- for (auto param : current_function_->parameters) {
- if (param.pointer_return_value) {
- param.pointer_return_value->AddEdge(
+ for (auto& param : current_function_->parameters) {
+ if (param.ptr_output_contents) {
+ param.ptr_output_contents->AddEdge(
current_function_->variables.Get(param.sem));
}
}
@@ -512,8 +575,8 @@
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
+ auto name = NameFor(var->Declaration());
+ return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
}
@@ -529,7 +592,7 @@
auto [_, v_cond] = ProcessExpression(cf, b->condition);
// Add a diagnostic node to capture the control flow change.
- auto* v = current_function_->CreateNode("break_if_stmt", b);
+ auto* v = CreateNode({"break_if_stmt"}, b);
v->affects_control_flow = true;
v->AddEdge(v_cond);
@@ -548,8 +611,8 @@
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
+ auto name = NameFor(var->Declaration());
+ return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
@@ -558,7 +621,7 @@
auto* sem_break_if = sem_.Get(b);
if (sem_break_if->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) {
- auto* cf_end = CreateNode("break_if_CFend");
+ auto* cf_end = CreateNode({"break_if_CFend"});
cf_end->AddEdge(v);
return cf_end;
}
@@ -572,9 +635,11 @@
[&](const ast::CompoundAssignmentStatement* c) {
// The compound assignment statement `a += b` is equivalent to `a = a + b`.
- auto [cf1, v1] = ProcessExpression(cf, c->lhs);
+ // Note: we set load_rule=true when evaluating the LHS the first time, as the
+ // resolver does not add a load node for it.
+ auto [cf1, v1] = ProcessExpression(cf, c->lhs, /* load_rule */ true);
auto [cf2, v2] = ProcessExpression(cf1, c->rhs);
- auto* result = CreateNode("binary_expr_result");
+ auto* result = CreateNode({"binary_expr_result"});
result->AddEdge(v1);
result->AddEdge(v2);
@@ -591,20 +656,11 @@
auto& info = current_function_->LoopSwitchInfoFor(parent);
// Propagate assignments to the loop input nodes.
- for (auto* var : current_function_->local_var_decls) {
- // Skip variables that were declared inside this loop.
- if (auto* lv = var->As<sem::LocalVariable>();
- lv &&
- lv->Statement()->FindFirstParent([&](auto* s) { return s == parent; })) {
- continue;
- }
-
- // Add an edge from the variable's loop input node to its value at this point.
- auto in_node = info.var_in_nodes.Find(var);
- TINT_ASSERT(Resolver, in_node != nullptr);
- auto* out_node = current_function_->variables.Get(var);
- if (out_node != *in_node) {
- (*in_node)->AddEdge(out_node);
+ for (auto v : info.var_in_nodes) {
+ auto* in_node = v.value;
+ auto* out_node = current_function_->variables.Get(v.key);
+ if (out_node != in_node) {
+ in_node->AddEdge(out_node);
}
}
return cf;
@@ -614,7 +670,7 @@
[&](const ast::ForLoopStatement* f) {
auto* sem_loop = sem_.Get(f);
- auto* cfx = CreateNode("loop_start");
+ auto* cfx = CreateNode({"loop_start"});
// Insert the initializer before the loop.
auto* cf_init = cf;
@@ -628,8 +684,7 @@
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
- auto* in_node = CreateNode(name + "_value_forloop_in");
+ auto* in_node = CreateNode({NameFor(v->Declaration()), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
@@ -638,7 +693,7 @@
// Insert the condition at the start of the loop body.
if (f->condition) {
auto [cf_cond, v] = ProcessExpression(cfx, f->condition);
- auto* cf_condition_end = CreateNode("for_condition_CFend", f);
+ auto* cf_condition_end = CreateNode({"for_condition_CFend"}, f);
cf_condition_end->affects_control_flow = true;
cf_condition_end->AddEdge(v);
cf_start = cf_condition_end;
@@ -646,8 +701,8 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
+ auto name = NameFor(var->Declaration());
+ return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
}
@@ -677,6 +732,13 @@
current_function_->variables.Set(v.key, v.value);
}
+ if (f->initializer) {
+ // Remove variables declared in the for-loop initializer from the current scope.
+ if (auto* decl = f->initializer->As<ast::VariableDeclStatement>()) {
+ current_function_->local_var_decls.Remove(sem_.Get(decl->variable));
+ }
+ }
+
current_function_->RemoveLoopSwitchInfoFor(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
@@ -688,7 +750,7 @@
[&](const ast::WhileStatement* w) {
auto* sem_loop = sem_.Get(w);
- auto* cfx = CreateNode("loop_start");
+ auto* cfx = CreateNode({"loop_start"});
auto* cf_start = cf;
@@ -697,8 +759,7 @@
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
- auto* in_node = CreateNode(name + "_value_forloop_in");
+ auto* in_node = CreateNode({NameFor(v->Declaration()), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
@@ -707,7 +768,7 @@
// Insert the condition at the start of the loop body.
{
auto [cf_cond, v] = ProcessExpression(cfx, w->condition);
- auto* cf_condition_end = CreateNode("while_condition_CFend", w);
+ auto* cf_condition_end = CreateNode({"while_condition_CFend"}, w);
cf_condition_end->affects_control_flow = true;
cf_condition_end->AddEdge(v);
cf_start = cf_condition_end;
@@ -716,8 +777,8 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
+ auto name = NameFor(var->Declaration());
+ return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
}
@@ -753,7 +814,7 @@
auto [_, v_cond] = ProcessExpression(cf, i->condition);
// Add a diagnostic node to capture the control flow change.
- auto* v = current_function_->CreateNode("if_stmt", i);
+ auto* v = CreateNode({"if_stmt"}, i);
v->affects_control_flow = true;
v->AddEdge(v_cond);
@@ -800,8 +861,7 @@
}
// Create an exit node for the variable.
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- auto* out_node = CreateNode(name + "_value_if_exit");
+ auto* out_node = CreateNode({NameFor(var->Declaration()), "_value_if_exit"});
// Add edges to the assigned value or the initial value.
// Only add edges if the behavior for that block contains 'Next'.
@@ -824,7 +884,7 @@
}
if (sem_if->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) {
- auto* cf_end = CreateNode("if_CFend");
+ auto* cf_end = CreateNode({"if_CFend"});
cf_end->AddEdge(cf1);
if (cf2) {
cf_end->AddEdge(cf2);
@@ -836,8 +896,10 @@
[&](const ast::IncrementDecrementStatement* i) {
// The increment/decrement statement `i++` is equivalent to `i = i + 1`.
- auto [cf1, v1] = ProcessExpression(cf, i->lhs);
- auto* result = CreateNode("incdec_result");
+ // Note: we set load_rule=true when evaluating the LHS the first time, as the
+ // resolver does not add a load node for it.
+ auto [cf1, v1] = ProcessExpression(cf, i->lhs, /* load_rule */ true);
+ auto* result = CreateNode({"incdec_result"});
result->AddEdge(v1);
result->AddEdge(cf1);
@@ -848,27 +910,25 @@
[&](const ast::LoopStatement* l) {
auto* sem_loop = sem_.Get(l);
- auto* cfx = CreateNode("loop_start");
+ auto* cfx = CreateNode({"loop_start"});
auto& info = current_function_->LoopSwitchInfoFor(sem_loop);
info.type = "loop";
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
- auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration());
+ auto name = NameFor(v->Declaration());
+ auto* in_node = CreateNode({name, "_value_loop_in"}, v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
}
+ // Note: The continuing block is processed as a special case at the end of
+ // processing the loop body BlockStatement. This is so that variable declarations
+ // inside the loop body are visible to the continuing statement.
auto* cf1 = ProcessStatement(cfx, l->body);
- if (l->continuing) {
- auto* cf2 = ProcessStatement(cf1, l->continuing);
- cfx->AddEdge(cf2);
- } else {
- cfx->AddEdge(cf1);
- }
+ cfx->AddEdge(cf1);
cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop.
@@ -906,9 +966,9 @@
}
// Add edges from each pointer parameter output to its current value.
- for (auto param : current_function_->parameters) {
- if (param.pointer_return_value) {
- param.pointer_return_value->AddEdge(
+ for (auto& param : current_function_->parameters) {
+ if (param.ptr_output_contents) {
+ param.ptr_output_contents->AddEdge(
current_function_->variables.Get(param.sem));
}
}
@@ -921,13 +981,13 @@
auto [cfx, v_cond] = ProcessExpression(cf, s->condition);
// Add a diagnostic node to capture the control flow change.
- auto* v = current_function_->CreateNode("switch_stmt", s);
+ auto* v = CreateNode({"switch_stmt"}, s);
v->affects_control_flow = true;
v->AddEdge(v_cond);
Node* cf_end = nullptr;
if (sem_switch->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) {
- cf_end = CreateNode("switch_CFend");
+ cf_end = CreateNode({"switch_CFend"});
}
auto& info = current_function_->LoopSwitchInfoFor(sem_switch);
@@ -956,8 +1016,8 @@
// Add an edge from the variable exit node to its new value.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
- return CreateNode(name + "_value_" + info.type + "_exit");
+ auto name = NameFor(var->Declaration());
+ return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
}
@@ -1019,9 +1079,11 @@
/// Process an identifier expression.
/// @param cf the input control flow node
/// @param ident the identifier expression to process
+ /// @param load_rule true if the load rule is being invoked on this identifier
/// @returns a pair of (control flow node, value node)
std::pair<Node*, Node*> ProcessIdentExpression(Node* cf,
- const ast::IdentifierExpression* ident) {
+ const ast::IdentifierExpression* ident,
+ bool load_rule = false) {
// Helper to check if the entry point attribute of `obj` indicates non-uniformity.
auto has_nonuniform_entry_point_attribute = [](auto* obj) {
// Only the num_workgroups and workgroup_id builtins are uniform.
@@ -1034,9 +1096,9 @@
return true;
};
- auto name = builder_->Symbols().NameFor(ident->symbol);
- auto* sem = sem_.Get(ident)->Unwrap()->As<sem::VariableUser>()->Variable();
- auto* node = CreateNode(name + "_ident_expr", ident);
+ auto* var_user = sem_.Get(ident)->Unwrap()->As<sem::VariableUser>();
+ auto* sem = var_user->Variable();
+ auto* node = CreateNode({NameFor(ident), "_ident_expr"}, ident);
return Switch(
sem,
@@ -1063,28 +1125,77 @@
return std::make_pair(cf, node);
}
} else {
- auto* x = current_function_->variables.Get(param);
node->AddEdge(cf);
- node->AddEdge(x);
+
+ auto* current_value = current_function_->variables.Get(param);
+ if (param->Type()->Is<type::Pointer>()) {
+ if (load_rule) {
+ // We are loading from the pointer, so add an edge to its contents.
+ node->AddEdge(current_value);
+ } else {
+ // This is a pointer parameter that we are not loading from, so add an
+ // edge to the pointer value itself.
+ node->AddEdge(current_function_->parameters[param->Index()].value);
+ }
+ } else {
+ // The parameter is a value, so add an edge to it.
+ node->AddEdge(current_value);
+ }
+
return std::make_pair(cf, node);
}
},
[&](const sem::GlobalVariable* global) {
- if (!global->Declaration()->Is<ast::Var>() ||
- global->Access() == ast::Access::kRead) {
- node->AddEdge(cf);
- } else {
+ // Loads from global read-write variables may be non-uniform.
+ if (global->Declaration()->Is<ast::Var>() &&
+ global->Access() != ast::Access::kRead && load_rule) {
node->AddEdge(current_function_->may_be_non_uniform);
+ } else {
+ node->AddEdge(cf);
}
return std::make_pair(cf, node);
},
[&](const sem::LocalVariable* local) {
node->AddEdge(cf);
- if (auto* x = current_function_->variables.Get(local)) {
- node->AddEdge(x);
+
+ auto* local_value = current_function_->variables.Get(local);
+ if (local->Type()->Is<type::Pointer>()) {
+ if (load_rule) {
+ // We are loading from the pointer, so add an edge to its contents.
+ auto* root = var_user->RootIdentifier();
+ if (root->Is<sem::GlobalVariable>()) {
+ if (root->Access() != ast::Access::kRead) {
+ // The contents of a mutable global variable is always non-uniform.
+ node->AddEdge(current_function_->may_be_non_uniform);
+ }
+ } else {
+ node->AddEdge(current_function_->variables.Get(root));
+ }
+
+ // The uniformity of the contents also depends on the uniformity of the
+ // pointer itself. For a pointer captured in a let declaration, this will
+ // come from the value node of that declaration.
+ node->AddEdge(local_value);
+ } else {
+ // The variable is a pointer that we are not loading from, so add an edge to
+ // the pointer value itself.
+ node->AddEdge(local_value);
+ }
+ } else if (local->Type()->Is<type::Reference>()) {
+ if (load_rule) {
+ // We are loading from the reference, so add an edge to its contents.
+ node->AddEdge(local_value);
+ } else {
+ // References to local variables (i.e. var declarations) are always uniform,
+ // so no other edges needed.
+ }
+ } else {
+ // The identifier is a value declaration, so add an edge to it.
+ node->AddEdge(local_value);
}
+
return std::make_pair(cf, node);
},
@@ -1098,8 +1209,17 @@
/// Process an expression.
/// @param cf the input control flow node
/// @param expr the expression to process
+ /// @param load_rule true if the load rule is being invoked on this expression
/// @returns a pair of (control flow node, value node)
- std::pair<Node*, Node*> ProcessExpression(Node* cf, const ast::Expression* expr) {
+ std::pair<Node*, Node*> ProcessExpression(Node* cf,
+ const ast::Expression* expr,
+ bool load_rule = false) {
+ if (sem_.Get<sem::Load>(expr)) {
+ // Set the load-rule flag to indicate that identifier expressions in this sub-tree
+ // should add edges to the contents of the variables that they refer to.
+ load_rule = true;
+ }
+
return Switch(
expr,
@@ -1109,7 +1229,7 @@
auto [cf1, v1] = ProcessExpression(cf, b->lhs);
// Add a diagnostic node to capture the control flow change.
- auto* v1_cf = current_function_->CreateNode("short_circuit_op", b);
+ auto* v1_cf = CreateNode({"short_circuit_op"}, b);
v1_cf->affects_control_flow = true;
v1_cf->AddEdge(v1);
@@ -1118,7 +1238,7 @@
} else {
auto [cf1, v1] = ProcessExpression(cf, b->lhs);
auto [cf2, v2] = ProcessExpression(cf1, b->rhs);
- auto* result = CreateNode("binary_expr_result", b);
+ auto* result = CreateNode({"binary_expr_result"}, b);
result->AddEdge(v1);
result->AddEdge(v2);
return std::pair<Node*, Node*>(cf2, result);
@@ -1129,12 +1249,14 @@
[&](const ast::CallExpression* c) { return ProcessCall(cf, c); },
- [&](const ast::IdentifierExpression* i) { return ProcessIdentExpression(cf, i); },
+ [&](const ast::IdentifierExpression* i) {
+ return ProcessIdentExpression(cf, i, load_rule);
+ },
[&](const ast::IndexAccessorExpression* i) {
- auto [cf1, v1] = ProcessExpression(cf, i->object);
+ auto [cf1, v1] = ProcessExpression(cf, i->object, load_rule);
auto [cf2, v2] = ProcessExpression(cf1, i->index);
- auto* result = CreateNode("index_accessor_result");
+ auto* result = CreateNode({"index_accessor_result"});
result->AddEdge(v1);
result->AddEdge(v2);
return std::pair<Node*, Node*>(cf2, result);
@@ -1143,21 +1265,11 @@
[&](const ast::LiteralExpression*) { return std::make_pair(cf, cf); },
[&](const ast::MemberAccessorExpression* m) {
- return ProcessExpression(cf, m->structure);
+ return ProcessExpression(cf, m->structure, load_rule);
},
[&](const ast::UnaryOpExpression* u) {
- if (u->op == ast::UnaryOp::kIndirection) {
- // Cut the analysis short, since we only need to know the originating variable
- // which is being accessed.
- auto* root_ident = sem_.Get(u)->RootIdentifier();
- auto* value = current_function_->variables.Get(root_ident);
- if (!value) {
- value = cf;
- }
- return std::pair<Node*, Node*>(cf, value);
- }
- return ProcessExpression(cf, u->expr);
+ return ProcessExpression(cf, u->expr, load_rule);
},
[&](Default) {
@@ -1202,13 +1314,12 @@
expr,
[&](const ast::IdentifierExpression* i) {
- auto name = builder_->Symbols().NameFor(i->symbol);
auto* sem = sem_.Get(i)->UnwrapLoad()->As<sem::VariableUser>();
if (sem->Variable()->Is<sem::GlobalVariable>()) {
return std::make_pair(cf, current_function_->may_be_non_uniform);
} else if (auto* local = sem->Variable()->As<sem::LocalVariable>()) {
// Create a new value node for this variable.
- auto* value = CreateNode(name + "_lvalue");
+ auto* value = CreateNode({NameFor(i), "_lvalue"});
auto* old_value = current_function_->variables.Set(local, value);
// If i is part of an expression that is a partial reference to a variable (e.g.
@@ -1245,8 +1356,7 @@
// Cut the analysis short, since we only need to know the originating variable
// that is being written to.
auto* root_ident = sem_.Get(u)->RootIdentifier();
- auto name = builder_->Symbols().NameFor(root_ident->Declaration()->symbol);
- auto* deref = CreateNode(name + "_deref");
+ auto* deref = CreateNode({NameFor(root_ident->Declaration()), "_deref"});
auto* old_value = current_function_->variables.Set(root_ident, deref);
if (old_value) {
@@ -1276,7 +1386,7 @@
std::pair<Node*, Node*> ProcessCall(Node* cf, const ast::CallExpression* call) {
std::string name;
if (call->target.name) {
- name = builder_->Symbols().NameFor(call->target.name->symbol);
+ name = NameFor(call->target.name);
} else {
name = call->target.type->FriendlyName(builder_->Symbols());
}
@@ -1284,29 +1394,53 @@
// Process call arguments
Node* cf_last_arg = cf;
utils::Vector<Node*, 8> args;
+ utils::Vector<Node*, 8> ptrarg_contents;
+ ptrarg_contents.Resize(call->args.Length());
for (size_t i = 0; i < call->args.Length(); i++) {
auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]);
// Capture the index of this argument in a new node.
// Note: This is an additional node that isn't described in the specification, for the
// purpose of providing diagnostic information.
- Node* arg_node = CreateNode(name + "_arg_" + std::to_string(i), call);
- arg_node->type = Node::kFunctionCallArgument;
+ Node* arg_node = CreateNode({name, "_arg_", std::to_string(i)}, call);
+ arg_node->type = Node::kFunctionCallArgumentValue;
arg_node->arg_index = static_cast<uint32_t>(i);
arg_node->AddEdge(arg_i);
+ // For pointer arguments, create an additional node to represent the contents of that
+ // pointer prior to the function call.
+ auto* sem_arg = sem_.Get(call->args[i]);
+ if (sem_arg->Type()->Is<type::Pointer>()) {
+ auto* arg_contents =
+ CreateNode({name, "_ptrarg_", std::to_string(i), "_contents"}, call);
+ arg_contents->type = Node::kFunctionCallArgumentContents;
+ arg_contents->arg_index = static_cast<uint32_t>(i);
+
+ auto* root = sem_arg->RootIdentifier();
+ if (root->Is<sem::GlobalVariable>()) {
+ if (root->Access() != ast::Access::kRead) {
+ // The contents of a mutable global variable is always non-uniform.
+ arg_contents->AddEdge(current_function_->may_be_non_uniform);
+ }
+ } else {
+ arg_contents->AddEdge(current_function_->variables.Get(root));
+ }
+ arg_contents->AddEdge(arg_node);
+ ptrarg_contents[i] = arg_contents;
+ }
+
cf_last_arg = cf_i;
args.Push(arg_node);
}
// Note: This is an additional node that isn't described in the specification, for the
// purpose of providing diagnostic information.
- Node* call_node = CreateNode(name + "_call", call);
+ Node* call_node = CreateNode({name, "_call"}, call);
call_node->AddEdge(cf_last_arg);
- Node* result = CreateNode(name + "_return_value", call);
+ Node* result = CreateNode({name, "_return_value"}, call);
result->type = Node::kFunctionCallReturnValue;
- Node* cf_after = CreateNode("CF_after_" + name, call);
+ Node* cf_after = CreateNode({"CF_after_", name}, call);
// Get tags for the callee.
CallSiteTag callsite_tag = CallSiteNoRestriction;
@@ -1316,8 +1450,8 @@
Switch(
sem->Target(),
[&](const sem::Builtin* builtin) {
- // Most builtins have no restrictions. The exceptions are barriers, derivatives, and
- // some texture sampling builtins.
+ // Most builtins have no restrictions. The exceptions are barriers, derivatives,
+ // some texture sampling builtins, and atomics.
if (builtin->IsBarrier()) {
callsite_tag = CallSiteRequiredToBeUniform;
} else if (builtin->IsDerivative() ||
@@ -1326,6 +1460,9 @@
builtin->Type() == sem::BuiltinType::kTextureSampleCompare) {
callsite_tag = CallSiteRequiredToBeUniform;
function_tag = ReturnValueMayBeNonUniform;
+ } else if (builtin->IsAtomic()) {
+ callsite_tag = CallSiteNoRestriction;
+ function_tag = ReturnValueMayBeNonUniform;
} else {
callsite_tag = CallSiteNoRestriction;
function_tag = NoRestriction;
@@ -1366,24 +1503,42 @@
// For each argument, add edges based on parameter tags.
for (size_t i = 0; i < args.Length(); i++) {
if (func_info) {
- switch (func_info->parameters[i].tag) {
- case ParameterRequiredToBeUniform:
+ auto& param_info = func_info->parameters[i];
+
+ // Capture the direct uniformity requirements.
+ switch (param_info.tag_direct) {
+ case ParameterValueRequiredToBeUniform:
current_function_->required_to_be_uniform->AddEdge(args[i]);
break;
- case ParameterRequiredToBeUniformForReturnValue:
+ case ParameterContentsRequiredToBeUniform: {
+ current_function_->required_to_be_uniform->AddEdge(ptrarg_contents[i]);
+ break;
+ }
+ case ParameterNoRestriction:
+ break;
+ }
+ // Capture the effects of this parameter on the return value.
+ switch (param_info.tag_retval) {
+ case ParameterValueRequiredToBeUniform:
result->AddEdge(args[i]);
break;
+ case ParameterContentsRequiredToBeUniform: {
+ result->AddEdge(ptrarg_contents[i]);
+ break;
+ }
case ParameterNoRestriction:
break;
}
+ // Capture the effects of other call parameters on the contents of this parameter
+ // after the call returns.
auto* sem_arg = sem_.Get(call->args[i]);
if (sem_arg->Type()->Is<type::Pointer>()) {
auto* ptr_result =
- CreateNode(name + "_ptrarg_" + std::to_string(i) + "_result", call);
+ CreateNode({name, "_ptrarg_", std::to_string(i), "_result"}, call);
ptr_result->type = Node::kFunctionCallPointerArgumentResult;
ptr_result->arg_index = static_cast<uint32_t>(i);
- if (func_info->parameters[i].pointer_may_become_non_uniform) {
+ if (param_info.pointer_may_become_non_uniform) {
ptr_result->AddEdge(current_function_->may_be_non_uniform);
} else {
// Add edge to the call to catch when it's called in non-uniform control
@@ -1391,10 +1546,14 @@
ptr_result->AddEdge(call_node);
// Add edges from the resulting pointer value to any other arguments that
- // feed it.
- for (auto* source : func_info->parameters[i].pointer_param_output_sources) {
+ // feed it. We distinguish between requirements on the source arguments
+ // value versus its contents for pointer arguments.
+ for (auto* source : param_info.ptr_output_source_param_values) {
ptr_result->AddEdge(args[source->Index()]);
}
+ for (auto* source : param_info.ptr_output_source_param_contents) {
+ ptr_result->AddEdge(ptrarg_contents[source->Index()]);
+ }
}
// Update the current stored value for this pointer argument.
@@ -1405,12 +1564,7 @@
} else {
// All builtin function parameters are RequiredToBeUniformForReturnValue, as are
// parameters for type initializers and type conversions.
- // The arrayLength() builtin is a special case, as there is currently no way for it
- // to have a non-uniform return value.
- auto* builtin = sem->Target()->As<sem::Builtin>();
- if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
- result->AddEdge(args[i]);
- }
+ result->AddEdge(args[i]);
}
}
@@ -1539,8 +1693,7 @@
auto* var = sem_.Get(ident)->UnwrapLoad()->As<sem::VariableUser>()->Variable();
std::string var_type = get_var_type(var);
diagnostics_.add_note(diag::System::Resolver,
- "reading from " + var_type + "'" +
- builder_->Symbols().NameFor(ident->symbol) +
+ "reading from " + var_type + "'" + NameFor(ident) +
"' may result in a non-uniform value",
ident->source);
},
@@ -1548,14 +1701,12 @@
auto* var = sem_.Get(v);
std::string var_type = get_var_type(var);
diagnostics_.add_note(diag::System::Resolver,
- "reading from " + var_type + "'" +
- builder_->Symbols().NameFor(v->symbol) +
+ "reading from " + var_type + "'" + NameFor(v) +
"' may result in a non-uniform value",
v->source);
},
[&](const ast::CallExpression* c) {
- auto target_name = builder_->Symbols().NameFor(
- c->target.name->As<ast::IdentifierExpression>()->symbol);
+ auto target_name = NameFor(c->target.name);
switch (non_uniform_source->type) {
case Node::kFunctionCallReturnValue: {
diagnostics_.add_note(
@@ -1563,6 +1714,26 @@
"return value of '" + target_name + "' may be non-uniform", c->source);
break;
}
+ case Node::kFunctionCallArgumentContents: {
+ auto* arg = c->args[non_uniform_source->arg_index];
+ auto* var = sem_.Get(arg)->RootIdentifier();
+ std::string var_type = get_var_type(var);
+ diagnostics_.add_note(diag::System::Resolver,
+ "reading from " + var_type + "'" +
+ NameFor(var->Declaration()) +
+ "' may result in a non-uniform value",
+ var->Declaration()->source);
+ break;
+ }
+ case Node::kFunctionCallArgumentValue: {
+ auto* arg = c->args[non_uniform_source->arg_index];
+ // TODO(jrprice): Which output? (return value vs another pointer argument).
+ diagnostics_.add_note(diag::System::Resolver,
+ "passing non-uniform pointer to '" + target_name +
+ "' may produce a non-uniform output",
+ arg->source);
+ break;
+ }
case Node::kFunctionCallPointerArgumentResult: {
diagnostics_.add_note(
diag::System::Resolver,
@@ -1623,13 +1794,12 @@
if (auto* builtin = target->As<sem::Builtin>()) {
func_name = builtin->str();
} else if (auto* user = target->As<sem::Function>()) {
- func_name = builder_->Symbols().NameFor(user->Declaration()->symbol);
+ func_name = NameFor(user->Declaration());
}
- if (cause->type == Node::kFunctionCallArgument) {
+ if (cause->type == Node::kFunctionCallArgumentValue) {
// The requirement was on a function parameter.
- auto param_name = builder_->Symbols().NameFor(
- target->Parameters()[cause->arg_index]->Declaration()->symbol);
+ auto param_name = NameFor(target->Parameters()[cause->arg_index]->Declaration());
report(call->args[cause->arg_index]->source,
"parameter '" + param_name + "' of '" + func_name + "' must be uniform");
@@ -1637,7 +1807,21 @@
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
auto next_function = functions_.Find(user->Declaration());
- Node* next_cause = next_function->parameters[cause->arg_index].init_value;
+ Node* next_cause = next_function->parameters[cause->arg_index].value;
+ MakeError(*next_function, next_cause, true);
+ }
+ } else if (cause->type == Node::kFunctionCallArgumentContents) {
+ // The requirement was on the contents of a function parameter.
+ auto param_name = NameFor(target->Parameters()[cause->arg_index]->Declaration());
+ report(call->args[cause->arg_index]->source, "contents of parameter '" + param_name +
+ "' of '" + func_name +
+ "' must be uniform");
+
+ // If this is a call to a user-defined function, add a note to show the reason that the
+ // parameter is required to be uniform.
+ if (auto* user = target->As<sem::Function>()) {
+ auto next_function = functions_.Find(user->Declaration());
+ Node* next_cause = next_function->parameters[cause->arg_index].ptr_input_contents;
MakeError(*next_function, next_cause, true);
}
} else {
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index fd632df..3afa094 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -1545,6 +1545,71 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, Loop_NonUniformValueDeclaredInBody_UnreachableContinuing) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var condition = true;
+ loop {
+ var v = non_uniform;
+ if (condition) {
+ break;
+ } else {
+ break;
+ }
+
+ continuing {
+ if (v == 0) {
+ workgroupBarrier();
+ }
+ }
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, Loop_NonUniformValueDeclaredInBody_MaybeReachesContinuing) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ var condition = true;
+ loop {
+ var v = non_uniform;
+ if (condition) {
+ continue;
+ } else {
+ break;
+ }
+
+ continuing {
+ if (v == 0) {
+ workgroupBarrier();
+ }
+ }
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:16:9 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:15:7 note: control flow depends on non-uniform value
+ if (v == 0) {
+ ^^
+
+test:7:13 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ var v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, Loop_NonUniformBreakInBody_Reconverge) {
// Loops reconverge at exit, so test that we can call workgroupBarrier() after a loop that
// contains a non-uniform conditional break.
@@ -1939,6 +2004,43 @@
)");
}
+TEST_F(UniformityAnalysisTest,
+ ForLoop_InitializerVarBecomesNonUniformBeforeConditionalContinue_BarrierAtStart) {
+ // Use a variable declared in a for-loop initializer for a conditional barrier in a loop, assign
+ // a non-uniform value to that variable later in that loop and then execute a continue.
+ // Tests that variables declared in the for-loop initializer are properly tracked.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ for (var i = 0; i < 10; i++) {
+ if (i < 5) {
+ workgroupBarrier();
+ }
+ if (true) {
+ i = non_uniform;
+ continue;
+ }
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:7:7 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:5:3 note: control flow depends on non-uniform value
+ for (var i = 0; i < 10; i++) {
+ ^^^
+
+test:10:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ i = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, ForLoop_NonUniformCondition_Reconverge) {
// Loops reconverge at exit, so test that we can call workgroupBarrier() after a loop that has a
// non-uniform condition.
@@ -1955,6 +2057,41 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, ForLoop_VarDeclaredInBody) {
+ // Make sure that we can declare a variable inside the loop body without causing issues for
+ // tracking local variables across iterations.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> n : i32;
+
+fn foo() {
+ var outer : i32;
+ for (var i = 0; i < n; i = i + 1) {
+ var inner : i32;
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, ForLoop_InitializerScope) {
+ // Make sure that variables declared in a for-loop initializer are properly removed from the
+ // local variable list, otherwise a parent control-flow statement will try to add edges to nodes
+ // that no longer exist.
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> n : i32;
+
+fn foo() {
+ if (n == 5) {
+ for (var i = 0; i < n; i = i + 1) {
+ }
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
TEST_F(UniformityAnalysisTest, While_CallInside_UniformCondition) {
std::string src = R"(
@group(0) @binding(0) var<storage, read> n : i32;
@@ -3523,7 +3660,7 @@
)");
}
-TEST_F(UniformityAnalysisTest, LoadNonUniformThroughCapturedPointer) {
+TEST_F(UniformityAnalysisTest, LoadNonUniformLocalThroughCapturedPointer) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@@ -3552,7 +3689,7 @@
)");
}
-TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter) {
+TEST_F(UniformityAnalysisTest, LoadNonUniformLocalThroughPointerParameter) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@@ -3570,7 +3707,167 @@
RunTest(src, false);
EXPECT_EQ(error_,
- R"(test:12:7 warning: parameter 'p' of 'bar' must be uniform
+ R"(test:12:7 warning: contents of parameter 'p' of 'bar' must be uniform
+ bar(&v);
+ ^
+
+test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:11:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ var v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformGlobalThroughCapturedPointer) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ let pv = &non_uniform;
+ if (*pv == 0) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:7:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:6:3 note: control flow depends on non-uniform value
+ if (*pv == 0) {
+ ^^
+
+test:6:8 note: reading from 'pv' may result in a non-uniform value
+ if (*pv == 0) {
+ ^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformGlobalThroughPointerParameter) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<storage, i32, read_write>) {
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn foo() {
+ bar(&non_uniform);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:13:7 warning: contents of parameter 'p' of 'bar' must be uniform
+ bar(&non_uniform);
+ ^
+
+test:8:5 note: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:4:48 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformGlobalThroughPointerParameter_ViaReturnValue) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<storage, i32, read_write>) -> i32 {
+ return *p;
+}
+
+fn foo() {
+ if (0 == bar(&non_uniform)) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:12:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:11:3 note: control flow depends on non-uniform value
+ if (0 == bar(&non_uniform)) {
+ ^^
+
+test:4:48 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter_BecomesUniformAfterUse) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) {
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+ *p = 0;
+}
+
+fn foo() {
+ var v = non_uniform;
+ bar(&v);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:13:7 warning: contents of parameter 'p' of 'bar' must be uniform
+ bar(&v);
+ ^
+
+test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:12:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ var v = non_uniform;
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter_BecomesUniformAfterCall) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) {
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn foo() {
+ var v = non_uniform;
+ bar(&v);
+ v = 0;
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:12:7 warning: contents of parameter 'p' of 'bar' must be uniform
bar(&v);
^
@@ -3628,6 +3925,198 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, LoadUniformThroughNonUniformPointer) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn foo() {
+ // The contents of `v` are uniform.
+ var v = array<i32, 4>();
+ // The pointer `p` is non-uniform.
+ let p = &v[non_uniform];
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:9:3 note: control flow depends on non-uniform value
+ if (*p == 0) {
+ ^^
+
+test:8:14 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ let p = &v[non_uniform];
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadUniformThroughNonUniformPointer_ViaParameter) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, array<i32, 4>>) {
+ // The pointer `p` is non-uniform.
+ let local_p = &(*p)[non_uniform];
+ if (*local_p == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn foo() {
+ // The contents of `v` are uniform.
+ var v = array<i32, 4>();
+ let p = &v;
+ bar(p);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:9:3 note: control flow depends on non-uniform value
+ if (*local_p == 0) {
+ ^^
+
+test:8:23 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ let local_p = &(*p)[non_uniform];
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadUniformThroughNonUniformPointer_ViaParameterChain) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn zoo(p : ptr<function, i32>) {
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn bar(p : ptr<function, i32>) {
+ zoo(p);
+}
+
+fn foo() {
+ // The contents of `v` are uniform.
+ var v = array<i32, 4>();
+ // The pointer `p` is non-uniform.
+ let p = &v[non_uniform];
+ bar(p);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:21:7 warning: contents of parameter 'p' of 'bar' must be uniform
+ bar(p);
+ ^
+
+test:13:7 note: contents of parameter 'p' of 'zoo' must be uniform
+ zoo(p);
+ ^
+
+test:8:5 note: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:20:14 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ let p = &v[non_uniform];
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformThroughUniformPointer) {
+ std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+@group(0) @binding(1) var<storage, read> uniform_idx : i32;
+
+fn foo() {
+ // The contents of `v` are non-uniform.
+ var v = array<i32, 4>(0, 0, 0, non_uniform);
+ // The pointer `p` is uniform.
+ let p = &v[uniform_idx];
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:11:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:10:3 note: control flow depends on non-uniform value
+ if (*p == 0) {
+ ^^
+
+test:7:34 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ var v = array<i32, 4>(0, 0, 0, non_uniform);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, LoadNonUniformThroughUniformPointer_ViaParameter) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+@group(0) @binding(1) var<storage, read> uniform_idx : i32;
+
+fn zoo(p : ptr<function, i32>) {
+ if (*p == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn bar(p : ptr<function, i32>) {
+ zoo(p);
+}
+
+fn foo() {
+ // The contents of `v` are non-uniform.
+ var v = array<i32, 4>(0, 0, 0, non_uniform);
+ // The pointer `p` is uniform.
+ let p = &v[uniform_idx];
+ bar(p);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:22:7 warning: contents of parameter 'p' of 'bar' must be uniform
+ bar(p);
+ ^
+
+test:14:7 note: contents of parameter 'p' of 'zoo' must be uniform
+ zoo(p);
+ ^
+
+test:9:5 note: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:19:34 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+ var v = array<i32, 4>(0, 0, 0, non_uniform);
+ ^^^^^^^^^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, StoreNonUniformAfterCapturingPointer) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@@ -6918,6 +7407,190 @@
RunTest(src, true);
}
+TEST_F(UniformityAnalysisTest, ArrayLength_OnPtrArg) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> arr : array<f32>;
+
+fn bar(p : ptr<storage, array<f32>, read_write>) {
+ for (var i = 0u; i < arrayLength(p); i++) {
+ workgroupBarrier();
+ }
+}
+
+fn foo() {
+ bar(&arr);
+}
+)";
+
+ RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, ArrayLength_PtrArgRequiredToBeUniformForRetval_Pass) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> arr : array<f32>;
+
+fn length(p : ptr<storage, array<f32>, read_write>) -> u32 {
+ return arrayLength(p);
+}
+
+fn foo() {
+ for (var i = 0u; i < length(&arr); i++) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+// TODO(jrprice): This test requires variable pointers.
+TEST_F(UniformityAnalysisTest, DISABLED_ArrayLength_PtrArgRequiredToBeUniformForRetval_Fail) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+@group(0) @binding(1) var<storage, read_write> arr1 : array<f32>;
+@group(0) @binding(2) var<storage, read_write> arr2 : array<f32>;
+
+fn length(p : ptr<storage, array<f32>, read_write>) -> u32 {
+ return arrayLength(p);
+}
+
+fn foo() {
+ let non_uniform_ptr = select(&arr1, &arr2, non_uniform == 0);
+ let len = length(non_uniform_ptr);
+ if (len > 10) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:16:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:15:3 note: control flow depends on non-uniform value
+ if (len > 10) {
+ ^^
+
+test:14:20 note: passing non-uniform pointer to 'length' may produce a non-uniform output
+ let len = length(non_uniform_ptr, &len);
+ ^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, ArrayLength_PtrArgRequiredToBeUniformForOtherPtrResult_Pass) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> arr : array<f32>;
+
+fn length(p : ptr<storage, array<f32>, read_write>, out : ptr<function, u32>) {
+ *out = arrayLength(p);
+}
+
+fn foo() {
+ var len : u32;
+ length(&arr, &len);
+ if (len > 10) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, true);
+}
+
+// TODO(jrprice): This test requires variable pointers.
+TEST_F(UniformityAnalysisTest,
+ DISABLED_ArrayLength_PtrArgRequiredToBeUniformForOtherPtrResult_Fail) {
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+@group(0) @binding(1) var<storage, read_write> arr1 : array<f32>;
+@group(0) @binding(2) var<storage, read_write> arr2 : array<f32>;
+
+fn length(p : ptr<storage, array<f32>, read_write>, out : ptr<function, u32>) {
+ *out = arrayLength(p);
+}
+
+fn foo() {
+ var len : u32;
+ let non_uniform_ptr = select(&arr1, &arr2, non_uniform == 0);
+ length(non_uniform_ptr, &len);
+ if (len > 10) {
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:17:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:16:3 note: control flow depends on non-uniform value
+ if (len > 10) {
+ ^^
+
+test:15:10 note: passing non-uniform pointer to 'length' may produce a non-uniform output
+ length(non_uniform_ptr, &len);
+ ^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, ArrayLength_PtrArgRequiresUniformityAndAffectsReturnValue) {
+ // Test that a single pointer argument can directly require uniformity as well as affecting the
+ // uniformity of the return value.
+ std::string src = R"(
+enable chromium_experimental_full_ptr_parameters;
+
+@group(0) @binding(0) var<storage, read_write> arr : array<u32>;
+
+fn bar(p : ptr<storage, array<u32>, read_write>) -> u32 {
+ // This requires `p` to always be uniform.
+ if (arrayLength(p) == 10) {
+ workgroupBarrier();
+ }
+
+ // This requires the contents of `p` to be uniform in order for the return value to be uniform.
+ return (*p)[0];
+}
+
+fn foo() {
+ let p = &arr;
+ // We pass a uniform pointer, so the direct uniformity requirement on the parameter is satisfied.
+ if (0 == bar(p)) {
+ // This will fail as the return value of `p` is non-uniform due to non-uniform contents of `p`.
+ workgroupBarrier();
+ }
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:21:5 warning: 'workgroupBarrier' must only be called from uniform control flow
+ workgroupBarrier();
+ ^^^^^^^^^^^^^^^^
+
+test:19:3 note: control flow depends on non-uniform value
+ if (0 == bar(p)) {
+ ^^
+
+test:4:48 note: reading from read_write storage buffer 'arr' may result in a non-uniform value
+@group(0) @binding(0) var<storage, read_write> arr : array<u32>;
+ ^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, WorkgroupAtomics) {
std::string src = R"(
var<workgroup> a : atomic<i32>;
@@ -6939,9 +7612,9 @@
if (atomicAdd(&a, 1) == 1) {
^^
-test:5:18 note: reading from workgroup storage variable 'a' may result in a non-uniform value
+test:5:7 note: return value of 'atomicAdd' may be non-uniform
if (atomicAdd(&a, 1) == 1) {
- ^
+ ^^^^^^^^^
)");
}
@@ -6966,9 +7639,9 @@
if (atomicAdd(&a, 1) == 1) {
^^
-test:5:18 note: reading from read_write storage buffer 'a' may result in a non-uniform value
+test:5:7 note: return value of 'atomicAdd' may be non-uniform
if (atomicAdd(&a, 1) == 1) {
- ^
+ ^^^^^^^^^
)");
}
diff --git a/src/tint/sem/type_mappings.h b/src/tint/sem/type_mappings.h
index 6502f9a..f8aca4d 100644
--- a/src/tint/sem/type_mappings.h
+++ b/src/tint/sem/type_mappings.h
@@ -18,6 +18,9 @@
#include <type_traits>
// Forward declarations
+namespace tint {
+class CastableBase;
+} // namespace tint
namespace tint::ast {
class Array;
class Expression;
diff --git a/src/tint/utils/bitcast.h b/src/tint/utils/bitcast.h
index 4450336..4f72bb8 100644
--- a/src/tint/utils/bitcast.h
+++ b/src/tint/utils/bitcast.h
@@ -15,7 +15,9 @@
#ifndef SRC_TINT_UTILS_BITCAST_H_
#define SRC_TINT_UTILS_BITCAST_H_
+#include <cstddef>
#include <cstring>
+#include <type_traits>
namespace tint::utils {
@@ -29,8 +31,20 @@
template <typename TO, typename FROM>
inline TO Bitcast(FROM&& from) {
static_assert(sizeof(FROM) == sizeof(TO));
+ // gcc warns in cases where either TO or FROM are classes, even if they are trivially
+ // copyable, with for example:
+ //
+ // error: ‘void* memcpy(void*, const void*, size_t)’ copying an object of
+ // non-trivial type ‘struct tint::Number<unsigned int>’ from an array of ‘float’
+ // [-Werror=class-memaccess]
+ //
+ // We avoid this by asserting that both types are indeed trivially copyable, and casting both
+ // args to std::byte*.
+ static_assert(std::is_trivially_copyable_v<std::decay_t<FROM>>);
+ static_assert(std::is_trivially_copyable_v<std::decay_t<TO>>);
TO to;
- memcpy(&to, &from, sizeof(TO));
+ memcpy(reinterpret_cast<std::byte*>(&to), reinterpret_cast<const std::byte*>(&from),
+ sizeof(TO));
return to;
}
diff --git a/src/tint/writer/glsl/generator_impl_bitcast_test.cc b/src/tint/writer/glsl/generator_impl_bitcast_test.cc
index a56c5dd..fa14367 100644
--- a/src/tint/writer/glsl/generator_impl_bitcast_test.cc
+++ b/src/tint/writer/glsl/generator_impl_bitcast_test.cc
@@ -22,36 +22,39 @@
using GlslGeneratorImplTest_Bitcast = TestHelper;
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Float) {
- auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_i));
+ auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "intBitsToFloat(1)");
+ EXPECT_EQ(out.str(), "intBitsToFloat(a)");
}
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Int) {
- auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr(1_u));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_u));
+ auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "int(1u)");
+ EXPECT_EQ(out.str(), "int(a)");
}
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Uint) {
- auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr(1_i));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_i));
+ auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "uint(1)");
+ EXPECT_EQ(out.str(), "uint(a)");
}
} // namespace
diff --git a/src/tint/writer/hlsl/generator_impl_bitcast_test.cc b/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
index 8305d64..de0f9de 100644
--- a/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_bitcast_test.cc
@@ -22,36 +22,39 @@
using HlslGeneratorImplTest_Bitcast = TestHelper;
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Float) {
- auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_i));
+ auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "asfloat(1)");
+ EXPECT_EQ(out.str(), "asfloat(a)");
}
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Int) {
- auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr(1_u));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_u));
+ auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "asint(1u)");
+ EXPECT_EQ(out.str(), "asint(a)");
}
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Uint) {
- auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr(1_i));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_i));
+ auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "asuint(1)");
+ EXPECT_EQ(out.str(), "asuint(a)");
}
} // namespace
diff --git a/src/tint/writer/msl/generator_impl_bitcast_test.cc b/src/tint/writer/msl/generator_impl_bitcast_test.cc
index c398558..0393c8a 100644
--- a/src/tint/writer/msl/generator_impl_bitcast_test.cc
+++ b/src/tint/writer/msl/generator_impl_bitcast_test.cc
@@ -22,14 +22,15 @@
using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, EmitExpression_Bitcast) {
- auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
- WrapInFunction(bitcast);
+ auto* a = Let("a", Expr(1_i));
+ auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
+ WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
- EXPECT_EQ(out.str(), "as_type<float>(1)");
+ EXPECT_EQ(out.str(), "as_type<float>(a)");
}
} // namespace