Import Tint changes from Dawn
Changes:
- 841db7f515a57e5677d70d023c51586a096f0a6f Add libfuzzer_exports.h to Tint fuzzer common code by Ryan Harrison <rharrison@chromium.org>
- ff2cb02e920a07ae3b9d99709db7f38c9c6b0cf5 Add bitwise_expression to the WGSL parser. by dan sinclair <dsinclair@chromium.org>
- e86758cb26ed69b8e9502037798298f11bcd9edc reader/wgsl: Error if 'struct' has attributes by Ben Clayton <bclayton@google.com>
- ff1330240b4e2476ed8ae8f96b7bff2b05ba63e4 tint/transform: Fix index 0 accessing in DecomposeMemoryA... by Zhaoming Jiang <zhaoming.jiang@intel.com>
- dce63f571730edd14f0247fb4c41e7aed154cc90 tint/utils/UniqueVector: Use utils::Vector and utils::Has... by Ben Clayton <bclayton@google.com>
- b79238d7ec35de21aa333e31fb239addfee1e2a4 tint: Implement const eval of binary minus by Antonio Maiorano <amaiorano@google.com>
- eb0af9def79c2398f85dc49c512b03eedf117a57 Add optionally_typed_ident to WGSL parser. by dan sinclair <dsinclair@chromium.org>
- 0f0ba20208be96cf393f2bacf16e697df108457f tint/transform: Fix PromoteInitializersToLetTest for mate... by Ben Clayton <bclayton@google.com>
- 873f92e741517e28467c5befe7a1b1ab8433ca9d Convert assignment_statement to new WGSL grammar. by dan sinclair <dsinclair@chromium.org>
- 6c8dc15d64c51fe39c6d941a75469d9ac9376f4d Add core_lhs_expression and lhs_expression to parser. by dan sinclair <dsinclair@chromium.org>
- e13160efb61c34b2b816971e26437df6f792260a tint/utils: Add Hashmap and Hashset by Ben Clayton <bclayton@google.com>
- 81f06865235c255fc309fdb30a83978a1e39fd58 tint/utils/vector: Allow use of incomplete types by Ben Clayton <bclayton@google.com>
- 6e716c77accdd7f1fd015bceff35a1f8d9e3aecc Sync some WGSL grammar names to spec by dan sinclair <dsinclair@chromium.org>
- f8a34d08ddfcc66c3b956e2c9f18280be7fa8f08 tint: Add CheckedSub functions by Antonio Maiorano <amaiorano@google.com>
GitOrigin-RevId: 841db7f515a57e5677d70d023c51586a096f0a6f
Change-Id: I2ff87acfadfc2fcf489a1fe88a5637de1fbe1060
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/99660
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index cf3ebae..8e9a576 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -565,6 +565,8 @@
"utils/debugger.h",
"utils/enum_set.h",
"utils/hash.h",
+ "utils/hashmap.h",
+ "utils/hashset.h",
"utils/map.h",
"utils/math.h",
"utils/scoped_assignment.h",
@@ -1233,6 +1235,8 @@
"utils/defer_test.cc",
"utils/enum_set_test.cc",
"utils/hash_test.cc",
+ "utils/hashmap_test.cc",
+ "utils/hashset_test.cc",
"utils/io/command_test.cc",
"utils/io/tmpfile_test.cc",
"utils/map_test.cc",
@@ -1349,6 +1353,7 @@
"reader/wgsl/parser_impl_and_expression_test.cc",
"reader/wgsl/parser_impl_argument_expression_list_test.cc",
"reader/wgsl/parser_impl_assignment_stmt_test.cc",
+ "reader/wgsl/parser_impl_bitwise_expression_test.cc",
"reader/wgsl/parser_impl_break_stmt_test.cc",
"reader/wgsl/parser_impl_bug_cases_test.cc",
"reader/wgsl/parser_impl_call_stmt_test.cc",
@@ -1357,6 +1362,7 @@
"reader/wgsl/parser_impl_const_literal_test.cc",
"reader/wgsl/parser_impl_continue_stmt_test.cc",
"reader/wgsl/parser_impl_continuing_stmt_test.cc",
+ "reader/wgsl/parser_impl_core_lhs_expression_test.cc",
"reader/wgsl/parser_impl_depth_texture_test.cc",
"reader/wgsl/parser_impl_enable_directive_test.cc",
"reader/wgsl/parser_impl_equality_expression_test.cc",
@@ -1375,6 +1381,7 @@
"reader/wgsl/parser_impl_if_stmt_test.cc",
"reader/wgsl/parser_impl_inclusive_or_expression_test.cc",
"reader/wgsl/parser_impl_increment_decrement_stmt_test.cc",
+ "reader/wgsl/parser_impl_lhs_expression_test.cc",
"reader/wgsl/parser_impl_logical_and_expression_test.cc",
"reader/wgsl/parser_impl_logical_or_expression_test.cc",
"reader/wgsl/parser_impl_loop_stmt_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 2d25330..8612975 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -475,6 +475,8 @@
utils/crc32.h
utils/enum_set.h
utils/hash.h
+ utils/hashmap.h
+ utils/hashset.h
utils/map.h
utils/math.h
utils/scoped_assignment.h
@@ -863,6 +865,8 @@
utils/hash_test.cc
utils/io/command_test.cc
utils/io/tmpfile_test.cc
+ utils/hashmap_test.cc
+ utils/hashset_test.cc
utils/map_test.cc
utils/math_test.cc
utils/result_test.cc
@@ -944,6 +948,7 @@
reader/wgsl/parser_impl_and_expression_test.cc
reader/wgsl/parser_impl_argument_expression_list_test.cc
reader/wgsl/parser_impl_assignment_stmt_test.cc
+ reader/wgsl/parser_impl_bitwise_expression_test.cc
reader/wgsl/parser_impl_break_stmt_test.cc
reader/wgsl/parser_impl_bug_cases_test.cc
reader/wgsl/parser_impl_call_stmt_test.cc
@@ -952,6 +957,7 @@
reader/wgsl/parser_impl_const_literal_test.cc
reader/wgsl/parser_impl_continue_stmt_test.cc
reader/wgsl/parser_impl_continuing_stmt_test.cc
+ reader/wgsl/parser_impl_core_lhs_expression_test.cc
reader/wgsl/parser_impl_depth_texture_test.cc
reader/wgsl/parser_impl_enable_directive_test.cc
reader/wgsl/parser_impl_external_texture_test.cc
@@ -970,6 +976,7 @@
reader/wgsl/parser_impl_if_stmt_test.cc
reader/wgsl/parser_impl_inclusive_or_expression_test.cc
reader/wgsl/parser_impl_increment_decrement_stmt_test.cc
+ reader/wgsl/parser_impl_lhs_expression_test.cc
reader/wgsl/parser_impl_logical_and_expression_test.cc
reader/wgsl/parser_impl_logical_or_expression_test.cc
reader/wgsl/parser_impl_loop_stmt_test.cc
diff --git a/src/tint/ast/extension.h b/src/tint/ast/extension.h
index ce3d48a..eea9bec 100644
--- a/src/tint/ast/extension.h
+++ b/src/tint/ast/extension.h
@@ -50,7 +50,7 @@
Extension ParseExtension(std::string_view str);
// A unique vector of extensions
-using Extensions = utils::UniqueVector<Extension>;
+using Extensions = utils::UniqueVector<Extension, 4>;
} // namespace tint::ast
diff --git a/src/tint/ast/extension.h.tmpl b/src/tint/ast/extension.h.tmpl
index 395aeec..29feff9 100644
--- a/src/tint/ast/extension.h.tmpl
+++ b/src/tint/ast/extension.h.tmpl
@@ -25,7 +25,7 @@
{{ Eval "DeclareEnum" $enum}}
// A unique vector of extensions
-using Extensions = utils::UniqueVector<Extension>;
+using Extensions = utils::UniqueVector<Extension, 4>;
} // namespace tint::ast
diff --git a/src/tint/fuzzers/BUILD.gn b/src/tint/fuzzers/BUILD.gn
index 2b47762..48d5e66 100644
--- a/src/tint/fuzzers/BUILD.gn
+++ b/src/tint/fuzzers/BUILD.gn
@@ -85,6 +85,10 @@
"tint_reader_writer_fuzzer.h",
"transform_builder.h",
]
+
+ if (is_mac) {
+ sources += [ "//testing/libfuzzer/libfuzzer_exports.h" ]
+ }
}
source_set("tint_fuzzer_common_with_init_src") {
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index e9c4d88..6073470 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -505,7 +505,7 @@
ResourceBinding::ResourceType::kExternalTexture);
}
-std::vector<sem::SamplerTexturePair> Inspector::GetSamplerTextureUses(
+utils::Vector<sem::SamplerTexturePair, 4> Inspector::GetSamplerTextureUses(
const std::string& entry_point) {
auto* func = FindEntryPointByName(entry_point);
if (!func) {
@@ -570,7 +570,7 @@
std::vector<std::string> Inspector::GetUsedExtensionNames() {
auto& extensions = program_->Sem().Module()->Extensions();
std::vector<std::string> out;
- out.reserve(extensions.size());
+ out.reserve(extensions.Length());
for (auto ext : extensions) {
out.push_back(utils::ToString(ext));
}
@@ -789,7 +789,7 @@
}
sampler_targets_ = std::make_unique<
- std::unordered_map<std::string, utils::UniqueVector<sem::SamplerTexturePair>>>();
+ std::unordered_map<std::string, utils::UniqueVector<sem::SamplerTexturePair, 4>>>();
auto& sem = program_->Sem();
@@ -849,7 +849,7 @@
for (auto* entry_point : entry_points) {
const auto& ep_name =
program_->Symbols().NameFor(entry_point->Declaration()->symbol);
- (*sampler_targets_)[ep_name].add(
+ (*sampler_targets_)[ep_name].Add(
{sampler_binding_point, texture_binding_point});
}
});
@@ -868,7 +868,7 @@
std::array<const sem::GlobalVariable*, N> globals{};
std::array<const sem::Parameter*, N> parameters{};
- utils::UniqueVector<const ast::CallExpression*> callsites;
+ utils::UniqueVector<const ast::CallExpression*, 8> callsites;
for (size_t i = 0; i < N; i++) {
const sem::Variable* source_var = sem.Get(exprs[i])->SourceVariable();
@@ -882,7 +882,7 @@
return;
}
for (auto* call : func->CallSites()) {
- callsites.add(call->Declaration());
+ callsites.Add(call->Declaration());
}
parameters[i] = param;
} else {
@@ -893,7 +893,7 @@
}
}
- if (callsites.size()) {
+ if (callsites.Length()) {
for (auto* call_expr : callsites) {
// Make a copy of the expressions for this callsite
std::array<const ast::Expression*, N> call_exprs = exprs;
diff --git a/src/tint/inspector/inspector.h b/src/tint/inspector/inspector.h
index 97707db..f3fe270 100644
--- a/src/tint/inspector/inspector.h
+++ b/src/tint/inspector/inspector.h
@@ -122,7 +122,7 @@
/// @param entry_point name of the entry point to get information about.
/// @returns vector of all of the sampler/texture sampling pairs that are used
/// by that entry point.
- std::vector<sem::SamplerTexturePair> GetSamplerTextureUses(const std::string& entry_point);
+ utils::Vector<sem::SamplerTexturePair, 4> GetSamplerTextureUses(const std::string& entry_point);
/// @param entry_point name of the entry point to get information about.
/// @param placeholder the sampler binding point to use for texture-only
@@ -153,7 +153,8 @@
private:
const Program* program_;
diag::List diagnostics_;
- std::unique_ptr<std::unordered_map<std::string, utils::UniqueVector<sem::SamplerTexturePair>>>
+ std::unique_ptr<
+ std::unordered_map<std::string, utils::UniqueVector<sem::SamplerTexturePair, 4>>>
sampler_targets_;
/// @param name name of the entry point to find
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 543239c..cb46833 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -2971,7 +2971,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(0u, result.size());
+ ASSERT_EQ(0u, result.Length());
}
TEST_F(InspectorGetSamplerTextureUsesTest, Simple) {
@@ -2989,7 +2989,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3053,7 +3053,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3080,7 +3080,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3107,7 +3107,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3134,7 +3134,7 @@
auto result = inspector.GetSamplerTextureUses("main");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3188,7 +3188,7 @@
auto result = inspector.GetSamplerTextureUses("via_call");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3200,7 +3200,7 @@
auto result = inspector.GetSamplerTextureUses("via_ptr");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
@@ -3212,7 +3212,7 @@
auto result = inspector.GetSamplerTextureUses("direct");
ASSERT_FALSE(inspector.has_error()) << inspector.error();
- ASSERT_EQ(1u, result.size());
+ ASSERT_EQ(1u, result.Length());
EXPECT_EQ(0u, result[0].sampler_binding_point.group);
EXPECT_EQ(1u, result[0].sampler_binding_point.binding);
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index 54db7f5..2b84d45 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -880,8 +880,8 @@
@const op ~ <T: ia_iu32>(T) -> T
@const op ~ <T: ia_iu32, N: num> (vec<N, T>) -> vec<N, T>
-@const op - <T: fia_fi32_f16>(T) -> T
-@const op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T>
+@const("UnaryMinus") op - <T: fia_fi32_f16>(T) -> T
+@const("UnaryMinus") op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T>
////////////////////////////////////////////////////////////////////////////////
// Binary Operators //
@@ -892,11 +892,11 @@
@const op + <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op + <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
-op - <T: fiu32_f16>(T, T) -> T
-op - <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
-op - <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
-op - <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
-op - <T: f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
+@const op - <T: fia_fiu32_f16>(T, T) -> T
+@const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
+@const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
+@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
+@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op * <T: fiu32_f16>(T, T) -> T
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
diff --git a/src/tint/number.h b/src/tint/number.h
index 3c8d5a0..a36ccfb 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -350,7 +350,7 @@
}
#else // TINT_HAS_OVERFLOW_BUILTINS
if (a.value >= 0) {
- if (AInt::kHighestValue - a.value < b.value) {
+ if (b.value > AInt::kHighestValue - a.value) {
return {};
}
} else {
@@ -372,6 +372,37 @@
return AFloat{result};
}
+/// @returns a - b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedSub(AInt a, AInt b) {
+ int64_t result;
+#ifdef TINT_HAS_OVERFLOW_BUILTINS
+ if (__builtin_sub_overflow(a.value, b.value, &result)) {
+ return {};
+ }
+#else // TINT_HAS_OVERFLOW_BUILTINS
+ if (b.value >= 0) {
+ if (a.value < AInt::kLowestValue + b.value) {
+ return {};
+ }
+ } else {
+ if (a.value > AInt::kHighestValue + b.value) {
+ return {};
+ }
+ }
+ result = a.value - b.value;
+#endif // TINT_HAS_OVERFLOW_BUILTINS
+ return AInt(result);
+}
+
+/// @returns a + b, or an empty optional if the resulting value overflowed the AFloat
+inline std::optional<AFloat> CheckedSub(AFloat a, AFloat b) {
+ auto result = a.value - b.value;
+ if (!std::isfinite(result)) {
+ return {};
+ }
+ return AFloat{result};
+}
+
/// @returns a * b, or an empty optional if the resulting value overflowed the AInt
inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
int64_t result;
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index c245fb4..3182ad3 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -356,7 +356,9 @@
/////////////////////////////////////
}));
+#ifdef OVERFLOW
#undef OVERFLOW // corecrt_math.h :(
+#endif
#define OVERFLOW \
{}
@@ -428,6 +430,67 @@
////////////////////////////////////////////////////////////////////////
}));
+using CheckedSubTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
+TEST_P(CheckedSubTest_AInt, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedSubTest_AInt,
+ CheckedSubTest_AInt,
+ testing::ValuesIn(std::vector<BinaryCheckedCase_AInt>{
+ {AInt(0), AInt(0), AInt(0)},
+ {AInt(1), AInt(1), AInt(0)},
+ {AInt(0), AInt(1), AInt(1)},
+ {AInt(-2), AInt(-1), AInt(1)},
+ {AInt(1), AInt(2), AInt(1)},
+ {AInt(-3), AInt(-2), AInt(1)},
+ {AInt(0x100), AInt(0x300), AInt(0x200)},
+ {AInt(-0x300), AInt(-0x100), AInt(0x200)},
+ {AInt::Highest(), AInt(AInt::kHighestValue - 1), AInt(-1)},
+ {AInt::Lowest(), AInt(AInt::kLowestValue + 1), AInt(1)},
+ {AInt(0x00000000ffffffffll), AInt::Highest(), AInt(0x7fffffff00000000ll)},
+ {AInt::Highest(), AInt::Highest(), AInt(0)},
+ {AInt::Lowest(), AInt::Lowest(), AInt(0)},
+ {OVERFLOW, AInt::Lowest(), AInt(1)},
+ {OVERFLOW, AInt::Highest(), AInt(-1)},
+ {OVERFLOW, AInt::Lowest(), AInt(2)},
+ {OVERFLOW, AInt::Highest(), AInt(-2)},
+ {OVERFLOW, AInt::Lowest(), AInt(10000)},
+ {OVERFLOW, AInt::Highest(), AInt(-10000)},
+ {OVERFLOW, AInt::Lowest(), AInt::Highest()},
+ ////////////////////////////////////////////////////////////////////////
+ }));
+
+using CheckedSubTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
+TEST_P(CheckedSubTest_AFloat, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedSubTest_AFloat,
+ CheckedSubTest_AFloat,
+ testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
+ {AFloat(0), AFloat(0), AFloat(0)},
+ {AFloat(1), AFloat(1), AFloat(0)},
+ {AFloat(0), AFloat(1), AFloat(1)},
+ {AFloat(-2), AFloat(-1), AFloat(1)},
+ {AFloat(1), AFloat(2), AFloat(1)},
+ {AFloat(-3), AFloat(-2), AFloat(1)},
+ {AFloat(0x100), AFloat(0x300), AFloat(0x200)},
+ {AFloat(-0x300), AFloat(-0x100), AFloat(0x200)},
+ {AFloat::Highest(), AFloat(AFloat::kHighestValue - 1), AFloat(-1)},
+ {AFloat::Lowest(), AFloat(AFloat::kLowestValue + 1), AFloat(1)},
+ {AFloat::Highest(), AFloat::Highest(), AFloat(0)},
+ {AFloat::Lowest(), AFloat::Lowest(), AFloat(0)},
+ {OVERFLOW, AFloat::Lowest(), AFloat::Highest()},
+ ////////////////////////////////////////////////////////////////////////
+ }));
+
using CheckedMulTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
TEST_P(CheckedMulTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
diff --git a/src/tint/reader/spirv/entry_point_info.cc b/src/tint/reader/spirv/entry_point_info.cc
index 2677494..9e3604a 100644
--- a/src/tint/reader/spirv/entry_point_info.cc
+++ b/src/tint/reader/spirv/entry_point_info.cc
@@ -22,8 +22,8 @@
ast::PipelineStage the_stage,
bool the_owns_inner_implementation,
std::string the_inner_name,
- std::vector<uint32_t>&& the_inputs,
- std::vector<uint32_t>&& the_outputs,
+ utils::VectorRef<uint32_t> the_inputs,
+ utils::VectorRef<uint32_t> the_outputs,
GridSize the_wg_size)
: name(the_name),
stage(the_stage),
diff --git a/src/tint/reader/spirv/entry_point_info.h b/src/tint/reader/spirv/entry_point_info.h
index bc13759..9007742 100644
--- a/src/tint/reader/spirv/entry_point_info.h
+++ b/src/tint/reader/spirv/entry_point_info.h
@@ -16,9 +16,9 @@
#define SRC_TINT_READER_SPIRV_ENTRY_POINT_INFO_H_
#include <string>
-#include <vector>
#include "src/tint/ast/pipeline_stage.h"
+#include "src/tint/utils/vector.h"
namespace tint::reader::spirv {
@@ -48,8 +48,8 @@
ast::PipelineStage the_stage,
bool the_owns_inner_implementation,
std::string the_inner_name,
- std::vector<uint32_t>&& the_inputs,
- std::vector<uint32_t>&& the_outputs,
+ utils::VectorRef<uint32_t> the_inputs,
+ utils::VectorRef<uint32_t> the_outputs,
GridSize the_wg_size);
/// Copy constructor
/// @param other the other entry point info to be built from
@@ -75,9 +75,9 @@
/// The name of the inner implementation function of the entry point.
std::string inner_name;
/// IDs of pipeline input variables, sorted and without duplicates.
- std::vector<uint32_t> inputs;
+ utils::Vector<uint32_t, 8> inputs;
/// IDs of pipeline output variables, sorted and without duplicates.
- std::vector<uint32_t> outputs;
+ utils::Vector<uint32_t, 8> outputs;
/// If this is a compute shader, this is the workgroup size in the x, y,
/// and z dimensions set via LocalSize, or via the composite value
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 4975387..a57984b 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -1303,7 +1303,7 @@
utils::Empty)));
// Pipeline outputs are mapped to the return value.
- if (ep_info_->outputs.empty()) {
+ if (ep_info_->outputs.IsEmpty()) {
// There is nothing to return.
return_type = ty_.Void()->Build(builder_);
} else {
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index e1bff31..f9117a5 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -878,17 +878,17 @@
TINT_ASSERT(Reader, !inner_implementation_name.empty());
TINT_ASSERT(Reader, ep_name != inner_implementation_name);
- utils::UniqueVector<uint32_t> inputs;
- utils::UniqueVector<uint32_t> outputs;
+ utils::UniqueVector<uint32_t, 8> inputs;
+ utils::UniqueVector<uint32_t, 8> outputs;
for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) {
const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg);
if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) {
switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) {
case SpvStorageClassInput:
- inputs.add(var_id);
+ inputs.Add(var_id);
break;
case SpvStorageClassOutput:
- outputs.add(var_id);
+ outputs.Add(var_id);
break;
default:
break;
@@ -896,9 +896,9 @@
}
}
// Save the lists, in ID-sorted order.
- std::vector<uint32_t> sorted_inputs(inputs);
+ utils::Vector<uint32_t, 8> sorted_inputs(inputs);
std::sort(sorted_inputs.begin(), sorted_inputs.end());
- std::vector<uint32_t> sorted_outputs(outputs);
+ utils::Vector<uint32_t, 8> sorted_outputs(outputs);
std::sort(sorted_outputs.begin(), sorted_outputs.end());
const auto ast_stage = enum_converter_.ToPipelineStage(stage);
diff --git a/src/tint/reader/spirv/parser_impl_module_var_test.cc b/src/tint/reader/spirv/parser_impl_module_var_test.cc
index c88f0d6..8eb506b 100644
--- a/src/tint/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/tint/reader/spirv/parser_impl_module_var_test.cc
@@ -3428,13 +3428,13 @@
const auto& info_1000 = p->GetEntryPointInfo(1000);
EXPECT_EQ(1u, info_1000.size());
- EXPECT_TRUE(info_1000[0].inputs.empty());
- EXPECT_TRUE(info_1000[0].outputs.empty());
+ EXPECT_TRUE(info_1000[0].inputs.IsEmpty());
+ EXPECT_TRUE(info_1000[0].outputs.IsEmpty());
const auto& info_1100 = p->GetEntryPointInfo(1100);
EXPECT_EQ(1u, info_1100.size());
EXPECT_THAT(info_1100[0].inputs, ElementsAre(1));
- EXPECT_TRUE(info_1100[0].outputs.empty());
+ EXPECT_TRUE(info_1100[0].outputs.IsEmpty());
const auto& info_1200 = p->GetEntryPointInfo(1200);
EXPECT_EQ(1u, info_1200.size());
diff --git a/src/tint/reader/wgsl/lexer_test.cc b/src/tint/reader/wgsl/lexer_test.cc
index 6cac798..dfd1b85 100644
--- a/src/tint/reader/wgsl/lexer_test.cc
+++ b/src/tint/reader/wgsl/lexer_test.cc
@@ -1009,7 +1009,6 @@
INSTANTIATE_TEST_SUITE_P(LexerTest,
PunctuationTest,
testing::Values(TokenData{"&", Token::Type::kAnd},
- TokenData{"&&", Token::Type::kAndAnd},
TokenData{"->", Token::Type::kArrow},
TokenData{"@", Token::Type::kAttr},
TokenData{"/", Token::Type::kForwardSlash},
@@ -1087,7 +1086,8 @@
}
INSTANTIATE_TEST_SUITE_P(LexerTest,
SplittablePunctuationTest,
- testing::Values(TokenData{">=", Token::Type::kGreaterThanEqual},
+ testing::Values(TokenData{"&&", Token::Type::kAndAnd},
+ TokenData{">=", Token::Type::kGreaterThanEqual},
TokenData{">>", Token::Type::kShiftRight}));
using KeywordTest = testing::TestWithParam<TokenData>;
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index bf26f8e..e4e6812 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -47,6 +47,12 @@
namespace tint::reader::wgsl {
namespace {
+using Void = ParserImpl::Void;
+
+/// An instance of Void that can be used to signal success for functions that return Expect<Void> or
+/// Maybe<NoError>.
+static constexpr Void kSuccess;
+
template <typename T>
using Expect = ParserImpl::Expect<T>;
@@ -351,7 +357,7 @@
// global_directive
// : enable_directive
-Maybe<bool> ParserImpl::global_directive(bool have_parsed_decl) {
+Maybe<Void> ParserImpl::global_directive(bool have_parsed_decl) {
auto& p = peek();
auto ed = enable_directive();
if (ed.matched && have_parsed_decl) {
@@ -362,8 +368,8 @@
// enable_directive
// : enable name SEMICLON
-Maybe<bool> ParserImpl::enable_directive() {
- auto decl = sync(Token::Type::kSemicolon, [&]() -> Maybe<bool> {
+Maybe<Void> ParserImpl::enable_directive() {
+ auto decl = sync(Token::Type::kSemicolon, [&]() -> Maybe<Void> {
if (!match(Token::Type::kEnable)) {
return Failure::kNoMatch;
}
@@ -403,14 +409,14 @@
}
builder_.AST().AddEnable(create<ast::Enable>(name.source, extension));
- return true;
+ return kSuccess;
});
if (decl.errored) {
return Failure::kErrored;
}
if (decl.matched) {
- return true;
+ return kSuccess;
}
return Failure::kNoMatch;
@@ -424,9 +430,9 @@
// | struct_decl
// | function_decl
// | static_assert_statement SEMICOLON
-Maybe<bool> ParserImpl::global_decl() {
+Maybe<Void> ParserImpl::global_decl() {
if (match(Token::Type::kSemicolon) || match(Token::Type::kEOF)) {
- return true;
+ return kSuccess;
}
bool errored = false;
@@ -438,7 +444,7 @@
return Failure::kErrored;
}
- auto decl = sync(Token::Type::kSemicolon, [&]() -> Maybe<bool> {
+ auto decl = sync(Token::Type::kSemicolon, [&]() -> Maybe<Void> {
auto gv = global_variable_decl(attrs.value);
if (gv.errored) {
return Failure::kErrored;
@@ -449,7 +455,7 @@
}
builder_.AST().AddGlobalVariable(gv.value);
- return true;
+ return kSuccess;
}
auto gc = global_constant_decl(attrs.value);
@@ -466,7 +472,7 @@
}
builder_.AST().AddGlobalVariable(gc.value);
- return true;
+ return kSuccess;
}
auto ta = type_alias_decl();
@@ -479,7 +485,7 @@
}
builder_.AST().AddTypeDecl(ta.value);
- return true;
+ return kSuccess;
}
auto assertion = static_assert_statement();
@@ -491,7 +497,7 @@
if (!expect("static assertion declaration", Token::Type::kSemicolon)) {
return Failure::kErrored;
}
- return true;
+ return kSuccess;
}
return Failure::kNoMatch;
@@ -501,7 +507,10 @@
errored = true;
}
if (decl.matched) {
- return expect_attributes_consumed(attrs.value);
+ if (!expect_attributes_consumed(attrs.value)) {
+ return Failure::kErrored;
+ }
+ return kSuccess;
}
auto str = struct_decl();
@@ -510,7 +519,10 @@
}
if (str.matched) {
builder_.AST().AddTypeDecl(str.value);
- return true;
+ if (!expect_attributes_consumed(attrs.value)) {
+ return Failure::kErrored;
+ }
+ return kSuccess;
}
auto func = function_decl(attrs.value);
@@ -519,7 +531,7 @@
}
if (func.matched) {
builder_.AST().AddFunction(func.value);
- return true;
+ return kSuccess;
}
if (errored) {
@@ -594,8 +606,8 @@
}
// global_constant_decl :
-// | LET (ident | variable_ident_decl) global_const_initializer
-// | attribute* override (ident | variable_ident_decl) (equal expression)?
+// | LET optionally_typed_ident global_const_initializer
+// | attribute* override optionally_typed_ident (equal expression)?
// global_const_initializer
// : EQUAL const_expr
Maybe<const ast::Variable*> ParserImpl::global_constant_decl(AttributeList& attrs) {
@@ -615,7 +627,7 @@
return Failure::kNoMatch;
}
- auto decl = expect_ident_or_variable_ident_decl(use);
+ auto decl = expect_optionally_typed_ident(use);
if (decl.errored) {
return Failure::kErrored;
}
@@ -666,7 +678,7 @@
}
// variable_decl
-// : VAR variable_qualifier? (ident | variable_ident_decl)
+// : VAR variable_qualifier? optionally_typed_ident
Maybe<ParserImpl::VarDeclInfo> ParserImpl::variable_decl() {
Source source;
if (!match(Token::Type::kVar, &source)) {
@@ -682,7 +694,7 @@
vq = explicit_vq.value;
}
- auto decl = expect_ident_or_variable_ident_decl("variable declaration");
+ auto decl = expect_optionally_typed_ident("variable declaration");
if (decl.errored) {
return Failure::kErrored;
}
@@ -690,20 +702,20 @@
return VarDeclInfo{decl->source, decl->name, vq.storage_class, vq.access, decl->type};
}
-// texture_samplers
-// : sampler
-// | depth_texture
-// | sampled_texture LESS_THAN type_decl GREATER_THAN
-// | multisampled_texture LESS_THAN type_decl GREATER_THAN
-// | storage_texture LESS_THAN texel_format
-// COMMA access GREATER_THAN
-Maybe<const ast::Type*> ParserImpl::texture_samplers() {
- auto type = sampler();
+// texture_and_sampler_types
+// : sampler_type
+// | depth_texture_type
+// | sampled_texture_type LESS_THAN type_decl GREATER_THAN
+// | multisampled_texture_type LESS_THAN type_decl GREATER_THAN
+// | storage_texture_type LESS_THAN texel_format
+// COMMA access_mode GREATER_THAN
+Maybe<const ast::Type*> ParserImpl::texture_and_sampler_types() {
+ auto type = sampler_type();
if (type.matched) {
return type;
}
- type = depth_texture();
+ type = depth_texture_type();
if (type.matched) {
return type;
}
@@ -715,7 +727,7 @@
auto source_range = make_source_range();
- auto dim = sampled_texture();
+ auto dim = sampled_texture_type();
if (dim.matched) {
const char* use = "sampled texture type";
@@ -727,7 +739,7 @@
return builder_.ty.sampled_texture(source_range, dim.value, subtype.value);
}
- auto ms_dim = multisampled_texture();
+ auto ms_dim = multisampled_texture_type();
if (ms_dim.matched) {
const char* use = "multisampled texture type";
@@ -739,7 +751,7 @@
return builder_.ty.multisampled_texture(source_range, ms_dim.value, subtype.value);
}
- auto storage = storage_texture();
+ auto storage = storage_texture_type();
if (storage.matched) {
const char* use = "storage texture type";
using StorageTextureInfo = std::pair<tint::ast::TexelFormat, tint::ast::Access>;
@@ -753,7 +765,7 @@
return Failure::kErrored;
}
- auto access = expect_access("access control");
+ auto access = expect_access_mode("access control");
if (access.errored) {
return Failure::kErrored;
}
@@ -772,10 +784,10 @@
return Failure::kNoMatch;
}
-// sampler
+// sampler_type
// : SAMPLER
// | SAMPLER_COMPARISON
-Maybe<const ast::Type*> ParserImpl::sampler() {
+Maybe<const ast::Type*> ParserImpl::sampler_type() {
Source source;
if (match(Token::Type::kSampler, &source)) {
return builder_.ty.sampler(source, ast::SamplerKind::kSampler);
@@ -788,14 +800,14 @@
return Failure::kNoMatch;
}
-// sampled_texture
+// sampled_texture_type
// : TEXTURE_SAMPLED_1D
// | TEXTURE_SAMPLED_2D
// | TEXTURE_SAMPLED_2D_ARRAY
// | TEXTURE_SAMPLED_3D
// | TEXTURE_SAMPLED_CUBE
// | TEXTURE_SAMPLED_CUBE_ARRAY
-Maybe<const ast::TextureDimension> ParserImpl::sampled_texture() {
+Maybe<const ast::TextureDimension> ParserImpl::sampled_texture_type() {
if (match(Token::Type::kTextureSampled1d)) {
return ast::TextureDimension::k1d;
}
@@ -834,9 +846,9 @@
return Failure::kNoMatch;
}
-// multisampled_texture
+// multisampled_texture_type
// : TEXTURE_MULTISAMPLED_2D
-Maybe<const ast::TextureDimension> ParserImpl::multisampled_texture() {
+Maybe<const ast::TextureDimension> ParserImpl::multisampled_texture_type() {
if (match(Token::Type::kTextureMultisampled2d)) {
return ast::TextureDimension::k2d;
}
@@ -844,12 +856,12 @@
return Failure::kNoMatch;
}
-// storage_texture
+// storage_texture_type
// : TEXTURE_STORAGE_1D
// | TEXTURE_STORAGE_2D
// | TEXTURE_STORAGE_2D_ARRAY
// | TEXTURE_STORAGE_3D
-Maybe<const ast::TextureDimension> ParserImpl::storage_texture() {
+Maybe<const ast::TextureDimension> ParserImpl::storage_texture_type() {
if (match(Token::Type::kTextureStorage1d)) {
return ast::TextureDimension::k1d;
}
@@ -866,13 +878,13 @@
return Failure::kNoMatch;
}
-// depth_texture
+// depth_texture_type
// : TEXTURE_DEPTH_2D
// | TEXTURE_DEPTH_2D_ARRAY
// | TEXTURE_DEPTH_CUBE
// | TEXTURE_DEPTH_CUBE_ARRAY
// | TEXTURE_DEPTH_MULTISAMPLED_2D
-Maybe<const ast::Type*> ParserImpl::depth_texture() {
+Maybe<const ast::Type*> ParserImpl::depth_texture_type() {
Source source;
if (match(Token::Type::kTextureDepth2d, &source)) {
return builder_.ty.depth_texture(source, ast::TextureDimension::k2d);
@@ -918,7 +930,7 @@
return fmt;
}
-Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_ident_or_variable_ident_decl_impl(
+Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_ident_with_optional_type_decl(
std::string_view use,
bool allow_inferred) {
auto ident = expect_ident(use);
@@ -946,23 +958,24 @@
return TypedIdentifier{type.value, ident.value, ident.source};
}
-// (ident | variable_ident_decl)
-Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_ident_or_variable_ident_decl(
+// optionally_typed_ident
+// : ident ( COLON typed_decl ) ?
+Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_optionally_typed_ident(
std::string_view use) {
- return expect_ident_or_variable_ident_decl_impl(use, true);
+ return expect_ident_with_optional_type_decl(use, true);
}
-// variable_ident_decl
+// ident_with_type_decl
// : IDENT COLON type_decl
-Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_variable_ident_decl(std::string_view use) {
- return expect_ident_or_variable_ident_decl_impl(use, false);
+Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_ident_with_type_decl(std::string_view use) {
+ return expect_ident_with_optional_type_decl(use, false);
}
// access_mode
// : 'read'
// | 'write'
// | 'read_write'
-Expect<ast::Access> ParserImpl::expect_access(std::string_view use) {
+Expect<ast::Access> ParserImpl::expect_access_mode(std::string_view use) {
auto ident = expect_ident(use);
if (ident.errored) {
return Failure::kErrored;
@@ -996,7 +1009,7 @@
return Failure::kErrored;
}
if (match(Token::Type::kComma)) {
- auto ac = expect_access(use);
+ auto ac = expect_access_mode(use);
if (ac.errored) {
return Failure::kErrored;
}
@@ -1064,7 +1077,7 @@
// | MAT4x2 LESS_THAN type_decl GREATER_THAN
// | MAT4x3 LESS_THAN type_decl GREATER_THAN
// | MAT4x4 LESS_THAN type_decl GREATER_THAN
-// | texture_samplers
+// | texture_and_sampler_types
Maybe<const ast::Type*> ParserImpl::type_decl() {
auto& t = peek();
Source source;
@@ -1114,7 +1127,7 @@
return expect_type_decl_matrix(t);
}
- auto texture_or_sampler = texture_samplers();
+ auto texture_or_sampler = texture_and_sampler_types();
if (texture_or_sampler.errored) {
return Failure::kErrored;
}
@@ -1159,7 +1172,7 @@
}
if (match(Token::Type::kComma)) {
- auto ac = expect_access("access control");
+ auto ac = expect_access_mode("access control");
if (ac.errored) {
return Failure::kErrored;
}
@@ -1356,14 +1369,14 @@
}
// struct_member
-// : attribute* variable_ident_decl
+// : attribute* ident_with_type_decl
Expect<ast::StructMember*> ParserImpl::expect_struct_member() {
auto attrs = attribute_list();
if (attrs.errored) {
return Failure::kErrored;
}
- auto decl = expect_variable_ident_decl("struct member");
+ auto decl = expect_ident_with_type_decl("struct member");
if (decl.errored) {
return Failure::kErrored;
}
@@ -1519,11 +1532,11 @@
}
// param
-// : attribute_list* variable_ident_decl
+// : attribute_list* ident_with_type_decl
Expect<ast::Parameter*> ParserImpl::expect_param() {
auto attrs = attribute_list();
- auto decl = expect_variable_ident_decl("parameter");
+ auto decl = expect_ident_with_type_decl("parameter");
if (decl.errored) {
return Failure::kErrored;
}
@@ -1603,18 +1616,18 @@
}
// builtin_value_name
-// : 'vertex_index'
-// | 'instance_index'
-// | 'position'
-// | 'front_facing'
-// | 'frag_depth'
-// | 'local_invocation_id'
-// | 'local_invocation_index'
-// | 'global_invocation_id'
-// | 'workgroup_id'
-// | 'num_workgroups'
-// | 'sample_index'
-// | 'sample_mask'
+// : frag_depth
+// | front_facing
+// | global_invocation_id
+// | instance_index
+// | local_invocation_id
+// | local_invocation_index
+// | num_workgroups
+// | position
+// | sample_index
+// | sample_mask
+// | vertex_index
+// | workgroup_id
Expect<ast::BuiltinValue> ParserImpl::expect_builtin() {
auto ident = expect_ident("builtin");
if (ident.errored) {
@@ -1762,9 +1775,7 @@
// | break_statement SEMICOLON
// | continue_statement SEMICOLON
// | DISCARD SEMICOLON
-// | assignment_statement SEMICOLON
-// | increment_statement SEMICOLON
-// | decrement_statement SEMICOLON
+// | variable_updating_statement SEMICOLON
// | static_assert_statement SEMICOLON
Maybe<const ast::Statement*> ParserImpl::non_block_statement() {
auto stmt = [&]() -> Maybe<const ast::Statement*> {
@@ -1814,7 +1825,7 @@
}
// Note, this covers assignment, increment and decrement
- auto assign = assignment_statement();
+ auto assign = variable_updating_statement();
if (assign.errored) {
return Failure::kErrored;
}
@@ -1863,11 +1874,11 @@
// variable_statement
// : variable_decl
// | variable_decl EQUAL expression
-// | LET (ident | variable_ident_decl) EQUAL expression
-// | CONST (ident | variable_ident_decl) EQUAL expression
+// | LET optionally_typed_ident EQUAL expression
+// | CONST optionally_typed_ident EQUAL expression
Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_statement() {
if (match(Token::Type::kConst)) {
- auto decl = expect_ident_or_variable_ident_decl("'const' declaration");
+ auto decl = expect_optionally_typed_ident("'const' declaration");
if (decl.errored) {
return Failure::kErrored;
}
@@ -1894,7 +1905,7 @@
}
if (match(Token::Type::kLet)) {
- auto decl = expect_ident_or_variable_ident_decl("'let' declaration");
+ auto decl = expect_optionally_typed_ident("'let' declaration");
if (decl.errored) {
return Failure::kErrored;
}
@@ -2209,7 +2220,7 @@
ForHeader::~ForHeader() = default;
-// (variable_statement | increment_statement | decrement_statement | assignment_statement |
+// (variable_statement | variable_updating_statement |
// func_call_statement)?
Maybe<const ast::Statement*> ParserImpl::for_header_initializer() {
auto call = func_call_statement();
@@ -2228,7 +2239,7 @@
return var.value;
}
- auto assign = assignment_statement();
+ auto assign = variable_updating_statement();
if (assign.errored) {
return Failure::kErrored;
}
@@ -2239,7 +2250,7 @@
return Failure::kNoMatch;
}
-// (increment_statement | decrement_statement | assignment_statement | func_call_statement)?
+// (variable_updating_statement | func_call_statement)?
Maybe<const ast::Statement*> ParserImpl::for_header_continuing() {
auto call_stmt = func_call_statement();
if (call_stmt.errored) {
@@ -2249,7 +2260,7 @@
return call_stmt.value;
}
- auto assign = assignment_statement();
+ auto assign = variable_updating_statement();
if (assign.errored) {
return Failure::kErrored;
}
@@ -2261,10 +2272,10 @@
}
// for_header
-// : (variable_statement | assignment_statement | func_call_statement)?
+// : (variable_statement | variable_updating_statement | func_call_statement)?
// SEMICOLON
// expression? SEMICOLON
-// (assignment_statement | func_call_statement)?
+// (variable_updating_statement | func_call_statement)?
Expect<std::unique_ptr<ForHeader>> ParserImpl::expect_for_header() {
auto initializer = for_header_initializer();
if (initializer.errored) {
@@ -2501,8 +2512,9 @@
// postfix_expression
// :
-// | BRACE_LEFT expression BRACE_RIGHT postfix_expr
-// | PERIOD IDENTIFIER postfix_expr
+// | BRACE_LEFT expression BRACE_RIGHT postfix_expression?
+// | PERIOD member_ident postfix_expression?
+// | PERIOD swizzle_name postfix_expression?
Maybe<const ast::Expression*> ParserImpl::postfix_expression(const ast::Expression* prefix) {
Source source;
@@ -2587,6 +2599,64 @@
});
}
+// bitwise_expression.post.unary_expression
+// : AND unary_expression (AND unary_expression)*
+// | OR unary_expression (OR unary_expression)*
+// | XOR unary_expression (XOR unary_expression)*
+Maybe<const ast::Expression*> ParserImpl::bitwise_expression_post_unary_expression(
+ const ast::Expression* lhs) {
+ auto& t = peek();
+ if (!t.Is(Token::Type::kAnd) && !t.Is(Token::Type::kOr) && !t.Is(Token::Type::kXor)) {
+ return Failure::kNoMatch;
+ }
+
+ ast::BinaryOp op = ast::BinaryOp::kXor;
+ if (t.Is(Token::Type::kAnd)) {
+ op = ast::BinaryOp::kAnd;
+ } else if (t.Is(Token::Type::kOr)) {
+ op = ast::BinaryOp::kOr;
+ }
+
+ while (continue_parsing()) {
+ auto& n = peek();
+ // Handle the case of `a & b &&c` where `&c` is a unary_expression
+ bool split = false;
+ if (op == ast::BinaryOp::kAnd && n.Is(Token::Type::kAndAnd)) {
+ next();
+ split_token(Token::Type::kAnd, Token::Type::kAnd);
+ split = true;
+ }
+
+ if (!n.Is(t.type())) {
+ if (n.Is(Token::Type::kAnd) || n.Is(Token::Type::kOr) || n.Is(Token::Type::kXor)) {
+ return add_error(n.source(), std::string("mixing '") + std::string(t.to_name()) +
+ "' and '" + std::string(n.to_name()) +
+ "' requires parenthesis");
+ }
+
+ return lhs;
+ }
+ // If forced to split an `&&` then we've already done the `next` above which consumes
+ // the `&`. The type check above will always fail because we only split if already consuming
+ // a `&` operator.
+ if (!split) {
+ next();
+ }
+
+ auto rhs = unary_expression();
+ if (rhs.errored) {
+ return Failure::kErrored;
+ }
+ if (!rhs.matched) {
+ return add_error(peek(), std::string("unable to parse right side of ") +
+ std::string(t.to_name()) + " expression");
+ }
+
+ lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
+ }
+ return Failure::kErrored;
+}
+
// unary_expression
// : singular_expression
// | MINUS unary_expression
@@ -3114,16 +3184,87 @@
return Failure::kNoMatch;
}
-// assignment_statement
-// | lhs_expression ( EQUAL | compound_assignment_operator ) expression
-// | UNDERSCORE EQUAL expression
+// core_lhs_expression
+// : ident
+// | PAREN_LEFT lhs_expression PAREN_RIGHT
+Maybe<const ast::Expression*> ParserImpl::core_lhs_expression() {
+ auto& t = peek();
+ if (t.IsIdentifier()) {
+ next();
+
+ return create<ast::IdentifierExpression>(t.source(),
+ builder_.Symbols().Register(t.to_str()));
+ }
+
+ if (peek_is(Token::Type::kParenLeft)) {
+ return expect_paren_block("", [&]() -> Expect<const ast::Expression*> {
+ auto expr = lhs_expression();
+ if (expr.errored) {
+ return Failure::kErrored;
+ }
+ if (!expr.matched) {
+ return add_error(t, "invalid expression");
+ }
+ return expr.value;
+ });
+ }
+
+ return Failure::kNoMatch;
+}
+
+// lhs_expression
+// : ( STAR | AND )* core_lhs_expression postfix_expression?
+Maybe<const ast::Expression*> ParserImpl::lhs_expression() {
+ std::vector<const Token*> prefixes;
+ while (peek_is(Token::Type::kStar) || peek_is(Token::Type::kAnd) ||
+ peek_is(Token::Type::kAndAnd)) {
+ auto& t = next();
+
+ // If an '&&' is provided split into '&' and '&'
+ if (t.Is(Token::Type::kAndAnd)) {
+ split_token(Token::Type::kAnd, Token::Type::kAnd);
+ }
+
+ prefixes.push_back(&t);
+ }
+
+ auto core_expr = core_lhs_expression();
+ if (core_expr.errored) {
+ return Failure::kErrored;
+ } else if (!core_expr.matched) {
+ if (prefixes.empty()) {
+ return Failure::kNoMatch;
+ }
+
+ return add_error(peek(), "missing expression");
+ }
+
+ const auto* expr = core_expr.value;
+ for (auto it = prefixes.rbegin(); it != prefixes.rend(); ++it) {
+ auto& t = **it;
+ ast::UnaryOp op = ast::UnaryOp::kAddressOf;
+ if (t.Is(Token::Type::kStar)) {
+ op = ast::UnaryOp::kIndirection;
+ }
+ expr = create<ast::UnaryOpExpression>(t.source(), op, expr);
+ }
+
+ auto e = postfix_expression(expr);
+ if (e.errored) {
+ return Failure::kErrored;
+ }
+ return e.value;
+}
+
+// variable_updating_statement
+// : lhs_expression ( EQUAL | compound_assignment_operator ) expression
+// | lhs_expression MINUS_MINUS
+// | lhs_expression PLUS_PLUS
+// | UNDERSCORE EQUAL expression
//
-// increment_statement
-// | lhs_expression PLUS_PLUS
-//
-// decrement_statement
-// | lhs_expression MINUS_MINUS
-Maybe<const ast::Statement*> ParserImpl::assignment_statement() {
+// Note, this is a simplification of the recursive grammar statement with the `lhs_expression`
+// substituted back into the expression.
+Maybe<const ast::Statement*> ParserImpl::variable_updating_statement() {
auto& t = peek();
// tint:295 - Test for `ident COLON` - this is invalid grammar, and without
@@ -3133,36 +3274,47 @@
return add_error(peek(0).source(), "expected 'var' for variable declaration");
}
- auto lhs = unary_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- Source source = t.source();
- if (!match(Token::Type::kUnderscore, &source)) {
- return Failure::kNoMatch;
- }
- lhs = create<ast::PhonyExpression>(source);
- }
+ const ast::Expression* lhs = nullptr;
+ ast::BinaryOp compound_op = ast::BinaryOp::kNone;
+ if (peek_is(Token::Type::kUnderscore)) {
+ next(); // Consume the peek.
- // Handle increment and decrement statements.
- // We do this here because the parsing of the LHS expression overlaps with
- // the assignment statement, and we cannot tell which we are parsing until we
- // hit the ++/--/= token.
- if (match(Token::Type::kPlusPlus)) {
- return create<ast::IncrementDecrementStatement>(t.source(), lhs.value, true);
- } else if (match(Token::Type::kMinusMinus)) {
- return create<ast::IncrementDecrementStatement>(t.source(), lhs.value, false);
- }
-
- auto compound_op = compound_assignment_operator();
- if (compound_op.errored) {
- return Failure::kErrored;
- }
- if (!compound_op.matched) {
if (!expect("assignment", Token::Type::kEqual)) {
return Failure::kErrored;
}
+
+ lhs = create<ast::PhonyExpression>(t.source());
+
+ } else {
+ auto lhs_result = lhs_expression();
+ if (lhs_result.errored) {
+ return Failure::kErrored;
+ }
+ if (!lhs_result.matched) {
+ return Failure::kNoMatch;
+ }
+
+ lhs = lhs_result.value;
+
+ // Handle increment and decrement statements.
+ if (match(Token::Type::kPlusPlus)) {
+ return create<ast::IncrementDecrementStatement>(t.source(), lhs, true);
+ }
+ if (match(Token::Type::kMinusMinus)) {
+ return create<ast::IncrementDecrementStatement>(t.source(), lhs, false);
+ }
+
+ auto compound_op_result = compound_assignment_operator();
+ if (compound_op_result.errored) {
+ return Failure::kErrored;
+ }
+ if (compound_op_result.matched) {
+ compound_op = compound_op_result.value;
+ } else {
+ if (!expect("assignment", Token::Type::kEqual)) {
+ return Failure::kErrored;
+ }
+ }
}
auto rhs = expression();
@@ -3173,12 +3325,10 @@
return add_error(peek(), "unable to parse right side of assignment");
}
- if (compound_op.value != ast::BinaryOp::kNone) {
- return create<ast::CompoundAssignmentStatement>(t.source(), lhs.value, rhs.value,
- compound_op.value);
- } else {
- return create<ast::AssignmentStatement>(t.source(), lhs.value, rhs.value);
+ if (compound_op != ast::BinaryOp::kNone) {
+ return create<ast::CompoundAssignmentStatement>(t.source(), lhs, rhs.value, compound_op);
}
+ return create<ast::AssignmentStatement>(t.source(), lhs, rhs.value);
}
// const_literal
@@ -3186,8 +3336,8 @@
// | FLOAT_LITERAL
// | bool_literal
//
-// bool_literal:
-// | TRUE
+// bool_literal
+// : TRUE
// | FALSE
Maybe<const ast::LiteralExpression*> ParserImpl::const_literal() {
auto& t = peek();
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 14f86d5..1bde5a4 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -82,6 +82,10 @@
using StructMemberList = utils::Vector<const ast::StructMember*, 8>;
//! @endcond
+ /// Empty structure used by functions that do not return a value, but need to signal success /
+ /// error with Expect<Void> or Maybe<NoError>.
+ struct Void {};
+
/// Expect is the return type of the parser methods that are expected to
/// return a parsed value of type T, unless there was an parse error.
/// In the case of a parse error the called method will have called
@@ -388,13 +392,13 @@
/// Parses the `global_directive` grammar element, erroring on parse failure.
/// @param has_parsed_decl flag indicating if the parser has consumed a global declaration.
/// @return true on parse success, otherwise an error or no-match.
- Maybe<bool> global_directive(bool has_parsed_decl);
+ Maybe<Void> global_directive(bool has_parsed_decl);
/// Parses the `enable_directive` grammar element, erroring on parse failure.
/// @return true on parse success, otherwise an error or no-match.
- Maybe<bool> enable_directive();
+ Maybe<Void> enable_directive();
/// Parses the `global_decl` grammar element, erroring on parse failure.
/// @return true on parse success, otherwise an error or no-match.
- Maybe<bool> global_decl();
+ Maybe<Void> global_decl();
/// Parses a `global_variable_decl` grammar element with the initial
/// `variable_attribute_list*` provided as `attrs`
/// @returns the variable parsed or nullptr
@@ -410,21 +414,21 @@
/// Parses a `variable_decl` grammar element
/// @returns the parsed variable declaration info
Maybe<VarDeclInfo> variable_decl();
- /// Helper for parsing ident or variable_ident_decl. Should not be called directly,
+ /// Helper for parsing ident with an optional type declaration. Should not be called directly,
/// use the specific version below.
/// @param use a description of what was being parsed if an error was raised.
/// @param allow_inferred allow the identifier to be parsed without a type
/// @returns the parsed identifier, and possibly type, or empty otherwise
- Expect<TypedIdentifier> expect_ident_or_variable_ident_decl_impl(std::string_view use,
- bool allow_inferred);
+ Expect<TypedIdentifier> expect_ident_with_optional_type_decl(std::string_view use,
+ bool allow_inferred);
/// Parses a `ident` or a `variable_ident_decl` grammar element, erroring on parse failure.
/// @param use a description of what was being parsed if an error was raised.
/// @returns the identifier or empty otherwise.
- Expect<TypedIdentifier> expect_ident_or_variable_ident_decl(std::string_view use);
+ Expect<TypedIdentifier> expect_optionally_typed_ident(std::string_view use);
/// Parses a `variable_ident_decl` grammar element, erroring on parse failure.
/// @param use a description of what was being parsed if an error was raised.
/// @returns the identifier and type parsed or empty otherwise
- Expect<TypedIdentifier> expect_variable_ident_decl(std::string_view use);
+ Expect<TypedIdentifier> expect_ident_with_type_decl(std::string_view use);
/// Parses a `variable_qualifier` grammar element
/// @returns the variable qualifier information
Maybe<VariableQualifier> variable_qualifier();
@@ -453,26 +457,26 @@
/// by the declaration, then this vector is cleared before returning.
/// @returns the parsed function, nullptr otherwise
Maybe<const ast::Function*> function_decl(AttributeList& attrs);
- /// Parses a `texture_samplers` grammar element
+ /// Parses a `texture_and_sampler_types` grammar element
/// @returns the parsed Type or nullptr if none matched.
- Maybe<const ast::Type*> texture_samplers();
- /// Parses a `sampler` grammar element
+ Maybe<const ast::Type*> texture_and_sampler_types();
+ /// Parses a `sampler_type` grammar element
/// @returns the parsed Type or nullptr if none matched.
- Maybe<const ast::Type*> sampler();
- /// Parses a `multisampled_texture` grammar element
+ Maybe<const ast::Type*> sampler_type();
+ /// Parses a `multisampled_texture_type` grammar element
/// @returns returns the multisample texture dimension or kNone if none
/// matched.
- Maybe<const ast::TextureDimension> multisampled_texture();
- /// Parses a `sampled_texture` grammar element
+ Maybe<const ast::TextureDimension> multisampled_texture_type();
+ /// Parses a `sampled_texture_type` grammar element
/// @returns returns the sample texture dimension or kNone if none matched.
- Maybe<const ast::TextureDimension> sampled_texture();
- /// Parses a `storage_texture` grammar element
+ Maybe<const ast::TextureDimension> sampled_texture_type();
+ /// Parses a `storage_texture_type` grammar element
/// @returns returns the storage texture dimension.
/// Returns kNone if none matched.
- Maybe<const ast::TextureDimension> storage_texture();
- /// Parses a `depth_texture` grammar element
+ Maybe<const ast::TextureDimension> storage_texture_type();
+ /// Parses a `depth_texture_type` grammar element
/// @returns the parsed Type or nullptr if none matched.
- Maybe<const ast::Type*> depth_texture();
+ Maybe<const ast::Type*> depth_texture_type();
/// Parses a 'texture_external_type' grammar element
/// @returns the parsed Type or nullptr if none matched
Maybe<const ast::Type*> external_texture();
@@ -500,7 +504,7 @@
/// match a valid access control.
/// @param use a description of what was being parsed if an error was raised
/// @returns the parsed access control.
- Expect<ast::Access> expect_access(std::string_view use);
+ Expect<ast::Access> expect_access_mode(std::string_view use);
/// Parses an interpolation sample name identifier, erroring if the next token does not match a
/// valid sample name.
/// @returns the parsed sample name.
@@ -637,6 +641,11 @@
/// Parses the `equality_expression` grammar element
/// @returns the parsed expression or nullptr
Maybe<const ast::Expression*> equality_expression();
+ /// Parses the `bitwise_expression.post.unary_expression` grammar element
+ /// @param lhs the left side of the expression
+ /// @returns the parsed expression or nullptr
+ Maybe<const ast::Expression*> bitwise_expression_post_unary_expression(
+ const ast::Expression* lhs);
/// Parses the recursive part of the `and_expression`, erroring on parse
/// failure.
/// @param lhs the left side of the expression
@@ -683,9 +692,15 @@
/// Parses a `compound_assignment_operator` grammar element
/// @returns the parsed compound assignment operator
Maybe<ast::BinaryOp> compound_assignment_operator();
- /// Parses a `assignment_statement` grammar element
+ /// Parses a `core_lhs_expression` grammar element
+ /// @returns the parsed expression or a non-kMatched failure
+ Maybe<const ast::Expression*> core_lhs_expression();
+ /// Parses a `lhs_expression` grammar element
+ /// @returns the parsed expression or a non-kMatched failure
+ Maybe<const ast::Expression*> lhs_expression();
+ /// Parses a `variable_updating_statement` grammar element
/// @returns the parsed assignment or nullptr
- Maybe<const ast::Statement*> assignment_statement();
+ Maybe<const ast::Statement*> variable_updating_statement();
/// Parses one or more attribute lists.
/// @return the parsed attribute list, or an empty list on error.
Maybe<AttributeList> attribute_list();
diff --git a/src/tint/reader/wgsl/parser_impl_assignment_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_assignment_stmt_test.cc
index 4581d98..1f2b191 100644
--- a/src/tint/reader/wgsl/parser_impl_assignment_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_assignment_stmt_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, AssignmentStmt_Parses_ToVariable) {
auto p = parser("a = 123");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -42,7 +42,7 @@
TEST_F(ParserImplTest, AssignmentStmt_Parses_ToMember) {
auto p = parser("a.b.c[2].d = 123");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -92,7 +92,7 @@
TEST_F(ParserImplTest, AssignmentStmt_Parses_ToPhony) {
auto p = parser("_ = 123i");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -111,6 +111,38 @@
ASSERT_TRUE(a->lhs->Is<ast::PhonyExpression>());
}
+TEST_F(ParserImplTest, AssignmentStmt_Phony_CompoundOpFails) {
+ auto p = parser("_ += 123i");
+ auto e = p->variable_updating_statement();
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:3: expected '=' for assignment");
+}
+
+TEST_F(ParserImplTest, AssignmentStmt_Phony_IncrementFails) {
+ auto p = parser("_ ++");
+ auto e = p->variable_updating_statement();
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:3: expected '=' for assignment");
+}
+
+TEST_F(ParserImplTest, AssignmentStmt_Phony_EqualIncrementFails) {
+ auto p = parser("_ = ++");
+ auto e = p->variable_updating_statement();
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_EQ(
+ p->error(),
+ "1:5: prefix increment and decrement operators are reserved for a future WGSL version");
+}
+
struct CompoundData {
std::string str;
ast::BinaryOp op;
@@ -119,7 +151,7 @@
TEST_P(CompoundOpTest, CompoundOp) {
auto params = GetParam();
auto p = parser("a " + params.str + " 123u");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -155,7 +187,7 @@
TEST_F(ParserImplTest, AssignmentStmt_MissingEqual) {
auto p = parser("a.b.c[2].d 123");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_TRUE(p->has_error());
@@ -165,7 +197,7 @@
TEST_F(ParserImplTest, AssignmentStmt_Compound_MissingEqual) {
auto p = parser("a + 123");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_TRUE(p->has_error());
@@ -175,7 +207,7 @@
TEST_F(ParserImplTest, AssignmentStmt_InvalidLHS) {
auto p = parser("if (true) {} = 123");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -184,7 +216,7 @@
TEST_F(ParserImplTest, AssignmentStmt_InvalidRHS) {
auto p = parser("a.b.c[2].d = if (true) {}");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -194,7 +226,7 @@
TEST_F(ParserImplTest, AssignmentStmt_InvalidCompoundOp) {
auto p = parser("a &&= true");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
diff --git a/src/tint/reader/wgsl/parser_impl_bitwise_expression_test.cc b/src/tint/reader/wgsl/parser_impl_bitwise_expression_test.cc
new file mode 100644
index 0000000..a5f675c
--- /dev/null
+++ b/src/tint/reader/wgsl/parser_impl_bitwise_expression_test.cc
@@ -0,0 +1,345 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
+
+namespace tint::reader::wgsl {
+namespace {
+
+TEST_F(ParserImplTest, BitwiseExpr_NoOp) {
+ auto p = parser("a true");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_EQ(e.value, nullptr);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Or_Parses) {
+ auto p = parser("a | true");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 3u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 4u);
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kOr, rel->op);
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Or_Parses_Multiple) {
+ auto p = parser("a | true | b");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ // lhs: (a | true)
+ // rhs: b
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kOr, rel->op);
+
+ ASSERT_TRUE(rel->rhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->rhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("b"));
+
+ ASSERT_TRUE(rel->lhs->Is<ast::BinaryExpression>());
+
+ // lhs: a
+ // rhs: true
+ rel = rel->lhs->As<ast::BinaryExpression>();
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Or_MixedWithAnd_Invalid) {
+ auto p = parser("a | b & c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '|' and '&' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Or_MixedWithXor_Invalid) {
+ auto p = parser("a | b ^ c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '|' and '^' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Or_InvalidRHS) {
+ auto p = parser("true | if (a) {}");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:8: unable to parse right side of | expression");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Xor_Parses) {
+ auto p = parser("a ^ true");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 3u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 4u);
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kXor, rel->op);
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Xor_Parses_Multiple) {
+ auto p = parser("a ^ true ^ b");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ // lhs: (a ^ true)
+ // rhs: b
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kXor, rel->op);
+
+ ASSERT_TRUE(rel->rhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->rhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("b"));
+
+ ASSERT_TRUE(rel->lhs->Is<ast::BinaryExpression>());
+
+ // lhs: a
+ // rhs: true
+ rel = rel->lhs->As<ast::BinaryExpression>();
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Xor_MixedWithOr_Invalid) {
+ auto p = parser("a ^ b | c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '^' and '|' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Xor_MixedWithAnd_Invalid) {
+ auto p = parser("a ^ b & c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '^' and '&' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_Xor_InvalidRHS) {
+ auto p = parser("true ^ if (a) {}");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:8: unable to parse right side of ^ expression");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_Parses) {
+ auto p = parser("a & true");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 3u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 4u);
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kAnd, rel->op);
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_Parses_Multiple) {
+ auto p = parser("a & true & b");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ // lhs: (a & true)
+ // rhs: b
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kAnd, rel->op);
+
+ ASSERT_TRUE(rel->rhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->rhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("b"));
+
+ ASSERT_TRUE(rel->lhs->Is<ast::BinaryExpression>());
+
+ // lhs: a
+ // rhs: true
+ rel = rel->lhs->As<ast::BinaryExpression>();
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_Parses_AndAnd) {
+ auto p = parser("a & true &&b");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ // lhs: (a & true)
+ // rhs: &b
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kAnd, rel->op);
+
+ ASSERT_TRUE(rel->rhs->Is<ast::UnaryOpExpression>());
+ auto* unary = rel->rhs->As<ast::UnaryOpExpression>();
+ EXPECT_EQ(ast::UnaryOp::kAddressOf, unary->op);
+
+ ASSERT_TRUE(unary->expr->Is<ast::IdentifierExpression>());
+ auto* ident = unary->expr->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("b"));
+
+ ASSERT_TRUE(rel->lhs->Is<ast::BinaryExpression>());
+
+ // lhs: a
+ // rhs: true
+ rel = rel->lhs->As<ast::BinaryExpression>();
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_MixedWithOr_Invalid) {
+ auto p = parser("a & b | c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '&' and '|' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_MixedWithXor_Invalid) {
+ auto p = parser("a & b ^ c");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:7: mixing '&' and '^' requires parenthesis");
+}
+
+TEST_F(ParserImplTest, BitwiseExpr_And_InvalidRHS) {
+ auto p = parser("true & if (a) {}");
+ auto lhs = p->unary_expression();
+ auto e = p->bitwise_expression_post_unary_expression(lhs.value);
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:8: unable to parse right side of & expression");
+}
+
+} // namespace
+} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_core_lhs_expression_test.cc b/src/tint/reader/wgsl/parser_impl_core_lhs_expression_test.cc
new file mode 100644
index 0000000..81984eb
--- /dev/null
+++ b/src/tint/reader/wgsl/parser_impl_core_lhs_expression_test.cc
@@ -0,0 +1,93 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
+
+namespace tint::reader::wgsl {
+namespace {
+
+TEST_F(ParserImplTest, CoreLHS_NoMatch) {
+ auto p = parser("123");
+ auto e = p->core_lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_FALSE(e.matched);
+}
+
+TEST_F(ParserImplTest, CoreLHS_Ident) {
+ auto p = parser("identifier");
+ auto e = p->core_lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
+
+ auto* ident = e->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("identifier"));
+}
+
+TEST_F(ParserImplTest, CoreLHS_ParenStmt) {
+ auto p = parser("(a)");
+ auto e = p->core_lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
+
+ auto* ident = e->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+}
+
+TEST_F(ParserImplTest, CoreLHS_MissingRightParen) {
+ auto p = parser("(a");
+ auto e = p->core_lhs_expression();
+ ASSERT_TRUE(p->has_error());
+ ASSERT_TRUE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:3: expected ')'");
+}
+
+TEST_F(ParserImplTest, CoreLHS_InvalidLHSExpression) {
+ auto p = parser("(if (a() {})");
+ auto e = p->core_lhs_expression();
+ ASSERT_TRUE(p->has_error());
+ ASSERT_TRUE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:1: invalid expression");
+}
+
+TEST_F(ParserImplTest, CoreLHS_MissingLHSExpression) {
+ auto p = parser("()");
+ auto e = p->core_lhs_expression();
+ ASSERT_TRUE(p->has_error());
+ ASSERT_TRUE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:1: invalid expression");
+}
+
+TEST_F(ParserImplTest, CoreLHS_Invalid) {
+ auto p = parser("1234");
+ auto e = p->core_lhs_expression();
+ ASSERT_FALSE(p->has_error());
+ ASSERT_FALSE(e.errored);
+ EXPECT_FALSE(e.matched);
+}
+
+} // namespace
+} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_depth_texture_test.cc b/src/tint/reader/wgsl/parser_impl_depth_texture_test.cc
index 6c70bd3..78c4b8d 100644
--- a/src/tint/reader/wgsl/parser_impl_depth_texture_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_depth_texture_test.cc
@@ -20,7 +20,7 @@
TEST_F(ParserImplTest, DepthTextureType_Invalid) {
auto p = parser("1234");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_FALSE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_FALSE(p->has_error());
@@ -28,7 +28,7 @@
TEST_F(ParserImplTest, DepthTextureType_2d) {
auto p = parser("texture_depth_2d");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
@@ -41,7 +41,7 @@
TEST_F(ParserImplTest, DepthTextureType_2dArray) {
auto p = parser("texture_depth_2d_array");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
@@ -54,7 +54,7 @@
TEST_F(ParserImplTest, DepthTextureType_Cube) {
auto p = parser("texture_depth_cube");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
@@ -67,7 +67,7 @@
TEST_F(ParserImplTest, DepthTextureType_CubeArray) {
auto p = parser("texture_depth_cube_array");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
@@ -80,7 +80,7 @@
TEST_F(ParserImplTest, DepthTextureType_Multisampled2d) {
auto p = parser("texture_depth_multisampled_2d");
- auto t = p->depth_texture();
+ auto t = p->depth_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
diff --git a/src/tint/reader/wgsl/parser_impl_global_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
index d9001c5..6c943c8 100644
--- a/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_global_decl_test.cc
@@ -233,6 +233,17 @@
}
}
+TEST_F(ParserImplTest, GlobalDecl_Struct_UnexpectedAttribute) {
+ auto p = parser("@vertex struct S { i : i32 }");
+
+ auto s = p->global_decl();
+ EXPECT_TRUE(s.errored);
+ EXPECT_FALSE(s.matched);
+
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:2: unexpected attributes");
+}
+
TEST_F(ParserImplTest, GlobalDecl_StaticAssert_WithParen) {
auto p = parser("static_assert(true);");
p->global_decl();
diff --git a/src/tint/reader/wgsl/parser_impl_increment_decrement_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_increment_decrement_stmt_test.cc
index d4e52b2..79220e4 100644
--- a/src/tint/reader/wgsl/parser_impl_increment_decrement_stmt_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_increment_decrement_stmt_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, IncrementDecrementStmt_Increment) {
auto p = parser("a++");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -38,7 +38,7 @@
TEST_F(ParserImplTest, IncrementDecrementStmt_Decrement) {
auto p = parser("a--");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -57,7 +57,7 @@
TEST_F(ParserImplTest, IncrementDecrementStmt_Parenthesized) {
auto p = parser("(a)++");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -76,7 +76,7 @@
TEST_F(ParserImplTest, IncrementDecrementStmt_ToMember) {
auto p = parser("a.b.c[2].d++");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -121,7 +121,7 @@
TEST_F(ParserImplTest, IncrementDecrementStmt_InvalidLHS) {
auto p = parser("{}++");
- auto e = p->assignment_statement();
+ auto e = p->variable_updating_statement();
EXPECT_FALSE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
diff --git a/src/tint/reader/wgsl/parser_impl_lhs_expression_test.cc b/src/tint/reader/wgsl/parser_impl_lhs_expression_test.cc
new file mode 100644
index 0000000..6611f3e
--- /dev/null
+++ b/src/tint/reader/wgsl/parser_impl_lhs_expression_test.cc
@@ -0,0 +1,138 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
+
+namespace tint::reader::wgsl {
+namespace {
+
+TEST_F(ParserImplTest, LHSExpression_NoPrefix) {
+ auto p = parser("a");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
+}
+
+TEST_F(ParserImplTest, LHSExpression_NoMatch) {
+ auto p = parser("123");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+}
+
+TEST_F(ParserImplTest, LHSExpression_And) {
+ auto p = parser("&a");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::UnaryOpExpression>());
+
+ auto* u = e->As<ast::UnaryOpExpression>();
+ EXPECT_EQ(u->op, ast::UnaryOp::kAddressOf);
+ EXPECT_TRUE(u->expr->Is<ast::IdentifierExpression>());
+}
+
+TEST_F(ParserImplTest, LHSExpression_Star) {
+ auto p = parser("*a");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::UnaryOpExpression>());
+
+ auto* u = e->As<ast::UnaryOpExpression>();
+ EXPECT_EQ(u->op, ast::UnaryOp::kIndirection);
+ EXPECT_TRUE(u->expr->Is<ast::IdentifierExpression>());
+}
+
+TEST_F(ParserImplTest, LHSExpression_InvalidCoreLHSExpr) {
+ auto p = parser("*123");
+ auto e = p->lhs_expression();
+ ASSERT_TRUE(p->has_error());
+ ASSERT_TRUE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:2: missing expression");
+}
+
+TEST_F(ParserImplTest, LHSExpression_Multiple) {
+ auto p = parser("*&**&&*a");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+
+ std::vector<ast::UnaryOp> results = {ast::UnaryOp::kIndirection, ast::UnaryOp::kAddressOf,
+ ast::UnaryOp::kIndirection, ast::UnaryOp::kIndirection,
+ ast::UnaryOp::kAddressOf, ast::UnaryOp::kAddressOf,
+ ast::UnaryOp::kIndirection};
+
+ auto* expr = e.value;
+ for (auto op : results) {
+ ASSERT_TRUE(expr->Is<ast::UnaryOpExpression>());
+
+ auto* u = expr->As<ast::UnaryOpExpression>();
+ EXPECT_EQ(u->op, op);
+
+ expr = u->expr;
+ }
+
+ EXPECT_TRUE(expr->Is<ast::IdentifierExpression>());
+}
+
+TEST_F(ParserImplTest, LHSExpression_PostfixExpression) {
+ auto p = parser("*a.foo");
+ auto e = p->lhs_expression();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(e.errored);
+ EXPECT_TRUE(e.matched);
+ ASSERT_NE(e.value, nullptr);
+ ASSERT_TRUE(e->Is<ast::MemberAccessorExpression>());
+
+ auto* access = e->As<ast::MemberAccessorExpression>();
+ ASSERT_TRUE(access->structure->Is<ast::UnaryOpExpression>());
+
+ auto* u = access->structure->As<ast::UnaryOpExpression>();
+ EXPECT_EQ(u->op, ast::UnaryOp::kIndirection);
+
+ ASSERT_TRUE(u->expr->Is<ast::IdentifierExpression>());
+ auto* struct_ident = u->expr->As<ast::IdentifierExpression>();
+ EXPECT_EQ(struct_ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(access->member->Is<ast::IdentifierExpression>());
+ auto* member_ident = access->member->As<ast::IdentifierExpression>();
+ EXPECT_EQ(member_ident->symbol, p->builder().Symbols().Get("foo"));
+}
+
+TEST_F(ParserImplTest, LHSExpression_InvalidPostfixExpression) {
+ auto p = parser("*a.if");
+ auto e = p->lhs_expression();
+ ASSERT_TRUE(p->has_error());
+ ASSERT_TRUE(e.errored);
+ EXPECT_FALSE(e.matched);
+ ASSERT_EQ(e.value, nullptr);
+ EXPECT_EQ(p->error(), "1:4: expected identifier for member accessor");
+}
+
+} // namespace
+} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_sampled_texture_test.cc b/src/tint/reader/wgsl/parser_impl_sampled_texture_test.cc
index cf9f089..6aa1aab 100644
--- a/src/tint/reader/wgsl/parser_impl_sampled_texture_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_sampled_texture_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, SampledTextureType_Invalid) {
auto p = parser("1234");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_FALSE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_FALSE(p->has_error());
@@ -27,7 +27,7 @@
TEST_F(ParserImplTest, SampledTextureType_1d) {
auto p = parser("texture_1d");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k1d);
@@ -36,7 +36,7 @@
TEST_F(ParserImplTest, SampledTextureType_2d) {
auto p = parser("texture_2d");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k2d);
@@ -45,7 +45,7 @@
TEST_F(ParserImplTest, SampledTextureType_2dArray) {
auto p = parser("texture_2d_array");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k2dArray);
@@ -54,7 +54,7 @@
TEST_F(ParserImplTest, SampledTextureType_3d) {
auto p = parser("texture_3d");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k3d);
@@ -63,7 +63,7 @@
TEST_F(ParserImplTest, SampledTextureType_Cube) {
auto p = parser("texture_cube");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::kCube);
@@ -72,7 +72,7 @@
TEST_F(ParserImplTest, SampledTextureType_kCubeArray) {
auto p = parser("texture_cube_array");
- auto t = p->sampled_texture();
+ auto t = p->sampled_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::kCubeArray);
diff --git a/src/tint/reader/wgsl/parser_impl_sampler_test.cc b/src/tint/reader/wgsl/parser_impl_sampler_test.cc
index 7f1e564..bd3c2d3 100644
--- a/src/tint/reader/wgsl/parser_impl_sampler_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_sampler_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, SamplerType_Invalid) {
auto p = parser("1234");
- auto t = p->sampler();
+ auto t = p->sampler_type();
EXPECT_FALSE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, nullptr);
@@ -28,7 +28,7 @@
TEST_F(ParserImplTest, SamplerType_Sampler) {
auto p = parser("sampler");
- auto t = p->sampler();
+ auto t = p->sampler_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
@@ -40,7 +40,7 @@
TEST_F(ParserImplTest, SamplerType_ComparisonSampler) {
auto p = parser("sampler_comparison");
- auto t = p->sampler();
+ auto t = p->sampler_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
ASSERT_NE(t.value, nullptr);
diff --git a/src/tint/reader/wgsl/parser_impl_storage_texture_test.cc b/src/tint/reader/wgsl/parser_impl_storage_texture_test.cc
index 6297a1e..528f3a4 100644
--- a/src/tint/reader/wgsl/parser_impl_storage_texture_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_storage_texture_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, StorageTextureType_Invalid) {
auto p = parser("abc");
- auto t = p->storage_texture();
+ auto t = p->storage_texture_type();
EXPECT_FALSE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_FALSE(p->has_error());
@@ -27,7 +27,7 @@
TEST_F(ParserImplTest, StorageTextureType_1d) {
auto p = parser("texture_storage_1d");
- auto t = p->storage_texture();
+ auto t = p->storage_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k1d);
@@ -36,7 +36,7 @@
TEST_F(ParserImplTest, StorageTextureType_2d) {
auto p = parser("texture_storage_2d");
- auto t = p->storage_texture();
+ auto t = p->storage_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k2d);
@@ -45,7 +45,7 @@
TEST_F(ParserImplTest, StorageTextureType_2dArray) {
auto p = parser("texture_storage_2d_array");
- auto t = p->storage_texture();
+ auto t = p->storage_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k2dArray);
@@ -54,7 +54,7 @@
TEST_F(ParserImplTest, StorageTextureType_3d) {
auto p = parser("texture_storage_3d");
- auto t = p->storage_texture();
+ auto t = p->storage_texture_type();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
EXPECT_EQ(t.value, ast::TextureDimension::k3d);
diff --git a/src/tint/reader/wgsl/parser_impl_texture_sampler_test.cc b/src/tint/reader/wgsl/parser_impl_texture_sampler_test.cc
index 162b41c..1143c52 100644
--- a/src/tint/reader/wgsl/parser_impl_texture_sampler_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_texture_sampler_test.cc
@@ -22,7 +22,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_Invalid) {
auto p = parser("1234");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_FALSE(t.errored);
@@ -31,7 +31,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_Sampler) {
auto p = parser("sampler");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -43,7 +43,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SamplerComparison) {
auto p = parser("sampler_comparison");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -55,7 +55,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_DepthTexture) {
auto p = parser("texture_depth_2d");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -68,7 +68,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_F32) {
auto p = parser("texture_1d<f32>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -82,7 +82,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_I32) {
auto p = parser("texture_2d<i32>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -96,7 +96,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_U32) {
auto p = parser("texture_3d<u32>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -110,7 +110,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_MissingType) {
auto p = parser("texture_1d<>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_TRUE(p->has_error());
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
@@ -120,7 +120,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_MissingLessThan) {
auto p = parser("texture_1d");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_TRUE(p->has_error());
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
@@ -130,7 +130,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_SampledTexture_MissingGreaterThan) {
auto p = parser("texture_1d<u32");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_TRUE(p->has_error());
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
@@ -140,7 +140,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_MultisampledTexture_I32) {
auto p = parser("texture_multisampled_2d<i32>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -154,7 +154,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_MultisampledTexture_MissingType) {
auto p = parser("texture_multisampled_2d<>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_TRUE(p->has_error());
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
@@ -164,7 +164,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_MultisampledTexture_MissingLessThan) {
auto p = parser("texture_multisampled_2d");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -173,7 +173,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_MultisampledTexture_MissingGreaterThan) {
auto p = parser("texture_multisampled_2d<u32");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -182,7 +182,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_Readonly1dRg32Float) {
auto p = parser("texture_storage_1d<rg32float, read>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -198,7 +198,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_Writeonly2dR32Uint) {
auto p = parser("texture_storage_2d<r32uint, write>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(t.matched);
EXPECT_FALSE(t.errored);
@@ -214,7 +214,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_InvalidType) {
auto p = parser("texture_storage_1d<abc, read>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -223,7 +223,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_InvalidAccess) {
auto p = parser("texture_storage_1d<r32float, abc>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -232,7 +232,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_MissingType) {
auto p = parser("texture_storage_1d<>");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -241,7 +241,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_MissingLessThan) {
auto p = parser("texture_storage_1d");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
@@ -250,7 +250,7 @@
TEST_F(ParserImplTest, TextureSamplerTypes_StorageTexture_MissingGreaterThan) {
auto p = parser("texture_storage_1d<r32uint, read");
- auto t = p->texture_samplers();
+ auto t = p->texture_and_sampler_types();
EXPECT_EQ(t.value, nullptr);
EXPECT_FALSE(t.matched);
EXPECT_TRUE(t.errored);
diff --git a/src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc b/src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc
index 2679983..00dd59b 100644
--- a/src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_Parses) {
auto p = parser("my_var : f32");
- auto decl = p->expect_variable_ident_decl("test");
+ auto decl = p->expect_ident_with_type_decl("test");
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(decl.errored);
ASSERT_EQ(decl->name, "my_var");
@@ -32,7 +32,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_Parses_AllowInferredType) {
auto p = parser("my_var : f32");
- auto decl = p->expect_ident_or_variable_ident_decl("test");
+ auto decl = p->expect_optionally_typed_ident("test");
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(decl.errored);
ASSERT_EQ(decl->name, "my_var");
@@ -45,14 +45,14 @@
TEST_F(ParserImplTest, VariableIdentDecl_Inferred_Parse_Failure) {
auto p = parser("my_var = 1.0");
- auto decl = p->expect_variable_ident_decl("test");
+ auto decl = p->expect_ident_with_type_decl("test");
ASSERT_TRUE(p->has_error());
ASSERT_EQ(p->error(), "1:8: expected ':' for test");
}
TEST_F(ParserImplTest, VariableIdentDecl_Inferred_Parses_AllowInferredType) {
auto p = parser("my_var = 1.0");
- auto decl = p->expect_ident_or_variable_ident_decl("test");
+ auto decl = p->expect_optionally_typed_ident("test");
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(decl.errored);
ASSERT_EQ(decl->name, "my_var");
@@ -63,7 +63,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_MissingIdent) {
auto p = parser(": f32");
- auto decl = p->expect_variable_ident_decl("test");
+ auto decl = p->expect_ident_with_type_decl("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:1: expected identifier for test");
@@ -71,7 +71,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_MissingIdent_AllowInferredType) {
auto p = parser(": f32");
- auto decl = p->expect_ident_or_variable_ident_decl("test");
+ auto decl = p->expect_optionally_typed_ident("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:1: expected identifier for test");
@@ -79,7 +79,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_MissingType) {
auto p = parser("my_var :");
- auto decl = p->expect_variable_ident_decl("test");
+ auto decl = p->expect_ident_with_type_decl("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:9: invalid type for test");
@@ -87,7 +87,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_MissingType_AllowInferredType) {
auto p = parser("my_var :");
- auto decl = p->expect_ident_or_variable_ident_decl("test");
+ auto decl = p->expect_optionally_typed_ident("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:9: invalid type for test");
@@ -95,7 +95,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_InvalidIdent) {
auto p = parser("123 : f32");
- auto decl = p->expect_variable_ident_decl("test");
+ auto decl = p->expect_ident_with_type_decl("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:1: expected identifier for test");
@@ -103,7 +103,7 @@
TEST_F(ParserImplTest, VariableIdentDecl_InvalidIdent_AllowInferredType) {
auto p = parser("123 : f32");
- auto decl = p->expect_ident_or_variable_ident_decl("test");
+ auto decl = p->expect_optionally_typed_ident("test");
ASSERT_TRUE(p->has_error());
ASSERT_TRUE(decl.errored);
ASSERT_EQ(p->error(), "1:1: expected identifier for test");
diff --git a/src/tint/reader/wgsl/token.h b/src/tint/reader/wgsl/token.h
index 68cb6c6..d8b93c6 100644
--- a/src/tint/reader/wgsl/token.h
+++ b/src/tint/reader/wgsl/token.h
@@ -379,12 +379,16 @@
/// @returns true if the token can be split during parse into component tokens
bool IsSplittable() const {
- return Is(Token::Type::kShiftRight) || Is(Token::Type::kGreaterThanEqual);
+ return Is(Token::Type::kShiftRight) || Is(Token::Type::kGreaterThanEqual) ||
+ Is(Token::Type::kAndAnd);
}
/// @returns the source information for this token
Source source() const { return source_; }
+ /// @returns the type of the token
+ Type type() const { return type_; }
+
/// Returns the string value of the token
/// @return std::string
std::string to_str() const;
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 5144e63..1e66e5a 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -731,9 +731,9 @@
return TransformElements(builder, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
- utils::VectorRef<const sem::Constant*> args,
- const Source&) {
+ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the
@@ -801,6 +801,51 @@
return r;
}
+ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ using NumberT = decltype(i);
+ using T = UnwrapNumber<NumberT>;
+
+ auto subtract_values = [](T lhs, T rhs) {
+ if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
+ // Ensure no UB for signed underflow
+ using UT = std::make_unsigned_t<T>;
+ return static_cast<T>(static_cast<UT>(lhs) - static_cast<UT>(rhs));
+ } else {
+ return lhs - rhs;
+ }
+ };
+
+ NumberT result;
+ if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
+ // Check for over/underflow for abstract values
+ if (auto r = CheckedSub(i, j)) {
+ result = r->value;
+ } else {
+ AddError("'" + std::to_string(subtract_values(i.value, j.value)) +
+ "' cannot be represented as '" +
+ ty->FriendlyName(builder.Symbols()) + "'",
+ source);
+ return nullptr;
+ }
+ } else {
+ result = subtract_values(i.value, j.value);
+ }
+ return CreateElement(builder, c0->Type(), result);
+ };
+ return Dispatch_fia_fiu32_f16(create, c0, c1);
+ };
+
+ auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index dbc3dbd..9cd38d7 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -199,14 +199,14 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
- /// Minus operator '-'
+ /// Unary minus operator '-'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
- ConstantResult OpMinus(const sem::Type* ty,
- utils::VectorRef<const sem::Constant*> args,
- const Source& source);
+ ConstantResult OpUnaryMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
////////////////////////////////////////////////////////////////////////////
// Binary Operators
@@ -221,6 +221,15 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
+ /// Minus operator '-'
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpMinus(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index 6623a5a..1509b5d 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -3236,6 +3236,43 @@
OpAddFloatCases<f32>(),
OpAddFloatCases<f16>()))));
+template <typename T>
+std::vector<Case> OpSubIntCases() {
+ static_assert(IsInteger<UnwrapNumber<T>>);
+ return {
+ C(T{0}, T{0}, T{0}),
+ C(T{3}, T{2}, T{1}),
+ C(T{T::Lowest() + 1}, T{1}, T::Lowest()),
+ C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()),
+ C(Negate(T{1}), T::Highest(), T::Lowest()),
+ C(T::Lowest(), T{1}, T::Highest(), true),
+ C(T::Highest(), Negate(T{1}), T::Lowest(), true),
+ };
+}
+template <typename T>
+std::vector<Case> OpSubFloatCases() {
+ static_assert(IsFloatingPoint<UnwrapNumber<T>>);
+ return {
+ C(T{0}, T{0}, T{0}),
+ C(T{3}, T{2}, T{1}),
+ C(T::Highest(), T{1}, T{T::Highest() - 1}),
+ C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}),
+ C(T{0}, T::Highest(), T::Lowest()),
+ C(T::Highest(), Negate(T::Highest()), T::Inf(), true),
+ C(T::Lowest(), T::Highest(), -T::Inf(), true),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(Sub,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine(testing::Values(ast::BinaryOp::kSubtract),
+ testing::ValuesIn(Concat( //
+ OpSubIntCases<AInt>(),
+ OpSubIntCases<i32>(),
+ OpSubIntCases<u32>(),
+ OpSubFloatCases<AFloat>(),
+ OpSubFloatCases<f32>(),
+ OpSubFloatCases<f16>()))));
+
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 6dd35a5..823dcad 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -491,7 +491,7 @@
bool Run(const ast::Module& module) {
// Reserve container memory
graph_.resolved_symbols.reserve(module.GlobalDeclarations().Length());
- sorted_.reserve(module.GlobalDeclarations().Length());
+ sorted_.Reserve(module.GlobalDeclarations().Length());
// Collect all the named globals from the AST module
GatherGlobals(module);
@@ -505,7 +505,7 @@
// Dump the dependency graph if TINT_DUMP_DEPENDENCY_GRAPH is non-zero
DumpDependencyGraph();
- graph_.ordered_globals = std::move(sorted_);
+ graph_.ordered_globals = sorted_.Release();
return !diagnostics_.contains_errors();
}
@@ -632,7 +632,7 @@
// Make sure all 'enable' directives go before any other global declarations.
for (auto* global : declaration_order_) {
if (auto* enable = global->node->As<ast::Enable>()) {
- sorted_.add(enable);
+ sorted_.Add(enable);
}
}
@@ -641,31 +641,31 @@
// Skip 'enable' directives here, as they are already added.
continue;
}
- utils::UniqueVector<const Global*> stack;
+ utils::UniqueVector<const Global*, 8> stack;
TraverseDependencies(
global,
[&](const Global* g) { // Enter
- if (!stack.add(g)) {
- CyclicDependencyFound(g, stack);
+ if (!stack.Add(g)) {
+ CyclicDependencyFound(g, stack.Release());
return false;
}
- if (sorted_.contains(g->node)) {
+ if (sorted_.Contains(g->node)) {
// Visited this global already.
// stack was pushed, but exit() will not be called when we return
// false, so pop here.
- stack.pop_back();
+ stack.Pop();
return false;
}
return true;
},
[&](const Global* g) { // Exit. Only called if Enter returned true.
- sorted_.add(g->node);
- stack.pop_back();
+ sorted_.Add(g->node);
+ stack.Pop();
});
- sorted_.add(global->node);
+ sorted_.Add(global->node);
- if (!stack.empty()) {
+ if (!stack.IsEmpty()) {
// Each stack.push() must have a corresponding stack.pop_back().
TINT_ICE(Resolver, diagnostics_)
<< "stack not empty after returning from TraverseDependencies()";
@@ -691,12 +691,12 @@
/// @param root is the global that starts the cyclic dependency, which must be
/// found in `stack`.
/// @param stack is the global dependency stack that contains a loop.
- void CyclicDependencyFound(const Global* root, const std::vector<const Global*>& stack) {
+ void CyclicDependencyFound(const Global* root, utils::VectorRef<const Global*> stack) {
std::stringstream msg;
msg << "cyclic dependency found: ";
constexpr size_t kLoopNotStarted = ~0u;
size_t loop_start = kLoopNotStarted;
- for (size_t i = 0; i < stack.size(); i++) {
+ for (size_t i = 0; i < stack.Length(); i++) {
auto* e = stack[i];
if (loop_start == kLoopNotStarted && e == root) {
loop_start = i;
@@ -707,9 +707,9 @@
}
msg << "'" << NameOf(root->node) << "'";
AddError(diagnostics_, msg.str(), root->node->source);
- for (size_t i = loop_start; i < stack.size(); i++) {
+ for (size_t i = loop_start; i < stack.Length(); i++) {
auto* from = stack[i];
- auto* to = (i + 1 < stack.size()) ? stack[i + 1] : stack[loop_start];
+ auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start];
auto info = DepInfoFor(from, to);
AddNote(diagnostics_,
KindOf(from->node) + " '" + NameOf(from->node) + "' " + info.action + " " +
@@ -764,7 +764,7 @@
std::vector<Global*> declaration_order_;
/// Globals in sorted dependency order. Populated by SortGlobals().
- utils::UniqueVector<const ast::Node*> sorted_;
+ utils::UniqueVector<const ast::Node*, 64> sorted_;
};
} // namespace
diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h
index 0554817..9f5ddc5 100644
--- a/src/tint/resolver/dependency_graph.h
+++ b/src/tint/resolver/dependency_graph.h
@@ -46,7 +46,7 @@
DependencyGraph& output);
/// All globals in dependency-sorted order.
- std::vector<const ast::Node*> ordered_globals;
+ utils::Vector<const ast::Node*, 32> ordered_globals;
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol.
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 372d91a..3c4a6ac 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -489,8 +489,8 @@
ASSERT_NE(sem_x, nullptr);
ASSERT_NE(sem_y, nullptr);
- EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_x));
- EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
+ EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_x));
+ EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_y));
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index 16c177c..c589b2f 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -10952,60 +10952,60 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[737],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMinus,
},
{
/* [233] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[735],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMinus,
},
{
/* [234] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[733],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMinus,
},
{
/* [235] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[731],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMinus,
},
{
/* [236] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
- /* template types */ &kTemplateTypes[11],
+ /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[729],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpMinus,
},
{
/* [237] */
@@ -13237,7 +13237,7 @@
/* parameters */ &kParameters[862],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ &ConstEval::OpMinus,
+ /* const eval */ &ConstEval::OpUnaryMinus,
},
{
/* [423] */
@@ -13249,7 +13249,7 @@
/* parameters */ &kParameters[863],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ &ConstEval::OpMinus,
+ /* const eval */ &ConstEval::OpUnaryMinus,
},
{
/* [424] */
@@ -14620,11 +14620,11 @@
},
{
/* [1] */
- /* op -<T : fiu32_f16>(T, T) -> T */
- /* op -<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
- /* op -<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
- /* op -<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
- /* op -<T : f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */
+ /* op -<T : fia_fiu32_f16>(T, T) -> T */
+ /* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
+ /* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
+ /* op -<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
+ /* op -<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */
/* num overloads */ 5,
/* overloads */ &kOverloads[232],
},
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 2edab1f..cd8bdc8 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -164,7 +164,7 @@
return false;
}
- if (!enabled_extensions_.contains(ast::Extension::kChromiumDisableUniformityAnalysis)) {
+ if (!enabled_extensions_.Contains(ast::Extension::kChromiumDisableUniformityAnalysis)) {
if (!AnalyzeUniformity(builder_, dependencies_)) {
// TODO(jrprice): Reject programs that fail uniformity analysis.
}
@@ -194,7 +194,7 @@
[&](const ast::U32*) { return builder_->create<sem::U32>(); },
[&](const ast::F16* t) -> sem::F16* {
// Validate if f16 type is allowed.
- if (!enabled_extensions_.contains(ast::Extension::kF16)) {
+ if (!enabled_extensions_.Contains(ast::Extension::kF16)) {
AddError("f16 used without 'f16' extension enabled", t->source);
return nullptr;
}
@@ -2082,7 +2082,7 @@
return nullptr;
}
- if ((ty->Is<sem::F16>()) && (!enabled_extensions_.contains(tint::ast::Extension::kF16))) {
+ if ((ty->Is<sem::F16>()) && (!enabled_extensions_.Contains(tint::ast::Extension::kF16))) {
AddError("f16 literal used without 'f16' extension enabled", literal->source);
return nullptr;
}
@@ -2442,7 +2442,7 @@
}
bool Resolver::Enable(const ast::Enable* enable) {
- enabled_extensions_.add(enable->extension);
+ enabled_extensions_.Add(enable->extension);
return true;
}
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index ecc251a..35ebcf5 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -800,7 +800,7 @@
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
const auto& vars = func_sem->TransitivelyReferencedGlobals();
- ASSERT_EQ(vars.size(), 3u);
+ ASSERT_EQ(vars.Length(), 3u);
EXPECT_EQ(vars[0]->Declaration(), wg_var);
EXPECT_EQ(vars[1]->Declaration(), sb_var);
EXPECT_EQ(vars[2]->Declaration(), priv_var);
@@ -835,7 +835,7 @@
EXPECT_EQ(func2_sem->Parameters().Length(), 0u);
const auto& vars = func2_sem->TransitivelyReferencedGlobals();
- ASSERT_EQ(vars.size(), 3u);
+ ASSERT_EQ(vars.Length(), 3u);
EXPECT_EQ(vars[0]->Declaration(), wg_var);
EXPECT_EQ(vars[1]->Declaration(), sb_var);
EXPECT_EQ(vars[2]->Declaration(), priv_var);
@@ -853,7 +853,7 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u);
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
}
@@ -868,7 +868,7 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u);
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
}
@@ -879,7 +879,7 @@
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
- EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
+ EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().Length(), 0u);
EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
}
@@ -2006,7 +2006,7 @@
const sem::Function* sf = Sem().Get(f);
auto pairs = sf->TextureSamplerPairs();
- ASSERT_EQ(pairs.size(), 1u);
+ ASSERT_EQ(pairs.Length(), 1u);
EXPECT_TRUE(pairs[0].first != nullptr);
EXPECT_TRUE(pairs[0].second != nullptr);
}
@@ -2026,12 +2026,12 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto inner_pairs = Sem().Get(inner_func)->TextureSamplerPairs();
- ASSERT_EQ(inner_pairs.size(), 1u);
+ ASSERT_EQ(inner_pairs.Length(), 1u);
EXPECT_TRUE(inner_pairs[0].first != nullptr);
EXPECT_TRUE(inner_pairs[0].second != nullptr);
auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
- ASSERT_EQ(outer_pairs.size(), 1u);
+ ASSERT_EQ(outer_pairs.Length(), 1u);
EXPECT_TRUE(outer_pairs[0].first != nullptr);
EXPECT_TRUE(outer_pairs[0].second != nullptr);
}
@@ -2055,17 +2055,17 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs();
- ASSERT_EQ(inner_pairs_1.size(), 1u);
+ ASSERT_EQ(inner_pairs_1.Length(), 1u);
EXPECT_TRUE(inner_pairs_1[0].first != nullptr);
EXPECT_TRUE(inner_pairs_1[0].second != nullptr);
auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs();
- ASSERT_EQ(inner_pairs_1.size(), 1u);
+ ASSERT_EQ(inner_pairs_1.Length(), 1u);
EXPECT_TRUE(inner_pairs_2[0].first != nullptr);
EXPECT_TRUE(inner_pairs_2[0].second != nullptr);
auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
- ASSERT_EQ(outer_pairs.size(), 1u);
+ ASSERT_EQ(outer_pairs.Length(), 1u);
EXPECT_TRUE(outer_pairs[0].first != nullptr);
EXPECT_TRUE(outer_pairs[0].second != nullptr);
}
@@ -2092,17 +2092,17 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto inner_pairs_1 = Sem().Get(inner_func_1)->TextureSamplerPairs();
- ASSERT_EQ(inner_pairs_1.size(), 1u);
+ ASSERT_EQ(inner_pairs_1.Length(), 1u);
EXPECT_TRUE(inner_pairs_1[0].first != nullptr);
EXPECT_TRUE(inner_pairs_1[0].second != nullptr);
auto inner_pairs_2 = Sem().Get(inner_func_2)->TextureSamplerPairs();
- ASSERT_EQ(inner_pairs_2.size(), 1u);
+ ASSERT_EQ(inner_pairs_2.Length(), 1u);
EXPECT_TRUE(inner_pairs_2[0].first != nullptr);
EXPECT_TRUE(inner_pairs_2[0].second != nullptr);
auto outer_pairs = Sem().Get(outer_func)->TextureSamplerPairs();
- ASSERT_EQ(outer_pairs.size(), 2u);
+ ASSERT_EQ(outer_pairs.Length(), 2u);
EXPECT_TRUE(outer_pairs[0].first == inner_pairs_1[0].first);
EXPECT_TRUE(outer_pairs[0].second == inner_pairs_1[0].second);
EXPECT_TRUE(outer_pairs[1].first == inner_pairs_2[0].first);
@@ -2119,7 +2119,7 @@
const sem::Function* sf = Sem().Get(f);
auto pairs = sf->TextureSamplerPairs();
- ASSERT_EQ(pairs.size(), 1u);
+ ASSERT_EQ(pairs.Length(), 1u);
EXPECT_TRUE(pairs[0].first != nullptr);
EXPECT_TRUE(pairs[0].second == nullptr);
}
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 9ccb7ef..e4c31c9 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -102,14 +102,14 @@
uint32_t arg_index;
/// The set of edges from this node to other nodes in the graph.
- utils::UniqueVector<Node*> edges;
+ utils::UniqueVector<Node*, 4> edges;
/// The node that this node was visited from, or nullptr if not visited.
Node* visited_from = nullptr;
/// Add an edge to the `to` node.
/// @param to the destination node
- void AddEdge(Node* to) { edges.add(to); }
+ void AddEdge(Node* to) { edges.Add(to); }
};
/// ParameterInfo holds information about the uniformity requirements and effects for a particular
@@ -337,13 +337,13 @@
// Look at which nodes are reachable from "RequiredToBeUniform".
{
- utils::UniqueVector<Node*> reachable;
+ utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->required_to_be_uniform, &reachable);
- if (reachable.contains(current_function_->may_be_non_uniform)) {
+ if (reachable.Contains(current_function_->may_be_non_uniform)) {
MakeError(*current_function_, current_function_->may_be_non_uniform);
return false;
}
- if (reachable.contains(current_function_->cf_start)) {
+ if (reachable.Contains(current_function_->cf_start)) {
current_function_->callsite_tag = CallSiteRequiredToBeUniform;
}
@@ -351,7 +351,7 @@
// was reachable.
for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i];
- if (reachable.contains(current_function_->variables.Get(sem_.Get(param)))) {
+ if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) {
current_function_->parameters[i].tag = ParameterRequiredToBeUniform;
}
}
@@ -359,9 +359,9 @@
// Look at which nodes are reachable from "CF_return"
{
- utils::UniqueVector<Node*> reachable;
+ utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->cf_return, &reachable);
- if (reachable.contains(current_function_->may_be_non_uniform)) {
+ if (reachable.Contains(current_function_->may_be_non_uniform)) {
current_function_->function_tag = SubsequentControlFlowMayBeNonUniform;
}
@@ -369,7 +369,7 @@
// each parameter node that was reachable.
for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i];
- if (reachable.contains(current_function_->variables.Get(sem_.Get(param)))) {
+ if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) {
current_function_->parameters[i].tag =
ParameterRequiredToBeUniformForSubsequentControlFlow;
}
@@ -378,9 +378,9 @@
// If "Value_return" exists, look at which nodes are reachable from it
if (current_function_->value_return) {
- utils::UniqueVector<Node*> reachable;
+ utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->value_return, &reachable);
- if (reachable.contains(current_function_->may_be_non_uniform)) {
+ if (reachable.Contains(current_function_->may_be_non_uniform)) {
current_function_->function_tag = ReturnValueMayBeNonUniform;
}
@@ -388,7 +388,7 @@
// parameter node that was reachable.
for (size_t i = 0; i < func->params.Length(); i++) {
auto* param = func->params[i];
- if (reachable.contains(current_function_->variables.Get(sem_.Get(param)))) {
+ if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) {
current_function_->parameters[i].tag =
ParameterRequiredToBeUniformForReturnValue;
}
@@ -404,16 +404,16 @@
// Reset "visited" state for all nodes.
current_function_->ResetVisited();
- utils::UniqueVector<Node*> reachable;
+ utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->parameters[i].pointer_return_value, &reachable);
- if (reachable.contains(current_function_->may_be_non_uniform)) {
+ if (reachable.Contains(current_function_->may_be_non_uniform)) {
current_function_->parameters[i].pointer_may_become_non_uniform = true;
}
// Check every other parameter to see if they feed into this parameter's final value.
for (size_t j = 0; j < func->params.Length(); j++) {
auto* param_source = sem_.Get<sem::Parameter>(func->params[j]);
- if (reachable.contains(current_function_->parameters[j].init_value)) {
+ if (reachable.Contains(current_function_->parameters[j].init_value)) {
current_function_->parameters[i].pointer_param_output_sources.push_back(
param_source);
}
@@ -1356,7 +1356,7 @@
/// recording which node they were reached from.
/// @param source the starting node
/// @param reachable the set of reachable nodes to populate, if required
- void Traverse(Node* source, utils::UniqueVector<Node*>* reachable = nullptr) {
+ void Traverse(Node* source, utils::UniqueVector<Node*, 4>* reachable = nullptr) {
std::vector<Node*> to_visit{source};
while (!to_visit.empty()) {
@@ -1364,7 +1364,7 @@
to_visit.pop_back();
if (reachable) {
- reachable->add(node);
+ reachable->Add(node);
}
for (auto* to : node->edges) {
if (to->visited_from == nullptr) {
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index f3a02ed..6382d34 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -141,7 +141,7 @@
callback(f);
return;
}
- if (f->TransitivelyCalledFunctions().contains(to)) {
+ if (f->TransitivelyCalledFunctions().Contains(to)) {
TraverseCallChain(diagnostics, f, to, callback);
callback(f);
return;
@@ -519,7 +519,7 @@
const ast::Extensions& enabled_extensions,
ValidTypeStorageLayouts& layouts) const {
if (var->StorageClass() == ast::StorageClass::kPushConstant &&
- !enabled_extensions.contains(ast::Extension::kChromiumExperimentalPushConstant) &&
+ !enabled_extensions.Contains(ast::Extension::kChromiumExperimentalPushConstant) &&
IsValidationEnabled(var->Declaration()->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) {
AddError(
@@ -1723,7 +1723,7 @@
return true;
}
- if (!enabled_extensions.contains(extension)) {
+ if (!enabled_extensions.Contains(extension)) {
AddError("cannot call built-in function '" + std::string(builtin->str()) +
"' without extension " + utils::ToString(extension),
call->Declaration()->source);
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index c90d749..a09f81c 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -81,7 +81,7 @@
}
/// @returns all directly referenced global variables
- const utils::UniqueVector<const GlobalVariable*>& DirectlyReferencedGlobals() const {
+ const utils::UniqueVector<const GlobalVariable*, 4>& DirectlyReferencedGlobals() const {
return directly_referenced_globals_;
}
@@ -89,12 +89,12 @@
/// Note: Implicitly adds this global to the transtively-called globals.
/// @param global the module-scope variable
void AddDirectlyReferencedGlobal(const sem::GlobalVariable* global) {
- directly_referenced_globals_.add(global);
- transitively_referenced_globals_.add(global);
+ directly_referenced_globals_.Add(global);
+ transitively_referenced_globals_.Add(global);
}
/// @returns all transitively referenced global variables
- const utils::UniqueVector<const GlobalVariable*>& TransitivelyReferencedGlobals() const {
+ const utils::UniqueVector<const GlobalVariable*, 8>& TransitivelyReferencedGlobals() const {
return transitively_referenced_globals_;
}
@@ -102,29 +102,29 @@
/// variable.
/// @param global the module-scoped variable
void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) {
- transitively_referenced_globals_.add(global);
+ transitively_referenced_globals_.Add(global);
}
/// @returns the list of functions that this function transitively calls.
- const utils::UniqueVector<const Function*>& TransitivelyCalledFunctions() const {
+ const utils::UniqueVector<const Function*, 8>& TransitivelyCalledFunctions() const {
return transitively_called_functions_;
}
/// Records that this function transitively calls `function`.
/// @param function the function this function transitively calls
void AddTransitivelyCalledFunction(const Function* function) {
- transitively_called_functions_.add(function);
+ transitively_called_functions_.Add(function);
}
/// @returns the list of builtins that this function directly calls.
- const utils::UniqueVector<const Builtin*>& DirectlyCalledBuiltins() const {
+ const utils::UniqueVector<const Builtin*, 4>& DirectlyCalledBuiltins() const {
return directly_called_builtins_;
}
/// Records that this function transitively calls `builtin`.
/// @param builtin the builtin this function directly calls
void AddDirectlyCalledBuiltin(const Builtin* builtin) {
- directly_called_builtins_.add(builtin);
+ directly_called_builtins_.Add(builtin);
}
/// Adds the given texture/sampler pair to the list of unique pairs
@@ -134,12 +134,14 @@
/// @param texture the texture (must be non-null)
/// @param sampler the sampler (null indicates a texture-only reference)
void AddTextureSamplerPair(const sem::Variable* texture, const sem::Variable* sampler) {
- texture_sampler_pairs_.add(VariablePair(texture, sampler));
+ texture_sampler_pairs_.Add(VariablePair(texture, sampler));
}
/// @returns the list of texture/sampler pairs that this function uses
/// (directly or indirectly).
- const std::vector<VariablePair>& TextureSamplerPairs() const { return texture_sampler_pairs_; }
+ const utils::Vector<VariablePair, 8>& TextureSamplerPairs() const {
+ return texture_sampler_pairs_;
+ }
/// @returns the list of direct calls to functions / builtins made by this
/// function
@@ -253,17 +255,20 @@
sem::Behaviors& Behaviors() { return behaviors_; }
private:
+ Function(const Function&) = delete;
+ Function(Function&&) = delete;
+
VariableBindings TransitivelyReferencedSamplerVariablesImpl(ast::SamplerKind kind) const;
VariableBindings TransitivelyReferencedSampledTextureVariablesImpl(bool multisampled) const;
const ast::Function* const declaration_;
sem::WorkgroupSize workgroup_size_;
- utils::UniqueVector<const GlobalVariable*> directly_referenced_globals_;
- utils::UniqueVector<const GlobalVariable*> transitively_referenced_globals_;
- utils::UniqueVector<const Function*> transitively_called_functions_;
- utils::UniqueVector<const Builtin*> directly_called_builtins_;
- utils::UniqueVector<VariablePair> texture_sampler_pairs_;
+ utils::UniqueVector<const GlobalVariable*, 4> directly_referenced_globals_;
+ utils::UniqueVector<const GlobalVariable*, 8> transitively_referenced_globals_;
+ utils::UniqueVector<const Function*, 8> transitively_called_functions_;
+ utils::UniqueVector<const Builtin*, 4> directly_called_builtins_;
+ utils::UniqueVector<VariablePair, 8> texture_sampler_pairs_;
std::vector<const Call*> direct_calls_;
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
diff --git a/src/tint/sem/module.cc b/src/tint/sem/module.cc
index 7c60650..38f0ec6 100644
--- a/src/tint/sem/module.cc
+++ b/src/tint/sem/module.cc
@@ -21,7 +21,7 @@
namespace tint::sem {
-Module::Module(std::vector<const ast::Node*> dep_ordered_decls, ast::Extensions extensions)
+Module::Module(utils::VectorRef<const ast::Node*> dep_ordered_decls, ast::Extensions extensions)
: dep_ordered_decls_(std::move(dep_ordered_decls)), extensions_(std::move(extensions)) {}
Module::~Module() = default;
diff --git a/src/tint/sem/module.h b/src/tint/sem/module.h
index dffe003..b451c5b 100644
--- a/src/tint/sem/module.h
+++ b/src/tint/sem/module.h
@@ -15,10 +15,9 @@
#ifndef SRC_TINT_SEM_MODULE_H_
#define SRC_TINT_SEM_MODULE_H_
-#include <vector>
-
#include "src/tint/ast/extension.h"
#include "src/tint/sem/node.h"
+#include "src/tint/utils/vector.h"
// Forward declarations
namespace tint::ast {
@@ -34,13 +33,13 @@
/// Constructor
/// @param dep_ordered_decls the dependency-ordered module-scope declarations
/// @param extensions the list of enabled extensions in the module
- Module(std::vector<const ast::Node*> dep_ordered_decls, ast::Extensions extensions);
+ Module(utils::VectorRef<const ast::Node*> dep_ordered_decls, ast::Extensions extensions);
/// Destructor
~Module() override;
/// @returns the dependency-ordered global declarations for the module
- const std::vector<const ast::Node*>& DependencyOrderedDeclarations() const {
+ const utils::Vector<const ast::Node*, 64>& DependencyOrderedDeclarations() const {
return dep_ordered_decls_;
}
@@ -48,7 +47,7 @@
const ast::Extensions& Extensions() const { return extensions_; }
private:
- const std::vector<const ast::Node*> dep_ordered_decls_;
+ const utils::Vector<const ast::Node*, 64> dep_ordered_decls_;
ast::Extensions extensions_;
};
diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc
index 2f5101e..ca9005b 100644
--- a/src/tint/transform/combine_samplers.cc
+++ b/src/tint/transform/combine_samplers.cc
@@ -167,7 +167,7 @@
ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
if (auto* func = sem.Get(src)) {
auto pairs = func->TextureSamplerPairs();
- if (pairs.empty()) {
+ if (pairs.IsEmpty()) {
return nullptr;
}
utils::Vector<const ast::Parameter*, 8> params;
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index d725a4b..99fe94e 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -328,7 +328,7 @@
/// @returns an Offset for the given ast::Expression
const Offset* ToOffset(const ast::Expression* expr) {
if (auto* lit = expr->As<ast::IntLiteralExpression>()) {
- if (lit->value > 0) {
+ if (lit->value >= 0) {
return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
}
}
diff --git a/src/tint/transform/disable_uniformity_analysis.cc b/src/tint/transform/disable_uniformity_analysis.cc
index 7a30023..918b1f1 100644
--- a/src/tint/transform/disable_uniformity_analysis.cc
+++ b/src/tint/transform/disable_uniformity_analysis.cc
@@ -28,7 +28,7 @@
DisableUniformityAnalysis::~DisableUniformityAnalysis() = default;
bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const {
- return !program->Sem().Module()->Extensions().contains(
+ return !program->Sem().Module()->Extensions().Contains(
ast::Extension::kChromiumDisableUniformityAnalysis);
}
diff --git a/src/tint/transform/promote_initializers_to_let.cc b/src/tint/transform/promote_initializers_to_let.cc
index b57fb25..ec7ba95 100644
--- a/src/tint/transform/promote_initializers_to_let.cc
+++ b/src/tint/transform/promote_initializers_to_let.cc
@@ -75,19 +75,22 @@
return true;
},
[&](const ast::IdentifierExpression* expr) {
- if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(expr)) {
- // Identifier resolves to a variable
- if (auto* stmt = user->Stmt()) {
- if (auto* decl = stmt->Declaration()->As<ast::VariableDeclStatement>();
- decl && decl->variable->Is<ast::Const>()) {
- // The identifier is used on the RHS of a 'const' declaration. Ignore.
- return true;
+ if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ // Identifier resolves to a variable
+ if (auto* stmt = user->Stmt()) {
+ if (auto* decl = stmt->Declaration()->As<ast::VariableDeclStatement>();
+ decl && decl->variable->Is<ast::Const>()) {
+ // The identifier is used on the RHS of a 'const' declaration.
+ // Ignore.
+ return true;
+ }
}
- }
- if (user->Variable()->Declaration()->Is<ast::Const>()) {
- // The identifier resolves to a 'const' variable, but isn't used to
- // initialize another 'const'. This needs promoting.
- return promote(user);
+ if (user->Variable()->Declaration()->Is<ast::Const>()) {
+ // The identifier resolves to a 'const' variable, but isn't used to
+ // initialize another 'const'. This needs promoting.
+ return promote(user);
+ }
}
}
return true;
diff --git a/src/tint/transform/promote_initializers_to_let_test.cc b/src/tint/transform/promote_initializers_to_let_test.cc
index d4d9fb2..536d04a 100644
--- a/src/tint/transform/promote_initializers_to_let_test.cc
+++ b/src/tint/transform/promote_initializers_to_let_test.cc
@@ -158,6 +158,37 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(PromoteInitializersToLetTest, GlobalConstArrayDynamicIndex) {
+ auto* src = R"(
+const TRI_VERTICES = array(
+ vec4(0., 0., 0., 1.),
+ vec4(0., 1., 0., 1.),
+ vec4(1., 1., 0., 1.),
+);
+
+@vertex
+fn vs_main(@builtin(vertex_index) in_vertex_index: u32) -> @builtin(position) vec4<f32> {
+ // note: TRI_VERTICES requires a materialize before the dynamic index.
+ return TRI_VERTICES[in_vertex_index];
+}
+)";
+
+ auto* expect = R"(
+const TRI_VERTICES = array(vec4(0.0, 0.0, 0.0, 1.0), vec4(0.0, 1.0, 0.0, 1.0), vec4(1.0, 1.0, 0.0, 1.0));
+
+@vertex
+fn vs_main(@builtin(vertex_index) in_vertex_index : u32) -> @builtin(position) vec4<f32> {
+ let tint_symbol = TRI_VERTICES;
+ return tint_symbol[in_vertex_index];
+}
+)";
+
+ DataMap data;
+ auto got = Run<PromoteInitializersToLet>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(PromoteInitializersToLetTest, GlobalConstBasicArray_OutOfOrder) {
auto* src = R"(
fn f() {
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index b79f60f..bece2d6 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -48,7 +48,7 @@
ProgramBuilder& b = *ctx.dst;
std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs;
std::unordered_set<const sem::Variable*> atomic_variables;
- utils::UniqueVector<const sem::Expression*> atomic_expressions;
+ utils::UniqueVector<const sem::Expression*, 8> atomic_expressions;
public:
/// Constructor
@@ -92,7 +92,7 @@
// Keep track of this expression. We'll need to modify the source variable /
// structure to be atomic.
- atomic_expressions.add(ctx.src->Sem().Get(args[0]));
+ atomic_expressions.Add(ctx.src->Sem().Get(args[0]));
}
// Remove the stub from the output program
@@ -153,7 +153,7 @@
}
void ProcessAtomicExpressions() {
- for (size_t i = 0; i < atomic_expressions.size(); i++) {
+ for (size_t i = 0; i < atomic_expressions.Length(); i++) {
Switch(
atomic_expressions[i], //
[&](const sem::VariableUser* user) {
@@ -162,7 +162,7 @@
ctx.Replace(v->type, AtomicTypeFor(user->Variable()->Type()));
}
if (auto* ctor = user->Variable()->Constructor()) {
- atomic_expressions.add(ctor);
+ atomic_expressions.Add(ctor);
}
},
[&](const sem::StructMemberAccess* access) {
@@ -170,14 +170,14 @@
// atomic.
auto* member = access->Member();
Fork(member->Struct()->Declaration()).atomic_members.emplace(member->Index());
- atomic_expressions.add(access->Object());
+ atomic_expressions.Add(access->Object());
},
[&](const sem::IndexAccessorExpression* index) {
- atomic_expressions.add(index->Object());
+ atomic_expressions.Add(index->Object());
},
[&](const sem::Expression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
- atomic_expressions.add(ctx.src->Sem().Get(unary->expr));
+ atomic_expressions.Add(ctx.src->Sem().Get(unary->expr));
}
});
}
diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc
index 963ca01..94df3b9 100644
--- a/src/tint/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/transform/zero_init_workgroup_memory.cc
@@ -75,7 +75,7 @@
};
/// A list of unique ArrayIndex
- using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
+ using ArrayIndices = utils::UniqueVector<ArrayIndex, 4, ArrayIndex::Hasher>;
/// Expression holds information about an expression that is being built for a
/// statement will zero workgroup values.
@@ -193,7 +193,7 @@
ArrayIndices array_indices;
for (auto& s : stmts) {
for (auto& idx : s.array_indices) {
- array_indices.add(idx);
+ array_indices.Add(idx);
}
}
@@ -311,7 +311,7 @@
auto division = num_values;
auto a = get_expr(modulo);
auto array_indices = a.array_indices;
- array_indices.add(ArrayIndex{modulo, division});
+ array_indices.Add(ArrayIndex{modulo, division});
auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
[&] { return b.Symbols().New("i"); });
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices};
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
new file mode 100644
index 0000000..81bebf2
--- /dev/null
+++ b/src/tint/utils/hashmap.h
@@ -0,0 +1,305 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_UTILS_HASHMAP_H_
+#define SRC_TINT_UTILS_HASHMAP_H_
+
+#include <functional>
+#include <optional>
+#include <utility>
+
+#include "src/tint/utils/hashset.h"
+
+namespace tint::utils {
+
+/// An unordered map that uses a robin-hood hashing algorithm.
+///
+/// Hashmap internally wraps a Hashset for providing a store for key-value pairs.
+///
+/// @see Hashset
+template <typename K,
+ typename V,
+ size_t N,
+ typename HASH = std::hash<K>,
+ typename EQUAL = std::equal_to<K>>
+class Hashmap {
+ /// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into
+ /// the underlying Hashset.
+ ///
+ /// LazyCreator holds a #key, and a #create function used to build the final Entry::value.
+ /// The #create function must be of the signature `V()`.
+ ///
+ /// LazyCreator can be compared to Entry and hashed, allowing them to be passed to
+ /// Hashset::Insert(). If the set does not contain an existing entry with #key,
+ /// Hashset::Insert() will construct a new Entry passing the rvalue LazyCreator as the
+ /// constructor argument, which in turn calls the #create function to generate the entry value.
+ ///
+ /// @see Entry
+ /// @see Hasher
+ /// @see Equality
+ template <typename CREATE>
+ struct LazyCreator {
+ /// The key of the entry to insert into the map
+ const K& key;
+ /// The value creation function
+ CREATE create;
+ };
+
+ /// Entry holds a key and value pair, and is used as the element type of the underlying Hashset.
+ /// Entries are compared and hashed using only the #key.
+ /// @see Hasher
+ /// @see Equality
+ struct Entry {
+ /// Constructor from a key and value pair
+ Entry(K k, V v) : key(std::move(k)), value(std::move(v)) {}
+
+ /// Copy-constructor.
+ Entry(const Entry&) = default;
+
+ /// Move-constructor.
+ Entry(Entry&&) = default;
+
+ /// Constructor from a LazyCreator.
+ /// The constructor invokes the LazyCreator::create function to build the #value.
+ /// @see LazyCreator
+ template <typename CREATE>
+ Entry(const LazyCreator<CREATE>& creator) // NOLINT(runtime/explicit)
+ : key(creator.key), value(creator.create()) {}
+
+ /// Assignment operator from a LazyCreator.
+ /// The assignment invokes the LazyCreator::create function to build the #value.
+ /// @see LazyCreator
+ template <typename CREATE>
+ Entry& operator=(LazyCreator<CREATE>&& creator) {
+ key = std::move(creator.key);
+ value = creator.create();
+ return *this;
+ }
+
+ /// Copy-assignment operator
+ Entry& operator=(const Entry&) = default;
+
+ /// Move-assignment operator
+ Entry& operator=(Entry&&) = default;
+
+ K key; /// The map entry key
+ V value; /// The map entry value
+ };
+
+ /// Hash provider for the underlying Hashset.
+ /// Provides hash functions for an Entry, K or LazyCreator.
+ /// The hash functions only consider the key of an entry.
+ struct Hasher {
+ /// Calculates a hash from an Entry
+ size_t operator()(const Entry& entry) const { return HASH()(entry.key); }
+ /// Calculates a hash from a K
+ size_t operator()(const K& key) const { return HASH()(key); }
+ /// Calculates a hash from a LazyCreator
+ template <typename CREATE>
+ size_t operator()(const LazyCreator<CREATE>& lc) const {
+ return HASH()(lc.key);
+ }
+ };
+
+ /// Equality provider for the underlying Hashset.
+ /// Provides equality functions for an Entry, K or LazyCreator to an Entry.
+ /// The equality functions only consider the key for equality.
+ struct Equality {
+ /// Compares an Entry to an Entry for equality.
+ bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); }
+ /// Compares a K to an Entry for equality.
+ bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); }
+ /// Compares a LazyCreator to an Entry for equality.
+ template <typename CREATE>
+ bool operator()(const LazyCreator<CREATE>& lc, const Entry& b) const {
+ return EQUAL()(lc.key, b.key);
+ }
+ };
+
+ /// The underlying set
+ using Set = Hashset<Entry, N, Hasher, Equality>;
+
+ public:
+ /// A Key and Value const-reference pair.
+ struct KeyValue {
+ /// key of a map entry
+ const K& key;
+ /// value of a map entry
+ const V& value;
+
+ /// Equality operator
+ /// @param other the other KeyValue
+ /// @returns true if the key and value of this KeyValue are equal to other's.
+ bool operator==(const KeyValue& other) const {
+ return key == other.key && value == other.value;
+ }
+ };
+
+ /// STL-style alias to KeyValue.
+ /// Used by gmock for the `ElementsAre` checks.
+ using value_type = KeyValue;
+
+ /// Iterator for the map
+ class Iterator {
+ public:
+ /// @returns the key of the entry pointed to by this iterator
+ const K& Key() const { return it->key; }
+
+ /// @returns the value of the entry pointed to by this iterator
+ const V& Value() const { return it->value; }
+
+ /// Increments the iterator
+ /// @returns this iterator
+ Iterator& operator++() {
+ ++it;
+ return *this;
+ }
+
+ /// Equality operator
+ /// @param other the other iterator to compare this iterator to
+ /// @returns true if this iterator is equal to other
+ bool operator==(const Iterator& other) const { return it == other.it; }
+
+ /// Inequality operator
+ /// @param other the other iterator to compare this iterator to
+ /// @returns true if this iterator is not equal to other
+ bool operator!=(const Iterator& other) const { return it != other.it; }
+
+ /// @returns a pair of key and value for the entry pointed to by this iterator
+ KeyValue operator*() const { return {Key(), Value()}; }
+
+ private:
+ /// Friend class
+ friend class Hashmap;
+
+ /// Underlying iterator type
+ using SetIterator = typename Set::Iterator;
+
+ explicit Iterator(SetIterator i) : it(i) {}
+
+ SetIterator it;
+ };
+
+ /// Removes all entries from the map.
+ void Clear() { set_.Clear(); }
+
+ /// Adds the key-value pair to the map, if the map does not already contain an entry with a key
+ /// equal to `key`.
+ /// @param key the entry's key to add to the map
+ /// @param value the entry's value to add to the map
+ /// @returns true if the entry was added to the map, false if there was already an entry in the
+ /// map with a key equal to `key`.
+ template <typename KEY, typename VALUE>
+ bool Add(KEY&& key, VALUE&& value) {
+ return set_.Add(Entry{std::forward<KEY>(key), std::forward<VALUE>(value)});
+ }
+
+ /// Adds the key-value pair to the map, replacing any entry with a key equal to `key`.
+ /// @param key the entry's key to add to the map
+ /// @param value the entry's value to add to the map
+ template <typename KEY, typename VALUE>
+ void Replace(KEY&& key, VALUE&& value) {
+ set_.Replace(Entry{std::forward<KEY>(key), std::forward<VALUE>(value)});
+ }
+
+ /// Searches for an entry with the given key value.
+ /// @param key the entry's key value to search for.
+ /// @returns the value of the entry with the given key, or no value if the entry was not found.
+ std::optional<V> Get(const K& key) {
+ if (auto* entry = set_.Find(key)) {
+ return entry->value;
+ }
+ return std::nullopt;
+ }
+
+ /// Searches for an entry with the given key value, adding and returning the result of
+ /// calling `create` if the entry was not found.
+ /// @param key the entry's key value to search for.
+ /// @param create the create function to call if the map does not contain the key.
+ /// @returns the value of the entry.
+ template <typename CREATE>
+ V& GetOrCreate(const K& key, CREATE&& create) {
+ LazyCreator<CREATE> lc{key, std::forward<CREATE>(create)};
+ auto res = set_.Add(std::move(lc));
+ return res.entry->value;
+ }
+
+ /// Searches for an entry with the given key value, adding and returning a newly created
+ /// zero-initialized value if the entry was not found.
+ /// @param key the entry's key value to search for.
+ /// @returns the value of the entry.
+ V& GetOrZero(const K& key) {
+ auto zero = [] { return V{}; };
+ LazyCreator<decltype(zero)> lc{key, zero};
+ auto res = set_.Add(std::move(lc));
+ return res.entry->value;
+ }
+
+ /// Searches for an entry with the given key value.
+ /// @param key the entry's key value to search for.
+ /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry
+ /// was not found.
+ /// @warning the pointer must not be used after the map is mutated
+ V* Find(const K& key) {
+ if (auto* entry = set_.Find(key)) {
+ return &entry->value;
+ }
+ return nullptr;
+ }
+
+ /// Searches for an entry with the given key value.
+ /// @param key the entry's key value to search for.
+ /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry
+ /// was not found.
+ /// @warning the pointer must not be used after the map is mutated
+ const V* Find(const K& key) const {
+ if (auto* entry = set_.Find(key)) {
+ return &entry->value;
+ }
+ return nullptr;
+ }
+
+ /// Removes an entry from the set with a key equal to `key`.
+ /// @param key the entry key value to remove.
+ /// @returns true if an entry was removed.
+ bool Remove(const K& key) { return set_.Remove(key); }
+
+ /// Checks whether an entry exists in the map with a key equal to `key`.
+ /// @param key the entry key value to search for.
+ /// @returns true if the map contains an entry with the given key.
+ bool Contains(const K& key) const { return set_.Contains(key); }
+
+ /// Pre-allocates memory so that the map can hold at least `capacity` entries.
+ /// @param capacity the new capacity of the map.
+ void Reserve(size_t capacity) { set_.Reserve(capacity); }
+
+ /// @returns the number of entries in the map.
+ size_t Count() const { return set_.Count(); }
+
+ /// @returns true if the map contains no entries.
+ bool IsEmpty() const { return set_.IsEmpty(); }
+
+ /// @returns an iterator to the start of the map
+ Iterator begin() const { return Iterator{set_.begin()}; }
+
+ /// @returns an iterator to the end of the map
+ Iterator end() const { return Iterator{set_.end()}; }
+
+ private:
+ Set set_;
+};
+
+} // namespace tint::utils
+
+#endif // SRC_TINT_UTILS_HASHMAP_H_
diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc
new file mode 100644
index 0000000..45e929b
--- /dev/null
+++ b/src/tint/utils/hashmap_test.cc
@@ -0,0 +1,179 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/utils/hashmap.h"
+
+#include <array>
+#include <random>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+
+#include "gmock/gmock.h"
+
+namespace tint::utils {
+namespace {
+
+constexpr std::array kPrimes{
+ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53,
+ 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131,
+ 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223,
+ 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311,
+ 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409,
+};
+
+TEST(Hashmap, Empty) {
+ Hashmap<std::string, int, 8> map;
+ EXPECT_EQ(map.Count(), 0u);
+}
+
+TEST(Hashmap, AddRemove) {
+ Hashmap<std::string, std::string, 8> map;
+ EXPECT_TRUE(map.Add("hello", "world"));
+ EXPECT_EQ(map.Get("hello"), "world");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_TRUE(map.Contains("hello"));
+ EXPECT_FALSE(map.Contains("world"));
+ EXPECT_FALSE(map.Add("hello", "cat"));
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_TRUE(map.Remove("hello"));
+ EXPECT_EQ(map.Count(), 0u);
+ EXPECT_FALSE(map.Contains("hello"));
+ EXPECT_FALSE(map.Contains("world"));
+}
+
+TEST(Hashmap, ReplaceRemove) {
+ Hashmap<std::string, std::string, 8> map;
+ map.Replace("hello", "world");
+ EXPECT_EQ(map.Get("hello"), "world");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_TRUE(map.Contains("hello"));
+ EXPECT_FALSE(map.Contains("world"));
+ map.Replace("hello", "cat");
+ EXPECT_EQ(map.Get("hello"), "cat");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_TRUE(map.Remove("hello"));
+ EXPECT_EQ(map.Count(), 0u);
+ EXPECT_FALSE(map.Contains("hello"));
+ EXPECT_FALSE(map.Contains("world"));
+}
+
+TEST(Hashmap, Iterator) {
+ using Map = Hashmap<int, std::string, 8>;
+ using KV = typename Map::KeyValue;
+ Map map;
+ map.Add(1, "one");
+ map.Add(4, "four");
+ map.Add(3, "three");
+ map.Add(2, "two");
+ EXPECT_THAT(map, testing::UnorderedElementsAre(KV{1, "one"}, KV{2, "two"}, KV{3, "three"},
+ KV{4, "four"}));
+}
+
+TEST(Hashmap, AddMany) {
+ Hashmap<int, std::string, 8> map;
+ for (size_t i = 0; i < kPrimes.size(); i++) {
+ int prime = kPrimes[i];
+ ASSERT_TRUE(map.Add(prime, std::to_string(prime))) << "i: " << i;
+ ASSERT_FALSE(map.Add(prime, std::to_string(prime))) << "i: " << i;
+ ASSERT_EQ(map.Count(), i + 1);
+ }
+ ASSERT_EQ(map.Count(), kPrimes.size());
+ for (int prime : kPrimes) {
+ ASSERT_TRUE(map.Contains(prime)) << prime;
+ ASSERT_EQ(map.Get(prime), std::to_string(prime)) << prime;
+ }
+}
+
+TEST(Hashmap, GetOrCreate) {
+ Hashmap<int, std::string, 8> map;
+ EXPECT_EQ(map.GetOrCreate(0, [&] { return "zero"; }), "zero");
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_EQ(map.Get(0), "zero");
+
+ bool create_called = false;
+ EXPECT_EQ(map.GetOrCreate(0,
+ [&] {
+ create_called = true;
+ return "oh noes";
+ }),
+ "zero");
+ EXPECT_FALSE(create_called);
+ EXPECT_EQ(map.Count(), 1u);
+ EXPECT_EQ(map.Get(0), "zero");
+
+ EXPECT_EQ(map.GetOrCreate(1, [&] { return "one"; }), "one");
+ EXPECT_EQ(map.Count(), 2u);
+ EXPECT_EQ(map.Get(1), "one");
+}
+
+TEST(Hashmap, Soak) {
+ std::mt19937 rnd;
+ std::unordered_map<std::string, std::string> reference;
+ Hashmap<std::string, std::string, 8> map;
+ for (size_t i = 0; i < 1000000; i++) {
+ std::string key = std::to_string(rnd() & 64);
+ std::string value = "V" + key;
+ switch (rnd() % 7) {
+ case 0: { // Add
+ auto expected = reference.emplace(key, value).second;
+ EXPECT_EQ(map.Add(key, value), expected) << "i:" << i;
+ EXPECT_EQ(map.Get(key), value) << "i:" << i;
+ EXPECT_TRUE(map.Contains(key)) << "i:" << i;
+ break;
+ }
+ case 1: { // Replace
+ reference[key] = value;
+ map.Replace(key, value);
+ EXPECT_EQ(map.Get(key), value) << "i:" << i;
+ EXPECT_TRUE(map.Contains(key)) << "i:" << i;
+ break;
+ }
+ case 2: { // Remove
+ auto expected = reference.erase(key) != 0;
+ EXPECT_EQ(map.Remove(key), expected) << "i:" << i;
+ EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i;
+ EXPECT_FALSE(map.Contains(key)) << "i:" << i;
+ break;
+ }
+ case 3: { // Contains
+ auto expected = reference.count(key) != 0;
+ EXPECT_EQ(map.Contains(key), expected) << "i:" << i;
+ break;
+ }
+ case 4: { // Get
+ if (reference.count(key) != 0) {
+ auto expected = reference[key];
+ EXPECT_EQ(map.Get(key), expected) << "i:" << i;
+ } else {
+ EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i;
+ }
+ break;
+ }
+ case 5: { // Copy / Move
+ Hashmap<std::string, std::string, 8> tmp(map);
+ map = std::move(tmp);
+ break;
+ }
+ case 6: { // Clear
+ reference.clear();
+ map.Clear();
+ break;
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tint::utils
diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h
new file mode 100644
index 0000000..f88a304
--- /dev/null
+++ b/src/tint/utils/hashset.h
@@ -0,0 +1,508 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_UTILS_HASHSET_H_
+#define SRC_TINT_UTILS_HASHSET_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <optional>
+#include <tuple>
+#include <utility>
+
+#include "src/tint/debug.h"
+#include "src/tint/utils/vector.h"
+
+namespace tint::utils {
+
+/// Action taken by Hashset::Insert()
+enum class AddAction {
+ /// Insert() added a new entry to the Hashset
+ kAdded,
+ /// Insert() replaced an existing entry in the Hashset
+ kReplaced,
+ /// Insert() found an existing entry, which was not replaced.
+ kKeptExisting,
+};
+
+/// An unordered set that uses a robin-hood hashing algorithm.
+/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html
+template <typename T, size_t N, typename HASH = std::hash<T>, typename EQUAL = std::equal_to<T>>
+class Hashset {
+ /// A slot is a single entry in the underlying vector.
+ /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance
+ /// will be zero.
+ struct Slot {
+ template <typename V>
+ bool Equals(size_t value_hash, const V& val) const {
+ return value_hash == hash && EQUAL()(val, value.value());
+ }
+
+ /// The slot value. If this does not contain a value, then the slot is vacant.
+ std::optional<T> value;
+ /// The precomputed hash of value.
+ size_t hash = 0;
+ size_t distance = 0;
+ };
+
+ /// The target length of the underlying vector length in relation to the number of entries in
+ /// the set, expressed as a percentage. For example a value of `150` would mean there would be
+ /// at least 50% more slots than the number of set entries.
+ static constexpr size_t kRehashFactor = 150;
+
+ /// @returns the target slot vector size to hold `n` set entries.
+ static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; }
+
+ /// The fixed-size slot vector length, based on N and kRehashFactor.
+ static constexpr size_t kNumFixedSlots = NumSlots(N);
+
+ /// The minimum number of slots for the set.
+ static constexpr size_t kMinSlots = std::max<size_t>(kNumFixedSlots, 4);
+
+ public:
+ /// Iterator for entries in the set
+ class Iterator {
+ public:
+ /// @returns the value pointed to by this iterator
+ const T* operator->() const { return ¤t->value.value(); }
+
+ /// Increments the iterator
+ /// @returns this iterator
+ Iterator& operator++() {
+ if (current == end) {
+ return *this;
+ }
+ current++;
+ SkipToNextValue();
+ return *this;
+ }
+
+ /// Equality operator
+ /// @param other the other iterator to compare this iterator to
+ /// @returns true if this iterator is equal to other
+ bool operator==(const Iterator& other) const { return current == other.current; }
+
+ /// Inequality operator
+ /// @param other the other iterator to compare this iterator to
+ /// @returns true if this iterator is not equal to other
+ bool operator!=(const Iterator& other) const { return current != other.current; }
+
+ /// @returns a reference to the value at the iterator
+ const T& operator*() const { return current->value.value(); }
+
+ private:
+ /// Friend class
+ friend class Hashset;
+
+ Iterator(const Slot* c, const Slot* e) : current(c), end(e) { SkipToNextValue(); }
+
+ /// Moves the iterator forward, stopping at the next slot that is not empty.
+ void SkipToNextValue() {
+ while (current != end && !current->value.has_value()) {
+ current++;
+ }
+ }
+
+ const Slot* current; /// The slot the iterator is pointing to
+ const Slot* end; /// One past the last slot in the set
+ };
+
+ /// Type of `T`.
+ using value_type = T;
+
+ /// Constructor
+ Hashset() { slots_.Resize(kMinSlots); }
+
+ /// Copy constructor
+ /// @param other the other Hashset to copy
+ Hashset(const Hashset& other) = default;
+
+ /// Move constructor
+ /// @param other the other Hashset to move
+ Hashset(Hashset&& other) = default;
+
+ /// Destructor
+ ~Hashset() { Clear(); }
+
+ /// Copy-assignment operator
+ /// @param other the other Hashset to copy
+ /// @returns this so calls can be chained
+ Hashset& operator=(const Hashset& other) = default;
+
+ /// Move-assignment operator
+ /// @param other the other Hashset to move
+ /// @returns this so calls can be chained
+ Hashset& operator=(Hashset&& other) = default;
+
+ /// Removes all entries from the set.
+ void Clear() {
+ slots_.Clear(); // Destructs all entries
+ slots_.Resize(kMinSlots);
+ count_ = 0;
+ }
+
+ /// Result of Add()
+ struct AddResult {
+ /// Whether the insert replaced or added a new entry to the set.
+ AddAction action = AddAction::kAdded;
+ /// A pointer to the inserted entry.
+ /// @warning do not modify this pointer in a way that would cause the equality or hash of
+ /// the entry to change. Doing this will corrupt the Hashset.
+ T* entry = nullptr;
+
+ /// @returns true if the entry was added to the set, or an existing entry was replaced.
+ operator bool() const { return action != AddAction::kKeptExisting; }
+ };
+
+ /// Adds a value to the set, if the set does not already contain an entry equal to `value`.
+ /// @param value the value to add to the set.
+ /// @returns A AddResult describing the result of the add
+ /// @warning do not modify the inserted entry in a way that would cause the equality of hash of
+ /// the entry to change. Doing this will corrupt the Hashset.
+ template <typename V>
+ AddResult Add(V&& value) {
+ return Put<PutMode::kAdd>(std::forward<V>(value));
+ }
+
+ /// Adds a value to the set, replacing any entry equal to `value`.
+ /// @param value the value to add to the set.
+ /// @returns A AddResult describing the result of the replace
+ template <typename V>
+ AddResult Replace(V&& value) {
+ return Put<PutMode::kReplace>(std::forward<V>(value));
+ }
+
+ /// Removes an entry from the set.
+ /// @param value the value to remove from the set.
+ /// @returns true if an entry was removed.
+ template <typename V>
+ bool Remove(const V& value) {
+ const auto [found, start] = IndexOf(value);
+ if (!found) {
+ return false;
+ }
+
+ // Shuffle the entries backwards until we either find a free slot, or a slot that has zero
+ // distance.
+ Slot* prev = nullptr;
+ Scan(start, [&](size_t, size_t index) {
+ auto& slot = slots_[index];
+ if (prev) {
+ // note: `distance == 0` also includes empty slots.
+ if (slot.distance == 0) {
+ // Clear the previous slot, and stop shuffling.
+ *prev = {};
+ return Action::kStop;
+ } else {
+ // Shuffle the slot backwards.
+ prev->value = std::move(slot.value);
+ prev->hash = slot.hash;
+ prev->distance = slot.distance - 1;
+ }
+ }
+ prev = &slot;
+ return Action::kContinue;
+ });
+
+ // Entry was removed.
+ count_--;
+
+ return true;
+ }
+
+ /// @param value the value to search for.
+ /// @returns the value of the entry that is equal to `value`, or no value if the entry was not
+ /// found.
+ template <typename V>
+ std::optional<T> Get(const V& value) const {
+ if (const auto [found, index] = IndexOf(value); found) {
+ return slots_[index].value.value();
+ }
+ return std::nullopt;
+ }
+
+ /// @param value the value to search for.
+ /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does
+ /// not contain the given value.
+ template <typename V>
+ const T* Find(const V& value) const {
+ const auto [found, index] = IndexOf(value);
+ return found ? &slots_[index].value.value() : nullptr;
+ }
+
+ /// @param value the value to search for.
+ /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does
+ /// not contain the given value.
+ /// @warning do not modify the inserted entry in a way that would cause the equality of hash of
+ /// the entry to change. Doing this will corrupt the Hashset.
+ template <typename V>
+ T* Find(const V& value) {
+ const auto [found, index] = IndexOf(value);
+ return found ? &slots_[index].value.value() : nullptr;
+ }
+
+ /// Checks whether an entry exists in the set
+ /// @param value the value to search for.
+ /// @returns true if the set contains an entry with the given value.
+ template <typename V>
+ bool Contains(const V& value) const {
+ const auto [found, _] = IndexOf(value);
+ return found;
+ }
+
+ /// Pre-allocates memory so that the set can hold at least `capacity` entries.
+ /// @param capacity the new capacity of the set.
+ void Reserve(size_t capacity) {
+ // Calculate the number of slots required to hold `capacity` entries.
+ const size_t num_slots = std::max(NumSlots(capacity), kMinSlots);
+ if (slots_.Length() >= num_slots) {
+ // Already have enough slots.
+ return;
+ }
+
+ // Move all the values out of the set and into a vector.
+ Vector<T, N> values;
+ values.Reserve(count_);
+ for (auto& slot : slots_) {
+ if (slot.value.has_value()) {
+ values.Push(std::move(slot.value.value()));
+ }
+ }
+
+ // Clear the set, grow the number of slots.
+ Clear();
+ slots_.Resize(num_slots);
+
+ // As the number of slots has grown, the slot indices will have changed from before, so
+ // re-add all the values back into the set.
+ for (auto& value : values) {
+ Add(std::move(value));
+ }
+ }
+
+ /// @returns the number of entries in the set.
+ size_t Count() const { return count_; }
+
+ /// @returns true if the set contains no entries.
+ bool IsEmpty() const { return count_ == 0; }
+
+ /// @returns an iterator to the start of the set.
+ Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; }
+
+ /// @returns an iterator to the end of the set.
+ Iterator end() const { return Iterator{slots_.end(), slots_.end()}; }
+
+ /// A debug function for checking that the set is in good health.
+ /// Asserts if the set is corrupted.
+ void ValidateIntegrity() const {
+ size_t num_alive = 0;
+ for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
+ const auto& slot = slots_[slot_idx];
+ if (slot.value.has_value()) {
+ num_alive++;
+ auto const [index, hash] = Hash(slot.value.value());
+ TINT_ASSERT(Utils, hash == slot.hash);
+ TINT_ASSERT(Utils, slot_idx == Wrap(index + slot.distance));
+ }
+ }
+ TINT_ASSERT(Utils, num_alive == count_);
+ }
+
+ private:
+ /// The behaviour of Put() when an entry already exists with the given key.
+ enum class PutMode {
+ /// Do not replace existing entries with the new value.
+ kAdd,
+ /// Replace existing entries with the new value.
+ kReplace,
+ };
+ /// The common implementation for Add() and Replace()
+ /// @param value the value to add to the set.
+ /// @returns A AddResult describing the result of the insertion
+ template <PutMode MODE, typename V>
+ AddResult Put(V&& value) {
+ // Ensure the set can fit a new entry
+ if (ShouldRehash(count_ + 1)) {
+ Reserve((count_ + 1) * 2);
+ }
+
+ const auto hash = Hash(value);
+
+ AddResult result{};
+ Scan(hash.scan_start, [&](size_t distance, size_t index) {
+ auto& slot = slots_[index];
+ if (!slot.value.has_value()) {
+ // Found an empty slot.
+ // Place value directly into the slot, and we're done.
+ slot.value.emplace(std::forward<V>(value));
+ slot.hash = hash.value;
+ slot.distance = distance;
+ count_++;
+ result = AddResult{AddAction::kAdded, &slot.value.value()};
+ return Action::kStop;
+ }
+
+ // Slot has an entry
+
+ if (slot.Equals(hash.value, value)) {
+ // Slot is equal to value. Replace or preserve?
+ if constexpr (MODE == PutMode::kReplace) {
+ slot.value = std::forward<V>(value);
+ result = AddResult{AddAction::kReplaced, &slot.value.value()};
+ } else {
+ result = AddResult{AddAction::kKeptExisting, &slot.value.value()};
+ }
+ return Action::kStop;
+ }
+
+ if (slot.distance < distance) {
+ // Existing slot has a closer distance than the value we're attempting to insert.
+ // Steal from the rich!
+ // Move the current slot to a temporary (evicted), and put the value into the slot.
+ Slot evicted{std::forward<V>(value), hash.value, distance};
+ std::swap(evicted, slot);
+
+ // Find a new home for the evicted slot.
+ evicted.distance++; // We've already swapped at index.
+ InsertShuffle(Wrap(index + 1), std::move(evicted));
+
+ count_++;
+ result = AddResult{AddAction::kAdded, &slot.value.value()};
+
+ return Action::kStop;
+ }
+ return Action::kContinue;
+ });
+
+ return result;
+ }
+
+ /// Return type of the Scan() callback.
+ enum class Action {
+ /// Continue scanning for a slot
+ kContinue,
+ /// Immediately stop scanning for a slot
+ kStop,
+ };
+
+ /// Sequentially visits each of the slots starting with the slot with the index `start`, calling
+ /// the callback function `f` for each slot until `f` returns Action::kStop.
+ /// `f` must be a function with the signature `Action(size_t distance, size_t index)`.
+ /// `f` must return Action::kStop within one whole cycle of the slots.
+ template <typename F>
+ void Scan(size_t start, F&& f) const {
+ size_t index = start;
+ for (size_t distance = 0; distance < slots_.Length(); distance++) {
+ if (f(distance, index) == Action::kStop) {
+ return;
+ }
+ index = Wrap(index + 1);
+ }
+ tint::diag::List diags;
+ TINT_ICE(Utils, diags) << "Hashset::Scan() looped entire set without finding a slot";
+ }
+
+ /// HashResult is the return value of Hash()
+ struct HashResult {
+ /// The target (zero-distance) slot index for the value.
+ size_t scan_start;
+ /// The calculated hash of the value.
+ size_t value;
+ };
+
+ /// @returns a tuple holding the target slot index for the given value, and the hash of the
+ /// value, respectively.
+ template <typename V>
+ HashResult Hash(const V& value) const {
+ size_t hash = HASH()(value);
+ size_t index = Wrap(hash);
+ return {index, hash};
+ }
+
+ /// Looks for the value in the set.
+ /// @returns a tuple holding a boolean representing whether the value was found in the set, and
+ /// if found, the index of the slot that holds the value.
+ template <typename V>
+ std::tuple<bool, size_t> IndexOf(const V& value) const {
+ const auto hash = Hash(value);
+
+ bool found = false;
+ size_t idx = 0;
+
+ Scan(hash.scan_start, [&](size_t distance, size_t index) {
+ auto& slot = slots_[index];
+ if (!slot.value.has_value()) {
+ return Action::kStop;
+ }
+ if (slot.Equals(hash.value, value)) {
+ found = true;
+ idx = index;
+ return Action::kStop;
+ }
+ if (slot.distance < distance) {
+ // If the slot distance is less than the current probe distance, then the slot must
+ // be for entry that has an index that comes after value. In this situation, we know
+ // that the set does not contain the value, as it would have been found before this
+ // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html
+ // suggests that the condition should inverted, but this is wrong.
+ return Action::kStop;
+ }
+ return Action::kContinue;
+ });
+
+ return {found, idx};
+ }
+
+ /// Shuffles slots for an insertion that has been placed one slot before `start`.
+ /// @param evicted the slot content that was evicted for the insertion.
+ void InsertShuffle(size_t start, Slot evicted) {
+ Scan(start, [&](size_t, size_t index) {
+ auto& slot = slots_[index];
+
+ if (!slot.value.has_value()) {
+ // Empty slot found for evicted.
+ slot = std::move(evicted);
+ return Action::kStop; // We're done.
+ }
+
+ if (slot.distance < evicted.distance) {
+ // Occupied slot has shorter distance to evicted.
+ // Swap slot and evicted.
+ std::swap(slot, evicted);
+ }
+
+ // evicted moves further from the target slot...
+ evicted.distance++;
+
+ return Action::kContinue;
+ });
+ }
+
+ /// @returns true if the set should grow the slot vector, and rehash the items.
+ bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); }
+
+ /// Wrap returns the index value modulo the number of slots.
+ size_t Wrap(size_t index) const { return index % slots_.Length(); }
+
+ /// The vector of slots. The vector length is equal to its capacity.
+ Vector<Slot, kNumFixedSlots> slots_;
+
+ /// The number of entries in the set.
+ size_t count_ = 0;
+};
+
+} // namespace tint::utils
+
+#endif // SRC_TINT_UTILS_HASHSET_H_
diff --git a/src/tint/utils/hashset_test.cc b/src/tint/utils/hashset_test.cc
new file mode 100644
index 0000000..4213b32
--- /dev/null
+++ b/src/tint/utils/hashset_test.cc
@@ -0,0 +1,142 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/utils/hashset.h"
+
+#include <array>
+#include <random>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+
+#include "gmock/gmock.h"
+
+namespace tint::utils {
+namespace {
+
+constexpr std::array kPrimes{
+ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53,
+ 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131,
+ 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223,
+ 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311,
+ 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409,
+};
+
+TEST(Hashset, Empty) {
+ Hashset<std::string, 8> set;
+ EXPECT_EQ(set.Count(), 0u);
+}
+
+TEST(Hashset, AddRemove) {
+ Hashset<std::string, 8> set;
+ EXPECT_TRUE(set.Add("hello"));
+ EXPECT_EQ(set.Count(), 1u);
+ EXPECT_TRUE(set.Contains("hello"));
+ EXPECT_FALSE(set.Contains("world"));
+ EXPECT_FALSE(set.Add("hello"));
+ EXPECT_EQ(set.Count(), 1u);
+ EXPECT_TRUE(set.Remove("hello"));
+ EXPECT_EQ(set.Count(), 0u);
+ EXPECT_FALSE(set.Contains("hello"));
+ EXPECT_FALSE(set.Contains("world"));
+}
+
+TEST(Hashset, AddMany) {
+ Hashset<int, 8> set;
+ for (size_t i = 0; i < kPrimes.size(); i++) {
+ int prime = kPrimes[i];
+ ASSERT_TRUE(set.Add(prime)) << "i: " << i;
+ ASSERT_FALSE(set.Add(prime)) << "i: " << i;
+ ASSERT_EQ(set.Count(), i + 1);
+ set.ValidateIntegrity();
+ }
+ ASSERT_EQ(set.Count(), kPrimes.size());
+ for (int prime : kPrimes) {
+ ASSERT_TRUE(set.Contains(prime)) << prime;
+ }
+}
+
+TEST(Hashset, Iterator) {
+ Hashset<std::string, 8> set;
+ set.Add("one");
+ set.Add("four");
+ set.Add("three");
+ set.Add("two");
+ EXPECT_THAT(set, testing::UnorderedElementsAre("one", "two", "three", "four"));
+}
+
+TEST(Hashset, Soak) {
+ std::mt19937 rnd;
+ std::unordered_set<std::string> reference;
+ Hashset<std::string, 8> set;
+ for (size_t i = 0; i < 1000000; i++) {
+ std::string value = std::to_string(rnd() & 0x100);
+ switch (rnd() % 8) {
+ case 0: { // Add
+ auto expected = reference.emplace(value).second;
+ ASSERT_EQ(set.Add(value), expected) << "i: " << i;
+ ASSERT_TRUE(set.Contains(value)) << "i: " << i;
+ break;
+ }
+ case 1: { // Replace
+ reference.emplace(value);
+ set.Replace(value);
+ ASSERT_TRUE(set.Contains(value)) << "i: " << i;
+ break;
+ }
+ case 2: { // Remove
+ auto expected = reference.erase(value) != 0;
+ ASSERT_EQ(set.Remove(value), expected) << "i: " << i;
+ ASSERT_FALSE(set.Contains(value)) << "i: " << i;
+ break;
+ }
+ case 3: { // Contains
+ auto expected = reference.count(value) != 0;
+ ASSERT_EQ(set.Contains(value), expected) << "i: " << i;
+ break;
+ }
+ case 4: { // Get
+ if (reference.count(value) != 0) {
+ ASSERT_TRUE(set.Get(value).has_value()) << "i: " << i;
+ ASSERT_EQ(set.Get(value), value) << "i: " << i;
+ } else {
+ ASSERT_FALSE(set.Get(value).has_value()) << "i: " << i;
+ }
+ break;
+ }
+ case 5: { // Find
+ if (reference.count(value) != 0) {
+ ASSERT_EQ(*set.Find(value), value) << "i: " << i;
+ } else {
+ ASSERT_EQ(set.Find(value), nullptr) << "i: " << i;
+ }
+ break;
+ }
+ case 6: { // Copy / Move
+ Hashset<std::string, 8> tmp(set);
+ set = std::move(tmp);
+ break;
+ }
+ case 7: { // Clear
+ reference.clear();
+ set.Clear();
+ break;
+ }
+ }
+ set.ValidateIntegrity();
+ }
+}
+
+} // namespace
+} // namespace tint::utils
diff --git a/src/tint/utils/unique_vector.h b/src/tint/utils/unique_vector.h
index d4018b2..bda090c 100644
--- a/src/tint/utils/unique_vector.h
+++ b/src/tint/utils/unique_vector.h
@@ -21,17 +21,15 @@
#include <utility>
#include <vector>
+#include "src/tint/utils/hashset.h"
+#include "src/tint/utils/vector.h"
+
namespace tint::utils {
/// UniqueVector is an ordered container that only contains unique items.
/// Attempting to add a duplicate is a no-op.
-template <typename T, typename HASH = std::hash<T>, typename EQUAL = std::equal_to<T>>
+template <typename T, size_t N, typename HASH = std::hash<T>, typename EQUAL = std::equal_to<T>>
struct UniqueVector {
- /// The iterator returned by begin() and end()
- using ConstIterator = typename std::vector<T>::const_iterator;
- /// The iterator returned by rbegin() and rend()
- using ConstReverseIterator = typename std::vector<T>::const_reverse_iterator;
-
/// Constructor
UniqueVector() = default;
@@ -40,7 +38,7 @@
/// elements will be removed.
explicit UniqueVector(std::vector<T>&& v) {
for (auto& el : v) {
- add(el);
+ Add(el);
}
}
@@ -48,10 +46,9 @@
/// already contain the given item.
/// @param item the item to append to the end of the vector
/// @returns true if the item was added, otherwise false.
- bool add(const T& item) {
- if (set.count(item) == 0) {
- vector.emplace_back(item);
- set.emplace(item);
+ bool Add(const T& item) {
+ if (set.Add(item)) {
+ vector.Push(item);
return true;
}
return false;
@@ -59,7 +56,7 @@
/// @returns true if the vector contains `item`
/// @param item the item
- bool contains(const T& item) const { return set.count(item); }
+ bool Contains(const T& item) const { return set.Contains(item); }
/// @param i the index of the element to retrieve
/// @returns the element at the index `i`
@@ -70,48 +67,50 @@
const T& operator[](size_t i) const { return vector[i]; }
/// @returns true if the vector is empty
- bool empty() const { return vector.empty(); }
+ bool IsEmpty() const { return vector.IsEmpty(); }
/// @returns the number of items in the vector
- size_t size() const { return vector.size(); }
+ size_t Length() const { return vector.Length(); }
/// @returns the pointer to the first element in the vector, or nullptr if the vector is empty.
- const T* data() const { return vector.empty() ? nullptr : vector.data(); }
+ const T* Data() const { return vector.IsEmpty() ? nullptr : &vector[0]; }
/// @returns an iterator to the beginning of the vector
- ConstIterator begin() const { return vector.begin(); }
+ auto begin() const { return vector.begin(); }
/// @returns an iterator to the end of the vector
- ConstIterator end() const { return vector.end(); }
+ auto end() const { return vector.end(); }
/// @returns an iterator to the beginning of the reversed vector
- ConstReverseIterator rbegin() const { return vector.rbegin(); }
+ auto rbegin() const { return vector.rbegin(); }
/// @returns an iterator to the end of the reversed vector
- ConstReverseIterator rend() const { return vector.rend(); }
+ auto rend() const { return vector.rend(); }
/// @returns a const reference to the internal vector
- operator const std::vector<T>&() const { return vector; }
+ operator const Vector<T, N>&() const { return vector; }
+
+ /// @returns the std::move()'d vector.
+ /// @note The UniqueVector must not be used after calling this method
+ VectorRef<T> Release() { return std::move(vector); }
/// Pre-allocates `count` elements in the vector and set
/// @param count the number of elements to pre-allocate
- void reserve(size_t count) {
- vector.reserve(count);
- set.reserve(count);
+ void Reserve(size_t count) {
+ vector.Reserve(count);
+ set.Reserve(count);
}
/// Removes the last element from the vector
/// @returns the popped element
- T pop_back() {
- auto el = std::move(vector.back());
- set.erase(el);
- vector.pop_back();
- return el;
+ T Pop() {
+ set.Remove(vector.Back());
+ return vector.Pop();
}
private:
- std::vector<T> vector;
- std::unordered_set<T, HASH, EQUAL> set;
+ Vector<T, N> vector;
+ Hashset<T, N, HASH, EQUAL> set;
};
} // namespace tint::utils
diff --git a/src/tint/utils/unique_vector_test.cc b/src/tint/utils/unique_vector_test.cc
index 035ebf8..9b015c2 100644
--- a/src/tint/utils/unique_vector_test.cc
+++ b/src/tint/utils/unique_vector_test.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/utils/unique_vector.h"
+
+#include <vector>
+
#include "src/tint/utils/reverse.h"
#include "gtest/gtest.h"
@@ -21,16 +24,16 @@
namespace {
TEST(UniqueVectorTest, Empty) {
- UniqueVector<int> unique_vec;
- EXPECT_EQ(unique_vec.size(), 0u);
- EXPECT_EQ(unique_vec.empty(), true);
+ UniqueVector<int, 4> unique_vec;
+ EXPECT_EQ(unique_vec.Length(), 0u);
+ EXPECT_EQ(unique_vec.IsEmpty(), true);
EXPECT_EQ(unique_vec.begin(), unique_vec.end());
}
TEST(UniqueVectorTest, MoveConstructor) {
- UniqueVector<int> unique_vec(std::vector<int>{0, 3, 2, 1, 2});
- EXPECT_EQ(unique_vec.size(), 4u);
- EXPECT_EQ(unique_vec.empty(), false);
+ UniqueVector<int, 4> unique_vec(std::vector<int>{0, 3, 2, 1, 2});
+ EXPECT_EQ(unique_vec.Length(), 4u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 3);
EXPECT_EQ(unique_vec[2], 2);
@@ -38,12 +41,12 @@
}
TEST(UniqueVectorTest, AddUnique) {
- UniqueVector<int> unique_vec;
- unique_vec.add(0);
- unique_vec.add(1);
- unique_vec.add(2);
- EXPECT_EQ(unique_vec.size(), 3u);
- EXPECT_EQ(unique_vec.empty(), false);
+ UniqueVector<int, 4> unique_vec;
+ unique_vec.Add(0);
+ unique_vec.Add(1);
+ unique_vec.Add(2);
+ EXPECT_EQ(unique_vec.Length(), 3u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
int i = 0;
for (auto n : unique_vec) {
EXPECT_EQ(n, i);
@@ -59,15 +62,15 @@
}
TEST(UniqueVectorTest, AddDuplicates) {
- UniqueVector<int> unique_vec;
- unique_vec.add(0);
- unique_vec.add(0);
- unique_vec.add(0);
- unique_vec.add(1);
- unique_vec.add(1);
- unique_vec.add(2);
- EXPECT_EQ(unique_vec.size(), 3u);
- EXPECT_EQ(unique_vec.empty(), false);
+ UniqueVector<int, 4> unique_vec;
+ unique_vec.Add(0);
+ unique_vec.Add(0);
+ unique_vec.Add(0);
+ unique_vec.Add(1);
+ unique_vec.Add(1);
+ unique_vec.Add(2);
+ EXPECT_EQ(unique_vec.Length(), 3u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
int i = 0;
for (auto n : unique_vec) {
EXPECT_EQ(n, i);
@@ -83,17 +86,17 @@
}
TEST(UniqueVectorTest, AsVector) {
- UniqueVector<int> unique_vec;
- unique_vec.add(0);
- unique_vec.add(0);
- unique_vec.add(0);
- unique_vec.add(1);
- unique_vec.add(1);
- unique_vec.add(2);
+ UniqueVector<int, 4> unique_vec;
+ unique_vec.Add(0);
+ unique_vec.Add(0);
+ unique_vec.Add(0);
+ unique_vec.Add(1);
+ unique_vec.Add(1);
+ unique_vec.Add(2);
- const std::vector<int>& vec = unique_vec;
- EXPECT_EQ(vec.size(), 3u);
- EXPECT_EQ(unique_vec.empty(), false);
+ const utils::Vector<int, 4>& vec = unique_vec;
+ EXPECT_EQ(vec.Length(), 3u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
int i = 0;
for (auto n : vec) {
EXPECT_EQ(n, i);
@@ -106,46 +109,46 @@
}
TEST(UniqueVectorTest, PopBack) {
- UniqueVector<int> unique_vec;
- unique_vec.add(0);
- unique_vec.add(2);
- unique_vec.add(1);
+ UniqueVector<int, 4> unique_vec;
+ unique_vec.Add(0);
+ unique_vec.Add(2);
+ unique_vec.Add(1);
- EXPECT_EQ(unique_vec.pop_back(), 1);
- EXPECT_EQ(unique_vec.size(), 2u);
- EXPECT_EQ(unique_vec.empty(), false);
+ EXPECT_EQ(unique_vec.Pop(), 1);
+ EXPECT_EQ(unique_vec.Length(), 2u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 2);
- EXPECT_EQ(unique_vec.pop_back(), 2);
- EXPECT_EQ(unique_vec.size(), 1u);
- EXPECT_EQ(unique_vec.empty(), false);
+ EXPECT_EQ(unique_vec.Pop(), 2);
+ EXPECT_EQ(unique_vec.Length(), 1u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
- unique_vec.add(1);
+ unique_vec.Add(1);
- EXPECT_EQ(unique_vec.size(), 2u);
- EXPECT_EQ(unique_vec.empty(), false);
+ EXPECT_EQ(unique_vec.Length(), 2u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 1);
- EXPECT_EQ(unique_vec.pop_back(), 1);
- EXPECT_EQ(unique_vec.size(), 1u);
- EXPECT_EQ(unique_vec.empty(), false);
+ EXPECT_EQ(unique_vec.Pop(), 1);
+ EXPECT_EQ(unique_vec.Length(), 1u);
+ EXPECT_EQ(unique_vec.IsEmpty(), false);
EXPECT_EQ(unique_vec[0], 0);
- EXPECT_EQ(unique_vec.pop_back(), 0);
- EXPECT_EQ(unique_vec.size(), 0u);
- EXPECT_EQ(unique_vec.empty(), true);
+ EXPECT_EQ(unique_vec.Pop(), 0);
+ EXPECT_EQ(unique_vec.Length(), 0u);
+ EXPECT_EQ(unique_vec.IsEmpty(), true);
}
TEST(UniqueVectorTest, Data) {
- UniqueVector<int> unique_vec;
- EXPECT_EQ(unique_vec.data(), nullptr);
+ UniqueVector<int, 4> unique_vec;
+ EXPECT_EQ(unique_vec.Data(), nullptr);
- unique_vec.add(42);
- EXPECT_EQ(unique_vec.data(), &unique_vec[0]);
- EXPECT_EQ(*unique_vec.data(), 42);
+ unique_vec.Add(42);
+ EXPECT_EQ(unique_vec.Data(), &unique_vec[0]);
+ EXPECT_EQ(*unique_vec.Data(), 42);
}
} // namespace
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index 518a693..b718840 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -104,22 +104,42 @@
auto rend() const { return std::reverse_iterator<const T*>(begin()); }
};
+namespace detail {
+
+/// Private implementation of tint::utils::CanReinterpretSlice.
+/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of
+/// the base classes, which can be troublesome if the slice is of an incomplete type.
+template <typename TO, typename FROM>
+struct CanReinterpretSlice {
+ /// True if a slice of FROM can be reinterpreted as a slice of TO
+ static constexpr bool value =
+ // Both TO and FROM are pointers
+ (std::is_pointer_v<TO> && std::is_pointer_v<FROM>)&& //
+ // const can only be applied, not removed
+ (std::is_const_v<std::remove_pointer_t<TO>> ||
+ !std::is_const_v<std::remove_pointer_t<FROM>>)&& //
+ // TO and FROM are both Castable
+ IsCastable<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>> && //
+ // FROM is of, or derives from TO
+ traits::IsTypeOrDerived<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>>;
+};
+
+/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types.
+template <typename T>
+struct CanReinterpretSlice<T, T> {
+ /// Always `true` as TO and FROM are the same type.
+ static constexpr bool value = true;
+};
+
+} // namespace detail
+
/// Evaluates whether a `vector<FROM>` and be reinterpreted as a `vector<TO>`.
/// Vectors can be reinterpreted if both `FROM` and `TO` are pointers to a type that derives from
/// CastableBase, and the pointee type of `TO` is of the same type as, or is an ancestor of the
/// pointee type of `FROM`. Vectors of non-`const` Castable pointers can be converted to a vector of
/// `const` Castable pointers.
template <typename TO, typename FROM>
-static constexpr bool CanReinterpretSlice =
- // TO and FROM are both pointer types
- std::is_pointer_v<TO> && std::is_pointer_v<FROM> && //
- // const can only be applied, not removed
- (std::is_const_v<std::remove_pointer_t<TO>> ||
- !std::is_const_v<std::remove_pointer_t<FROM>>)&& //
- // TO and FROM are both Castable
- IsCastable<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>> &&
- // FROM is of, or derives from TO
- traits::IsTypeOrDerived<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>>;
+static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice<TO, FROM>::value;
/// Reinterprets `const Slice<FROM>*` as `const Slice<TO>*`
/// @param slice a pointer to the slice to reinterpret
@@ -440,6 +460,21 @@
/// @returns the end for a reverse iterator
auto rend() const { return impl_.slice.rend(); }
+ /// Equality operator
+ /// @param other the other vector
+ /// @returns true if this vector is the same length as `other`, and all elements are equal.
+ bool operator==(const Vector& other) const {
+ const size_t len = Length();
+ if (len == other.Length()) {
+ for (size_t i = 0; i < len; i++) {
+ if ((*this)[i] != other[i]) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
private:
/// Friend class (differing specializations of this class)
template <typename, size_t>