Import Tint changes from Dawn
Changes:
- 7517e213cf031d0216cca3c9e5ab4f4b27ac49a4 Update `workgroup_size` to use `expression`. by dan sinclair <dsinclair@chromium.org>
- 436b4a2bd06e0b058c25e69ffd8c2cd650b1752d Drop `maybe_` from method names. by dan sinclair <dsinclair@chromium.org>
- 202362f47db4f81df9dcee8d97a4fef295cf4579 Convert to new expression grammar by dan sinclair <dsinclair@chromium.org>
- acdf6e1a53cac328a863006369111b7f74402c09 Remove `ast::VariableBindingPoint` in favour of `sem::Bin... by dan sinclair <dsinclair@chromium.org>
- 93c2d559a12bf72c7eb96b6a2a5a5bb77ecdc4b4 tint/utils: Default to using Hasher instead of std::hash by Ben Clayton <bclayton@google.com>
- 18dc315ccb830b4c52dd6b0fcab47d255bd59897 tint::CloneContext: Use utils::Hashset and utils::Hashmap by Ben Clayton <bclayton@google.com>
- b90b6bff1d46056c7471fa39907d10b9fa6636d2 tint: Minor no-op cleanup changes by Ben Clayton <bclayton@google.com>
- e3f2005b2ddc8467adacaa84808f00f936148412 tint/utils: Cleanup & optimize hash utilities by Ben Clayton <bclayton@google.com>
- e5d337171a1d20fa89878a09892ca9d9228be6dd tint: Castable - optimize IsAnyOf() by Ben Clayton <bclayton@google.com>
- cdcc85973bb0c9871100940d44b8e31680c21e2c tint: Clean up AddSpirvBlockAttribute by Ben Clayton <bclayton@google.com>
- 4964d9bc20dabc5d8773d588ec742556ada1b7f1 Convert `@align` to hold an expression. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 7517e213cf031d0216cca3c9e5ab4f4b27ac49a4
Change-Id: I7f30978e56d735371e311f209b7ca1107387e421
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/100280
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 8283b45..d140e1e 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1350,7 +1350,6 @@
sources = [
"reader/wgsl/lexer_test.cc",
"reader/wgsl/parser_impl_additive_expression_test.cc",
- "reader/wgsl/parser_impl_and_expression_test.cc",
"reader/wgsl/parser_impl_argument_expression_list_test.cc",
"reader/wgsl/parser_impl_assignment_stmt_test.cc",
"reader/wgsl/parser_impl_bitwise_expression_test.cc",
@@ -1367,10 +1366,8 @@
"reader/wgsl/parser_impl_depth_texture_test.cc",
"reader/wgsl/parser_impl_element_count_expression_test.cc",
"reader/wgsl/parser_impl_enable_directive_test.cc",
- "reader/wgsl/parser_impl_equality_expression_test.cc",
"reader/wgsl/parser_impl_error_msg_test.cc",
"reader/wgsl/parser_impl_error_resync_test.cc",
- "reader/wgsl/parser_impl_exclusive_or_expression_test.cc",
"reader/wgsl/parser_impl_expression_test.cc",
"reader/wgsl/parser_impl_external_texture_test.cc",
"reader/wgsl/parser_impl_for_stmt_test.cc",
@@ -1382,11 +1379,8 @@
"reader/wgsl/parser_impl_global_decl_test.cc",
"reader/wgsl/parser_impl_global_variable_decl_test.cc",
"reader/wgsl/parser_impl_if_stmt_test.cc",
- "reader/wgsl/parser_impl_inclusive_or_expression_test.cc",
"reader/wgsl/parser_impl_increment_decrement_stmt_test.cc",
"reader/wgsl/parser_impl_lhs_expression_test.cc",
- "reader/wgsl/parser_impl_logical_and_expression_test.cc",
- "reader/wgsl/parser_impl_logical_or_expression_test.cc",
"reader/wgsl/parser_impl_loop_stmt_test.cc",
"reader/wgsl/parser_impl_math_expression_test.cc",
"reader/wgsl/parser_impl_multiplicative_expression_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 6703d65..5d77470 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -945,7 +945,6 @@
reader/wgsl/lexer_test.cc
reader/wgsl/parser_test.cc
reader/wgsl/parser_impl_additive_expression_test.cc
- reader/wgsl/parser_impl_and_expression_test.cc
reader/wgsl/parser_impl_argument_expression_list_test.cc
reader/wgsl/parser_impl_assignment_stmt_test.cc
reader/wgsl/parser_impl_bitwise_expression_test.cc
@@ -962,10 +961,8 @@
reader/wgsl/parser_impl_depth_texture_test.cc
reader/wgsl/parser_impl_element_count_expression_test.cc
reader/wgsl/parser_impl_enable_directive_test.cc
- reader/wgsl/parser_impl_equality_expression_test.cc
reader/wgsl/parser_impl_error_msg_test.cc
reader/wgsl/parser_impl_error_resync_test.cc
- reader/wgsl/parser_impl_exclusive_or_expression_test.cc
reader/wgsl/parser_impl_expression_test.cc
reader/wgsl/parser_impl_external_texture_test.cc
reader/wgsl/parser_impl_for_stmt_test.cc
@@ -977,11 +974,8 @@
reader/wgsl/parser_impl_global_decl_test.cc
reader/wgsl/parser_impl_global_variable_decl_test.cc
reader/wgsl/parser_impl_if_stmt_test.cc
- reader/wgsl/parser_impl_inclusive_or_expression_test.cc
reader/wgsl/parser_impl_increment_decrement_stmt_test.cc
reader/wgsl/parser_impl_lhs_expression_test.cc
- reader/wgsl/parser_impl_logical_and_expression_test.cc
- reader/wgsl/parser_impl_logical_or_expression_test.cc
reader/wgsl/parser_impl_loop_stmt_test.cc
reader/wgsl/parser_impl_math_expression_test.cc
reader/wgsl/parser_impl_multiplicative_expression_test.cc
diff --git a/src/tint/ast/struct_member_align_attribute.cc b/src/tint/ast/struct_member_align_attribute.cc
index d8ed4fc..e188e7b 100644
--- a/src/tint/ast/struct_member_align_attribute.cc
+++ b/src/tint/ast/struct_member_align_attribute.cc
@@ -26,7 +26,7 @@
StructMemberAlignAttribute::StructMemberAlignAttribute(ProgramID pid,
NodeID nid,
const Source& src,
- uint32_t a)
+ const ast::Expression* a)
: Base(pid, nid, src), align(a) {}
StructMemberAlignAttribute::~StructMemberAlignAttribute() = default;
@@ -38,7 +38,8 @@
const StructMemberAlignAttribute* StructMemberAlignAttribute::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
- return ctx->dst->create<StructMemberAlignAttribute>(src, align);
+ auto* align_ = ctx->Clone(align);
+ return ctx->dst->create<StructMemberAlignAttribute>(src, align_);
}
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_align_attribute.h b/src/tint/ast/struct_member_align_attribute.h
index efff21b..2043b01 100644
--- a/src/tint/ast/struct_member_align_attribute.h
+++ b/src/tint/ast/struct_member_align_attribute.h
@@ -19,6 +19,7 @@
#include <string>
#include "src/tint/ast/attribute.h"
+#include "src/tint/ast/expression.h"
namespace tint::ast {
@@ -29,8 +30,11 @@
/// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier
/// @param src the source of this node
- /// @param align the align value
- StructMemberAlignAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t align);
+ /// @param align the align value expression
+ StructMemberAlignAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const ast::Expression* align);
~StructMemberAlignAttribute() override;
/// @returns the WGSL name for the attribute
@@ -42,8 +46,8 @@
/// @return the newly cloned node
const StructMemberAlignAttribute* Clone(CloneContext* ctx) const override;
- /// The align value
- const uint32_t align;
+ /// The align value expression
+ const ast::Expression* const align;
};
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_align_attribute_test.cc b/src/tint/ast/struct_member_align_attribute_test.cc
index ba4d1bb..f52d32b 100644
--- a/src/tint/ast/struct_member_align_attribute_test.cc
+++ b/src/tint/ast/struct_member_align_attribute_test.cc
@@ -22,8 +22,10 @@
using StructMemberAlignAttributeTest = TestHelper;
TEST_F(StructMemberAlignAttributeTest, Creation) {
- auto* d = create<StructMemberAlignAttribute>(2u);
- EXPECT_EQ(2u, d->align);
+ auto* val = Expr("ident");
+ auto* d = create<StructMemberAlignAttribute>(val);
+ EXPECT_EQ(val, d->align);
+ EXPECT_TRUE(d->align->Is<IdentifierExpression>());
}
} // namespace
diff --git a/src/tint/ast/variable.cc b/src/tint/ast/variable.cc
index ec87e54..bb719c0 100644
--- a/src/tint/ast/variable.cc
+++ b/src/tint/ast/variable.cc
@@ -37,16 +37,4 @@
Variable::~Variable() = default;
-VariableBindingPoint Variable::BindingPoint() const {
- const GroupAttribute* group = nullptr;
- const BindingAttribute* binding = nullptr;
- for (auto* attr : attributes) {
- Switch(
- attr, //
- [&](const GroupAttribute* a) { group = a; },
- [&](const BindingAttribute* a) { binding = a; });
- }
- return VariableBindingPoint{group, binding};
-}
-
} // namespace tint::ast
diff --git a/src/tint/ast/variable.h b/src/tint/ast/variable.h
index 1f5d77a..8772f5b 100644
--- a/src/tint/ast/variable.h
+++ b/src/tint/ast/variable.h
@@ -20,31 +20,19 @@
#include "src/tint/ast/access.h"
#include "src/tint/ast/attribute.h"
+#include "src/tint/ast/binding_attribute.h"
#include "src/tint/ast/expression.h"
+#include "src/tint/ast/group_attribute.h"
#include "src/tint/ast/storage_class.h"
// Forward declarations
namespace tint::ast {
-class BindingAttribute;
-class GroupAttribute;
class LocationAttribute;
class Type;
} // namespace tint::ast
namespace tint::ast {
-/// VariableBindingPoint holds a group and binding attribute.
-struct VariableBindingPoint {
- /// The `@group` part of the binding point
- const GroupAttribute* group = nullptr;
- /// The `@binding` part of the binding point
- const BindingAttribute* binding = nullptr;
-
- /// @returns true if the BindingPoint has a valid group and binding
- /// attribute.
- inline operator bool() const { return group && binding; }
-};
-
/// Variable is the base class for Var, Let, Const, Override and Parameter.
///
/// An instance of this class represents one of five constructs in WGSL: "var" declaration, "let"
@@ -75,9 +63,11 @@
/// Destructor
~Variable() override;
- /// @returns the binding point information from the variable's attributes.
- /// @note binding points should only be applied to Var and Parameter types.
- VariableBindingPoint BindingPoint() const;
+ /// @returns true if the variable has both group and binding attributes
+ bool HasBindingPoint() const {
+ return ast::GetAttribute<ast::BindingAttribute>(attributes) != nullptr &&
+ ast::GetAttribute<ast::GroupAttribute>(attributes) != nullptr;
+ }
/// @returns the kind of the variable, which can be used in diagnostics
/// e.g. "var", "let", "const", etc
diff --git a/src/tint/ast/variable_test.cc b/src/tint/ast/variable_test.cc
index 12f528b..14fb766 100644
--- a/src/tint/ast/variable_test.cc
+++ b/src/tint/ast/variable_test.cc
@@ -105,36 +105,24 @@
EXPECT_EQ(1u, location->value);
}
-TEST_F(VariableTest, BindingPoint) {
+TEST_F(VariableTest, HasBindingPoint_BothProvided) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Binding(2), Group(1));
- EXPECT_TRUE(var->BindingPoint());
- ASSERT_NE(var->BindingPoint().binding, nullptr);
- ASSERT_NE(var->BindingPoint().group, nullptr);
- EXPECT_EQ(var->BindingPoint().binding->value, 2u);
- EXPECT_EQ(var->BindingPoint().group->value, 1u);
+ EXPECT_TRUE(var->HasBindingPoint());
}
-TEST_F(VariableTest, BindingPointAttributes) {
+TEST_F(VariableTest, HasBindingPoint_NeitherProvided) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, utils::Empty);
- EXPECT_FALSE(var->BindingPoint());
- EXPECT_EQ(var->BindingPoint().group, nullptr);
- EXPECT_EQ(var->BindingPoint().binding, nullptr);
+ EXPECT_FALSE(var->HasBindingPoint());
}
-TEST_F(VariableTest, BindingPointMissingGroupAttribute) {
+TEST_F(VariableTest, HasBindingPoint_MissingGroupAttribute) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Binding(2));
- EXPECT_FALSE(var->BindingPoint());
- ASSERT_NE(var->BindingPoint().binding, nullptr);
- EXPECT_EQ(var->BindingPoint().binding->value, 2u);
- EXPECT_EQ(var->BindingPoint().group, nullptr);
+ EXPECT_FALSE(var->HasBindingPoint());
}
-TEST_F(VariableTest, BindingPointMissingBindingAttribute) {
+TEST_F(VariableTest, HasBindingPoint_MissingBindingAttribute) {
auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, Group(1));
- EXPECT_FALSE(var->BindingPoint());
- ASSERT_NE(var->BindingPoint().group, nullptr);
- EXPECT_EQ(var->BindingPoint().group->value, 1u);
- EXPECT_EQ(var->BindingPoint().binding, nullptr);
+ EXPECT_FALSE(var->HasBindingPoint());
}
} // namespace
diff --git a/src/tint/castable.h b/src/tint/castable.h
index d654fca..fe820e1 100644
--- a/src/tint/castable.h
+++ b/src/tint/castable.h
@@ -159,7 +159,7 @@
}
/// @returns a compile-time hashcode for the type `T`.
- /// @note the returned hashcode will have at most 2 bits set, as the hashes
+ /// @note the returned hashcode will have exactly 2 bits set, as the hashes
/// are expected to be used in bloom-filters which will quickly saturate when
/// multiple hashcodes are bitwise-or'd together.
template <typename T>
@@ -176,7 +176,8 @@
#endif
constexpr uint32_t bit_a = (crc & 63);
constexpr uint32_t bit_b = ((crc >> 6) & 63);
- return (static_cast<HashCode>(1) << bit_a) | (static_cast<HashCode>(1) << bit_b);
+ constexpr uint32_t bit_c = (bit_a == bit_b) ? ((bit_a + 1) & 63) : bit_b;
+ return (static_cast<HashCode>(1) << bit_a) | (static_cast<HashCode>(1) << bit_c);
}
/// @returns the hashcode of the given type, bitwise-or'd with the hashcodes
@@ -221,17 +222,16 @@
if constexpr (kCount == 0) {
return false;
} else if constexpr (kCount == 1) {
- return Is<std::tuple_element_t<0, TUPLE>>();
- } else if constexpr (kCount == 2) {
- return Is<std::tuple_element_t<0, TUPLE>>() || Is<std::tuple_element_t<1, TUPLE>>();
- } else if constexpr (kCount == 3) {
- return Is<std::tuple_element_t<0, TUPLE>>() || Is<std::tuple_element_t<1, TUPLE>>() ||
- Is<std::tuple_element_t<2, TUPLE>>();
+ return Is(&Of<std::tuple_element_t<0, TUPLE>>());
} else {
- // Optimization: Compare the object's hashcode to the bitwise-or of all
- // the tested type's hashcodes. If there's no intersection of bits in
- // the two masks, then we can guarantee that the type is not in `TO`.
- if (full_hashcode & TypeInfo::CombinedHashCodeOfTuple<TUPLE>()) {
+ // Optimization: Compare the object's hashcode to the bitwise-or of all the tested
+ // type's hashcodes. If there's no intersection of bits in the two masks, then we can
+ // guarantee that the type is not in `TO`.
+ HashCode mask = full_hashcode & TypeInfo::CombinedHashCodeOfTuple<TUPLE>();
+ // HashCodeOf() ensures that two bits are always set for every hash, so we can quickly
+ // eliminate the bitmask where only one bit is set.
+ HashCode two_bits = mask & (mask - 1);
+ if (two_bits) {
// Possibly one of the types in `TUPLE`.
// Split the search in two, and scan each block.
static constexpr auto kMid = kCount / 2;
@@ -607,9 +607,14 @@
return false;
} else {
// Multiple cases.
- // Check the hashcode bits to see if there's any possibility of a case
- // matching in these cases. If there isn't, we can skip all these cases.
- if (type->full_hashcode & TypeInfo::CombinedHashCodeOf<SwitchCaseType<CASES>...>()) {
+ // Check the hashcode bits to see if there's any possibility of a case matching in these
+ // cases. If there isn't, we can skip all these cases.
+ TypeInfo::HashCode mask =
+ type->full_hashcode & TypeInfo::CombinedHashCodeOf<SwitchCaseType<CASES>...>();
+ // HashCodeOf() ensures that two bits are always set for every hash, so we can quickly
+ // eliminate the bitmask where only one bit is set.
+ TypeInfo::HashCode two_bits = mask & (mask - 1);
+ if (two_bits) {
// There's a possibility. We need to scan further.
// Split the cases into two, and recurse.
constexpr size_t kMid = kNumCases / 2;
diff --git a/src/tint/clone_context.cc b/src/tint/clone_context.cc
index 1225294..457522b 100644
--- a/src/tint/clone_context.cc
+++ b/src/tint/clone_context.cc
@@ -27,9 +27,6 @@
Cloneable::Cloneable(Cloneable&&) = default;
Cloneable::~Cloneable() = default;
-CloneContext::ListTransforms::ListTransforms() = default;
-CloneContext::ListTransforms::~ListTransforms() = default;
-
CloneContext::CloneContext(ProgramBuilder* to, Program const* from, bool auto_clone_symbols)
: dst(to), src(from) {
if (auto_clone_symbols) {
@@ -48,7 +45,7 @@
if (!src) {
return s; // In-place clone
}
- return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol {
+ return cloned_symbols_.GetOrCreate(s, [&]() -> Symbol {
if (symbol_transform_) {
return symbol_transform_(s);
}
@@ -76,9 +73,8 @@
}
// Was Replace() called for this object?
- auto it = replacements_.find(object);
- if (it != replacements_.end()) {
- return it->second();
+ if (auto* fn = replacements_.Find(object)) {
+ return (*fn)();
}
// Attempt to clone using the registered replacer functions.
diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h
index e7e2d52..3c5f6ec 100644
--- a/src/tint/clone_context.h
+++ b/src/tint/clone_context.h
@@ -18,8 +18,6 @@
#include <algorithm>
#include <functional>
#include <type_traits>
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
#include <vector>
@@ -28,6 +26,7 @@
#include "src/tint/program_id.h"
#include "src/tint/symbol.h"
#include "src/tint/traits.h"
+#include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h"
// Forward declarations
@@ -201,56 +200,49 @@
void Clone(utils::Vector<T*, N>& to, const utils::Vector<T*, N>& from) {
to.Reserve(from.Length());
- auto list_transform_it = list_transforms_.find(&from);
- if (list_transform_it != list_transforms_.end()) {
- const auto& transforms = list_transform_it->second;
- for (auto* o : transforms.insert_front_) {
+ auto transforms = list_transforms_.Find(&from);
+
+ if (transforms) {
+ for (auto* o : transforms->insert_front_) {
to.Push(CheckedCast<T>(o));
}
for (auto& el : from) {
- auto insert_before_it = transforms.insert_before_.find(el);
- if (insert_before_it != transforms.insert_before_.end()) {
- for (auto insert : insert_before_it->second) {
+ if (auto* insert_before = transforms->insert_before_.Find(el)) {
+ for (auto insert : *insert_before) {
to.Push(CheckedCast<T>(insert));
}
}
- if (transforms.remove_.count(el) == 0) {
+ if (!transforms->remove_.Contains(el)) {
to.Push(Clone(el));
}
- auto insert_after_it = transforms.insert_after_.find(el);
- if (insert_after_it != transforms.insert_after_.end()) {
- for (auto insert : insert_after_it->second) {
+ if (auto* insert_after = transforms->insert_after_.Find(el)) {
+ for (auto insert : *insert_after) {
to.Push(CheckedCast<T>(insert));
}
}
}
- for (auto* o : transforms.insert_back_) {
+ for (auto* o : transforms->insert_back_) {
to.Push(CheckedCast<T>(o));
}
} else {
for (auto& el : from) {
to.Push(Clone(el));
- // Clone(el) may have inserted after
- list_transform_it = list_transforms_.find(&from);
- if (list_transform_it != list_transforms_.end()) {
- const auto& transforms = list_transform_it->second;
-
- auto insert_after_it = transforms.insert_after_.find(el);
- if (insert_after_it != transforms.insert_after_.end()) {
- for (auto insert : insert_after_it->second) {
+ // Clone(el) may have updated the transformation list, adding an `insert_after`
+ // transform for `from`.
+ if (transforms) {
+ if (auto* insert_after = transforms->insert_after_.Find(el)) {
+ for (auto insert : *insert_after) {
to.Push(CheckedCast<T>(insert));
}
}
}
}
- // Clone(el)s may have inserted back
- list_transform_it = list_transforms_.find(&from);
- if (list_transform_it != list_transforms_.end()) {
- const auto& transforms = list_transform_it->second;
-
- for (auto* o : transforms.insert_back_) {
+ // Clone(el) may have updated the transformation list, adding an `insert_back_`
+ // transform for `from`.
+ if (transforms) {
+ for (auto* o : transforms->insert_back_) {
to.Push(CheckedCast<T>(o));
}
}
@@ -358,7 +350,7 @@
CloneContext& Replace(const WHAT* what, const WITH* with) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with);
- replacements_[what] = [with]() -> const Cloneable* { return with; };
+ replacements_.Add(what, [with]() -> const Cloneable* { return with; });
return *this;
}
@@ -378,7 +370,7 @@
template <typename WHAT, typename WITH, typename = std::invoke_result_t<WITH>>
CloneContext& Replace(const WHAT* what, WITH&& with) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
- replacements_[what] = with;
+ replacements_.Add(what, with);
return *this;
}
@@ -396,7 +388,7 @@
return *this;
}
- list_transforms_[&vector].remove_.emplace(object);
+ list_transforms_.Edit(&vector).remove_.Add(object);
return *this;
}
@@ -408,9 +400,7 @@
template <typename T, size_t N, typename OBJECT>
CloneContext& InsertFront(const utils::Vector<T, N>& vector, OBJECT* object) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
- auto& transforms = list_transforms_[&vector];
- auto& list = transforms.insert_front_;
- list.Push(object);
+ list_transforms_.Edit(&vector).insert_front_.Push(object);
return *this;
}
@@ -422,9 +412,7 @@
template <typename T, size_t N, typename OBJECT>
CloneContext& InsertBack(const utils::Vector<T, N>& vector, OBJECT* object) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object);
- auto& transforms = list_transforms_[&vector];
- auto& list = transforms.insert_back_;
- list.Push(object);
+ list_transforms_.Edit(&vector).insert_back_.Push(object);
return *this;
}
@@ -446,9 +434,7 @@
return *this;
}
- auto& transforms = list_transforms_[&vector];
- auto& list = transforms.insert_before_[before];
- list.Push(object);
+ list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(object);
return *this;
}
@@ -470,9 +456,7 @@
return *this;
}
- auto& transforms = list_transforms_[&vector];
- auto& list = transforms.insert_after_[after];
- list.Push(object);
+ list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(object);
return *this;
}
@@ -502,6 +486,31 @@
std::function<const Cloneable*(const Cloneable*)> function;
};
+ /// A vector of const Cloneable*
+ using CloneableList = utils::Vector<const Cloneable*, 4>;
+
+ /// Transformations to be applied to a list (vector)
+ struct ListTransforms {
+ /// A map of object in #src to omit when cloned into #dst.
+ utils::Hashset<const Cloneable*, 4> remove_;
+
+ /// A list of objects in #dst to insert before any others when the vector is cloned.
+ CloneableList insert_front_;
+
+ /// A list of objects in #dst to insert after all others when the vector is cloned.
+ CloneableList insert_back_;
+
+ /// A map of object in #src to the list of cloned objects in #dst.
+ /// Clone(const utils::Vector<T*>& v) will use this to insert the map-value
+ /// list into the target vector before cloning and inserting the map-key.
+ utils::Hashmap<const Cloneable*, CloneableList, 4> insert_before_;
+
+ /// A map of object in #src to the list of cloned objects in #dst.
+ /// Clone(const utils::Vector<T*>& v) will use this to insert the map-value
+ /// list into the target vector after cloning and inserting the map-key.
+ utils::Hashmap<const Cloneable*, CloneableList, 4> insert_after_;
+ };
+
CloneContext(const CloneContext&) = delete;
CloneContext& operator=(const CloneContext&) = delete;
@@ -530,50 +539,78 @@
/// @returns the diagnostic list of #dst
diag::List& Diagnostics() const;
- /// A vector of const Cloneable*
- using CloneableList = utils::Vector<const Cloneable*, 4>;
+ /// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
+ struct VectorListTransforms {
+ /// An accessor to the VectorListTransforms map.
+ /// Index caches the last map lookup, and will only re-search the map if the transform map
+ /// was modified since the last lookup.
+ struct Index {
+ /// @returns true if the map now holds a value for the index
+ operator bool() {
+ Update();
+ return cached_;
+ }
- /// Transformations to be applied to a list (vector)
- struct ListTransforms {
- /// Constructor
- ListTransforms();
- /// Destructor
- ~ListTransforms();
+ /// @returns a pointer to the indexed map entry
+ const ListTransforms* operator->() {
+ Update();
+ return cached_;
+ }
- /// A map of object in #src to omit when cloned into #dst.
- std::unordered_set<const Cloneable*> remove_;
+ private:
+ friend VectorListTransforms;
- /// A list of objects in #dst to insert before any others when the vector is
- /// cloned.
- CloneableList insert_front_;
+ Index(const void* list,
+ VectorListTransforms& vlt,
+ uint32_t generation,
+ const ListTransforms* cached)
+ : list_(list), vlt_(vlt), generation_(generation), cached_(cached) {}
- /// A list of objects in #dst to insert befor after any others when the
- /// vector is cloned.
- CloneableList insert_back_;
+ void Update() {
+ if (vlt_.generation_ != generation_) {
+ cached_ = vlt_.map_.Find(list_);
+ generation_ = vlt_.generation_;
+ }
+ }
- /// A map of object in #src to the list of cloned objects in #dst.
- /// Clone(const utils::Vector<T*>& v) will use this to insert the map-value
- /// list into the target vector before cloning and inserting the map-key.
- std::unordered_map<const Cloneable*, CloneableList> insert_before_;
+ const void* list_;
+ VectorListTransforms& vlt_;
+ uint32_t generation_;
+ const ListTransforms* cached_;
+ };
- /// A map of object in #src to the list of cloned objects in #dst.
- /// Clone(const utils::Vector<T*>& v) will use this to insert the map-value
- /// list into the target vector after cloning and inserting the map-key.
- std::unordered_map<const Cloneable*, CloneableList> insert_after_;
+ /// Edit returns a reference to the ListTransforms for the given vector pointer and
+ /// increments #list_transform_generation_ signalling that the list transforms have been
+ /// modified.
+ inline ListTransforms& Edit(const void* list) {
+ generation_++;
+ return map_.GetOrZero(list);
+ }
+
+ /// @returns an Index to the transforms for the given list.
+ inline Index Find(const void* list) {
+ return Index{list, *this, generation_, map_.Find(list)};
+ }
+
+ private:
+ /// The map of vector pointer to ListTransforms
+ utils::Hashmap<const void*, ListTransforms, 4> map_;
+
+ /// A counter that's incremented each time list transforms are modified.
+ uint32_t generation_ = 0;
};
- /// A map of object in #src to functions that create their replacement in
- /// #dst
- std::unordered_map<const Cloneable*, std::function<const Cloneable*()>> replacements_;
+ /// A map of object in #src to functions that create their replacement in #dst
+ utils::Hashmap<const Cloneable*, std::function<const Cloneable*()>, 8> replacements_;
/// A map of symbol in #src to their cloned equivalent in #dst
- std::unordered_map<Symbol, Symbol> cloned_symbols_;
+ utils::Hashmap<Symbol, Symbol, 32> cloned_symbols_;
/// Cloneable transform functions registered with ReplaceAll()
utils::Vector<CloneableTransform, 8> transforms_;
- /// Map of utils::Vector pointer to transforms for that list
- std::unordered_map<const void*, ListTransforms> list_transforms_;
+ /// Transformations to apply to vectors
+ VectorListTransforms list_transforms_;
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;
diff --git a/src/tint/fuzzers/random_generator.cc b/src/tint/fuzzers/random_generator.cc
index 186ce1c..35b5caf 100644
--- a/src/tint/fuzzers/random_generator.cc
+++ b/src/tint/fuzzers/random_generator.cc
@@ -34,10 +34,9 @@
/// @param size - number of elements in buffer
/// @returns hash of the data in the buffer
size_t HashBuffer(const uint8_t* data, const size_t size) {
- size_t hash = 102931;
- utils::HashCombine(&hash, size);
+ size_t hash = utils::Hash(size);
for (size_t i = 0; i < size; i++) {
- utils::HashCombine(&hash, data[i]);
+ hash = utils::HashCombine(hash, data[i]);
}
return hash;
}
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index 6073470..2f8d09a 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -370,8 +370,8 @@
ResourceBinding entry;
entry.resource_type = ResourceBinding::ResourceType::kUniformBuffer;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
entry.size = unwrapped_type->Size();
entry.size_no_padding = entry.size;
if (auto* str = unwrapped_type->As<sem::Struct>()) {
@@ -410,8 +410,8 @@
ResourceBinding entry;
entry.resource_type = ResourceBinding::ResourceType::kSampler;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
result.push_back(entry);
}
@@ -434,8 +434,8 @@
ResourceBinding entry;
entry.resource_type = ResourceBinding::ResourceType::kComparisonSampler;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
result.push_back(entry);
}
@@ -475,8 +475,8 @@
ResourceBinding entry;
entry.resource_type = resource_type;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
auto* tex = var->Type()->UnwrapRef()->As<sem::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(tex->dim());
@@ -692,8 +692,8 @@
ResourceBinding entry;
entry.resource_type = read_only ? ResourceBinding::ResourceType::kReadOnlyStorageBuffer
: ResourceBinding::ResourceType::kStorageBuffer;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
entry.size = unwrapped_type->Size();
if (auto* str = unwrapped_type->As<sem::Struct>()) {
entry.size_no_padding = str->SizeNoPadding();
@@ -728,8 +728,8 @@
entry.resource_type = multisampled_only
? ResourceBinding::ResourceType::kMultisampledTexture
: ResourceBinding::ResourceType::kSampledTexture;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
auto* texture_type = var->Type()->UnwrapRef()->As<sem::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(texture_type->dim());
@@ -765,8 +765,8 @@
ResourceBinding entry;
entry.resource_type = ResourceBinding::ResourceType::kWriteOnlyStorageTexture;
- entry.bind_group = binding_info.group->value;
- entry.binding = binding_info.binding->value;
+ entry.bind_group = binding_info.group;
+ entry.binding = binding_info.binding;
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(texture_type->dim());
@@ -838,13 +838,8 @@
GetOriginatingResources(
std::array<const ast::Expression*, 2>{t, s},
[&](std::array<const sem::GlobalVariable*, 2> globals) {
- auto* texture = globals[0]->Declaration()->As<ast::Var>();
- sem::BindingPoint texture_binding_point = {texture->BindingPoint().group->value,
- texture->BindingPoint().binding->value};
-
- auto* sampler = globals[1]->Declaration()->As<ast::Var>();
- sem::BindingPoint sampler_binding_point = {sampler->BindingPoint().group->value,
- sampler->BindingPoint().binding->value};
+ auto texture_binding_point = globals[0]->BindingPoint();
+ auto sampler_binding_point = globals[1]->BindingPoint();
for (auto* entry_point : entry_points) {
const auto& ep_name =
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 1d04465..5f0013f 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -1811,7 +1811,7 @@
"foo_type",
utils::Vector{
Member("0i32", ty.i32()),
- Member("b", ty.array(ty.u32(), 4_u, /*stride*/ 16), utils::Vector{MemberAlign(16)}),
+ Member("b", ty.array(ty.u32(), 4_u, /*stride*/ 16), utils::Vector{MemberAlign(16_u)}),
});
AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0);
@@ -3238,11 +3238,9 @@
// of its last member, rounded up to the alignment of its largest member. So
// here the struct is expected to occupy 1024 bytes of workgroup storage.
const auto* wg_struct_type = MakeStructTypeFromMembers(
- "WgStruct",
- utils::Vector{
- MakeStructMember(0, ty.f32(),
- utils::Vector{create<ast::StructMemberAlignAttribute>(1024u)}),
- });
+ "WgStruct", utils::Vector{
+ MakeStructMember(0, ty.f32(), utils::Vector{MemberAlign(1024_u)}),
+ });
AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 42e428a..c797aab 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2274,17 +2274,19 @@
/// Creates a ast::StructMemberAlignAttribute
/// @param source the source information
- /// @param val the align value
+ /// @param val the align value expression
/// @returns the align attribute pointer
- const ast::StructMemberAlignAttribute* MemberAlign(const Source& source, uint32_t val) {
- return create<ast::StructMemberAlignAttribute>(source, val);
+ template <typename EXPR>
+ const ast::StructMemberAlignAttribute* MemberAlign(const Source& source, EXPR&& val) {
+ return create<ast::StructMemberAlignAttribute>(source, Expr(std::forward<EXPR>(val)));
}
/// Creates a ast::StructMemberAlignAttribute
- /// @param val the align value
+ /// @param val the align value expression
/// @returns the align attribute pointer
- const ast::StructMemberAlignAttribute* MemberAlign(uint32_t val) {
- return create<ast::StructMemberAlignAttribute>(source_, val);
+ template <typename EXPR>
+ const ast::StructMemberAlignAttribute* MemberAlign(EXPR&& val) {
+ return create<ast::StructMemberAlignAttribute>(source_, Expr(std::forward<EXPR>(val)));
}
/// Creates the ast::GroupAttribute
@@ -2969,6 +2971,15 @@
/// Creates an ast::WorkgroupAttribute
/// @param source the source information
/// @param x the x dimension expression
+ /// @returns the workgroup attribute pointer
+ template <typename EXPR_X>
+ const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) {
+ return WorkgroupSize(source, std::forward<EXPR_X>(x), nullptr, nullptr);
+ }
+
+ /// Creates an ast::WorkgroupAttribute
+ /// @param source the source information
+ /// @param x the x dimension expression
/// @param y the y dimension expression
/// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y>
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index 9335356..bf95b4e 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -1920,7 +1920,7 @@
return Failure::kErrored;
}
- auto initializer = maybe_expression();
+ auto initializer = expression();
if (initializer.errored) {
return Failure::kErrored;
}
@@ -1947,7 +1947,7 @@
return Failure::kErrored;
}
- auto initializer = maybe_expression();
+ auto initializer = expression();
if (initializer.errored) {
return Failure::kErrored;
}
@@ -1974,7 +1974,7 @@
const ast::Expression* initializer = nullptr;
if (match(Token::Type::kEqual)) {
- auto initializer_expr = maybe_expression();
+ auto initializer_expr = expression();
if (initializer_expr.errored) {
return Failure::kErrored;
}
@@ -2881,7 +2881,7 @@
// shift_expression
// : unary_expression shift_expression.post.unary_expression
-Maybe<const ast::Expression*> ParserImpl::maybe_shift_expression() {
+Maybe<const ast::Expression*> ParserImpl::shift_expression() {
auto lhs = unary_expression();
if (lhs.errored) {
return Failure::kErrored;
@@ -2913,12 +2913,13 @@
name = ">>";
}
+ auto& rhs_start = peek();
auto rhs = unary_expression();
if (rhs.errored) {
return Failure::kErrored;
}
if (!rhs.matched) {
- return add_error(t,
+ return add_error(rhs_start,
std::string("unable to parse right side of ") + name + " expression");
}
return create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
@@ -2929,7 +2930,7 @@
// relational_expression
// : unary_expression relational_expression.post.unary_expression
-Maybe<const ast::Expression*> ParserImpl::maybe_relational_expression() {
+Maybe<const ast::Expression*> ParserImpl::relational_expression() {
auto lhs = unary_expression();
if (lhs.errored) {
return Failure::kErrored;
@@ -2978,7 +2979,7 @@
}
auto& next = peek();
- auto rhs = maybe_shift_expression();
+ auto rhs = shift_expression();
if (rhs.errored) {
return Failure::kErrored;
}
@@ -3000,7 +3001,7 @@
// relational_expression ( or_or relational_expression )*
//
// Note, a `relational_expression` element was added to simplify many of the right sides
-Maybe<const ast::Expression*> ParserImpl::maybe_expression() {
+Maybe<const ast::Expression*> ParserImpl::expression() {
auto lhs = unary_expression();
if (lhs.errored) {
return Failure::kErrored;
@@ -3130,437 +3131,6 @@
return create<ast::UnaryOpExpression>(t.source(), op, expr.value);
}
-// multiplicative_expr
-// :
-// | STAR unary_expression multiplicative_expr
-// | FORWARD_SLASH unary_expression multiplicative_expr
-// | MODULO unary_expression multiplicative_expr
-Expect<const ast::Expression*> ParserImpl::expect_multiplicative_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- ast::BinaryOp op = ast::BinaryOp::kNone;
- if (peek_is(Token::Type::kStar)) {
- op = ast::BinaryOp::kMultiply;
- } else if (peek_is(Token::Type::kForwardSlash)) {
- op = ast::BinaryOp::kDivide;
- } else if (peek_is(Token::Type::kMod)) {
- op = ast::BinaryOp::kModulo;
- } else {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = unary_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of " + std::string(t.to_name()) +
- " expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// multiplicative_expression
-// : unary_expression multiplicative_expr
-Maybe<const ast::Expression*> ParserImpl::multiplicative_expression() {
- auto lhs = unary_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_multiplicative_expr(lhs.value);
-}
-
-// additive_expr
-// :
-// | PLUS multiplicative_expression additive_expr
-// | MINUS multiplicative_expression additive_expr
-Expect<const ast::Expression*> ParserImpl::expect_additive_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- ast::BinaryOp op = ast::BinaryOp::kNone;
- if (peek_is(Token::Type::kPlus)) {
- op = ast::BinaryOp::kAdd;
- } else if (peek_is(Token::Type::kMinus)) {
- op = ast::BinaryOp::kSubtract;
- } else {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = multiplicative_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of + expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// additive_expression
-// : multiplicative_expression additive_expr
-Maybe<const ast::Expression*> ParserImpl::additive_expression() {
- auto lhs = multiplicative_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_additive_expr(lhs.value);
-}
-
-// shift_expr
-// :
-// | SHIFT_LEFT additive_expression shift_expr
-// | SHIFT_RIGHT additive_expression shift_expr
-Expect<const ast::Expression*> ParserImpl::expect_shift_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- auto* name = "";
- ast::BinaryOp op = ast::BinaryOp::kNone;
- if (peek_is(Token::Type::kShiftLeft)) {
- op = ast::BinaryOp::kShiftLeft;
- name = "<<";
- } else if (peek_is(Token::Type::kShiftRight)) {
- op = ast::BinaryOp::kShiftRight;
- name = ">>";
- } else {
- return lhs;
- }
-
- auto& t = next();
- auto rhs = additive_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(),
- std::string("unable to parse right side of ") + name + " expression");
- }
-
- return lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// shift_expression
-// : additive_expression shift_expr
-Maybe<const ast::Expression*> ParserImpl::shift_expression() {
- auto lhs = additive_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_shift_expr(lhs.value);
-}
-
-// relational_expr
-// :
-// | LESS_THAN shift_expression relational_expr
-// | GREATER_THAN shift_expression relational_expr
-// | LESS_THAN_EQUAL shift_expression relational_expr
-// | GREATER_THAN_EQUAL shift_expression relational_expr
-Expect<const ast::Expression*> ParserImpl::expect_relational_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- ast::BinaryOp op = ast::BinaryOp::kNone;
- if (peek_is(Token::Type::kLessThan)) {
- op = ast::BinaryOp::kLessThan;
- } else if (peek_is(Token::Type::kGreaterThan)) {
- op = ast::BinaryOp::kGreaterThan;
- } else if (peek_is(Token::Type::kLessThanEqual)) {
- op = ast::BinaryOp::kLessThanEqual;
- } else if (peek_is(Token::Type::kGreaterThanEqual)) {
- op = ast::BinaryOp::kGreaterThanEqual;
- } else {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = shift_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of " + std::string(t.to_name()) +
- " expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// relational_expression
-// : shift_expression relational_expr
-Maybe<const ast::Expression*> ParserImpl::relational_expression() {
- auto lhs = shift_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_relational_expr(lhs.value);
-}
-
-// equality_expr
-// :
-// | EQUAL_EQUAL relational_expression equality_expr
-// | NOT_EQUAL relational_expression equality_expr
-Expect<const ast::Expression*> ParserImpl::expect_equality_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- ast::BinaryOp op = ast::BinaryOp::kNone;
- if (peek_is(Token::Type::kEqualEqual)) {
- op = ast::BinaryOp::kEqual;
- } else if (peek_is(Token::Type::kNotEqual)) {
- op = ast::BinaryOp::kNotEqual;
- } else {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = relational_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of " + std::string(t.to_name()) +
- " expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), op, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// equality_expression
-// : relational_expression equality_expr
-Maybe<const ast::Expression*> ParserImpl::equality_expression() {
- auto lhs = relational_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_equality_expr(lhs.value);
-}
-
-// and_expr
-// :
-// | AND equality_expression and_expr
-Expect<const ast::Expression*> ParserImpl::expect_and_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- if (!peek_is(Token::Type::kAnd)) {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = equality_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of & expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), ast::BinaryOp::kAnd, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// and_expression
-// : equality_expression and_expr
-Maybe<const ast::Expression*> ParserImpl::and_expression() {
- auto lhs = equality_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_and_expr(lhs.value);
-}
-
-// exclusive_or_expr
-// :
-// | XOR and_expression exclusive_or_expr
-Expect<const ast::Expression*> ParserImpl::expect_exclusive_or_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- Source source;
- if (!match(Token::Type::kXor, &source)) {
- return lhs;
- }
-
- auto rhs = and_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of ^ expression");
- }
-
- lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kXor, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// exclusive_or_expression
-// : and_expression exclusive_or_expr
-Maybe<const ast::Expression*> ParserImpl::exclusive_or_expression() {
- auto lhs = and_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_exclusive_or_expr(lhs.value);
-}
-
-// inclusive_or_expr
-// :
-// | OR exclusive_or_expression inclusive_or_expr
-Expect<const ast::Expression*> ParserImpl::expect_inclusive_or_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- Source source;
- if (!match(Token::Type::kOr, &source)) {
- return lhs;
- }
-
- auto rhs = exclusive_or_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of | expression");
- }
-
- lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kOr, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// inclusive_or_expression
-// : exclusive_or_expression inclusive_or_expr
-Maybe<const ast::Expression*> ParserImpl::inclusive_or_expression() {
- auto lhs = exclusive_or_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_inclusive_or_expr(lhs.value);
-}
-
-// logical_and_expr
-// :
-// | AND_AND inclusive_or_expression logical_and_expr
-Expect<const ast::Expression*> ParserImpl::expect_logical_and_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- if (!peek_is(Token::Type::kAndAnd)) {
- return lhs;
- }
-
- auto& t = next();
-
- auto rhs = inclusive_or_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of && expression");
- }
-
- lhs = create<ast::BinaryExpression>(t.source(), ast::BinaryOp::kLogicalAnd, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// logical_and_expression
-// : inclusive_or_expression logical_and_expr
-Maybe<const ast::Expression*> ParserImpl::logical_and_expression() {
- auto lhs = inclusive_or_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_logical_and_expr(lhs.value);
-}
-
-// logical_or_expr
-// :
-// | OR_OR logical_and_expression logical_or_expr
-Expect<const ast::Expression*> ParserImpl::expect_logical_or_expr(const ast::Expression* lhs) {
- while (continue_parsing()) {
- Source source;
- if (!match(Token::Type::kOrOr, &source)) {
- return lhs;
- }
-
- auto rhs = logical_and_expression();
- if (rhs.errored) {
- return Failure::kErrored;
- }
- if (!rhs.matched) {
- return add_error(peek(), "unable to parse right side of || expression");
- }
-
- lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalOr, lhs, rhs.value);
- }
- return Failure::kErrored;
-}
-
-// logical_or_expression
-// : logical_and_expression logical_or_expr
-Maybe<const ast::Expression*> ParserImpl::logical_or_expression() {
- auto lhs = logical_and_expression();
- if (lhs.errored) {
- return Failure::kErrored;
- }
- if (!lhs.matched) {
- return Failure::kNoMatch;
- }
-
- return expect_logical_or_expr(lhs.value);
-}
-
-// expression:
-// : relational_expression
-// | short_circuit_or_expression or_or relational_expression
-// | short_circuit_and_expression and_and relational_expression
-// | bitwise_expression
-Maybe<const ast::Expression*> ParserImpl::expression() {
- return logical_or_expression();
-}
-
// compound_assignment_operator
// : plus_equal
// | minus_equal
@@ -3835,28 +3405,25 @@
}
// attribute
-// : ATTR 'align' PAREN_LEFT expression attrib_end
-// | ATTR 'binding' PAREN_LEFT expression attrib_end
-// | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end
+// : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT
// | ATTR 'const'
-// | ATTR 'group' PAREN_LEFT expression attrib_end
-// | ATTR 'id' PAREN_LEFT expression attrib_end
-// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end
+// | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA
-// interpolation_sample_name attrib_end
+// interpolation_sample_name COMMA? PAREN_RIGHT
// | ATTR 'invariant'
-// | ATTR 'location' PAREN_LEFT expression attrib_end
-// | ATTR 'size' PAREN_LEFT expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end
+// | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA
+// expression COMMA? PAREN_RIGHT
// | ATTR 'vertex'
// | ATTR 'fragment'
// | ATTR 'compute'
-//
-// attrib_end
-// : COMMA? PAREN_RIGHT
-//
Maybe<const ast::Attribute*> ParserImpl::attribute() {
using Result = Maybe<const ast::Attribute*>;
auto& t = next();
@@ -3874,7 +3441,9 @@
}
match(Token::Type::kComma);
- return create<ast::StructMemberAlignAttribute>(t.source(), val.value);
+ return create<ast::StructMemberAlignAttribute>(
+ t.source(), create<ast::IntLiteralExpression>(
+ val.value, ast::IntLiteralExpression::Suffix::kNone));
});
}
@@ -4031,7 +3600,7 @@
const ast::Expression* y = nullptr;
const ast::Expression* z = nullptr;
- auto expr = primary_expression();
+ auto expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@@ -4041,7 +3610,7 @@
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
- expr = primary_expression();
+ expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@@ -4051,7 +3620,7 @@
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
- expr = primary_expression();
+ expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h
index 44525c0..cc0cd2b 100644
--- a/src/tint/reader/wgsl/parser_impl.h
+++ b/src/tint/reader/wgsl/parser_impl.h
@@ -621,49 +621,9 @@
/// Parses a `unary_expression` grammar element
/// @returns the parsed expression or nullptr
Maybe<const ast::Expression*> unary_expression();
- /// Parses the recursive part of the `multiplicative_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_multiplicative_expr(const ast::Expression* lhs);
- /// Parses the `multiplicative_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> multiplicative_expression();
- /// Parses the recursive part of the `additive_expression`, erroring on parse
- /// failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_additive_expr(const ast::Expression* lhs);
- /// Parses the `additive_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> additive_expression();
- /// Parses the recursive part of the `shift_expression`, erroring on parse
- /// failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_shift_expr(const ast::Expression* lhs);
- /// Parses the `shift_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> shift_expression();
- /// Parses the recursive part of the `relational_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_relational_expr(const ast::Expression* lhs);
/// Parses the `expression` grammar rule
/// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> maybe_expression();
- /// Parses the `relational_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> relational_expression();
- /// Parses the recursive part of the `equality_expression`, erroring on parse
- /// failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_equality_expr(const ast::Expression* lhs);
- /// Parses the `equality_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> equality_expression();
+ Maybe<const ast::Expression*> expression();
/// Parses the `bitwise_expression.post.unary_expression` grammar element
/// @param lhs the left side of the expression
/// @returns the parsed expression or nullptr
@@ -692,7 +652,7 @@
Maybe<const ast::Expression*> element_count_expression();
/// Parses a `unary_expression shift.post.unary_expression`
/// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> maybe_shift_expression();
+ Maybe<const ast::Expression*> shift_expression();
/// Parses a `shift_expression.post.unary_expression` grammar element
/// @param lhs the left side of the expression
/// @returns the parsed expression or `lhs` if no match
@@ -700,7 +660,7 @@
const ast::Expression* lhs);
/// Parses a `unary_expression relational_expression.post.unary_expression`
/// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> maybe_relational_expression();
+ Maybe<const ast::Expression*> relational_expression();
/// Parses a `relational_expression.post.unary_expression` grammar element
/// @param lhs the left side of the expression
/// @returns the parsed expression or `lhs` if no match
@@ -709,49 +669,6 @@
/// Parse the `additive_operator` grammar element
/// @returns the parsed operator if successful
Maybe<ast::BinaryOp> additive_operator();
- /// Parses the recursive part of the `and_expression`, erroring on parse
- /// failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_and_expr(const ast::Expression* lhs);
- /// Parses the `and_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> and_expression();
- /// Parses the recursive part of the `exclusive_or_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_exclusive_or_expr(const ast::Expression* lhs);
- /// Parses the `exclusive_or_expression` grammar elememnt
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> exclusive_or_expression();
- /// Parses the recursive part of the `inclusive_or_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_inclusive_or_expr(const ast::Expression* lhs);
- /// Parses the `inclusive_or_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> inclusive_or_expression();
- /// Parses the recursive part of the `logical_and_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_logical_and_expr(const ast::Expression* lhs);
- /// Parses a `logical_and_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> logical_and_expression();
- /// Parses the recursive part of the `logical_or_expression`, erroring on
- /// parse failure.
- /// @param lhs the left side of the expression
- /// @returns the parsed expression or nullptr
- Expect<const ast::Expression*> expect_logical_or_expr(const ast::Expression* lhs);
- /// Parses a `logical_or_expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> logical_or_expression();
- /// Parses an `expression` grammar element
- /// @returns the parsed expression or nullptr
- Maybe<const ast::Expression*> expression();
/// Parses a `compound_assignment_operator` grammar element
/// @returns the parsed compound assignment operator
Maybe<ast::BinaryOp> compound_assignment_operator();
diff --git a/src/tint/reader/wgsl/parser_impl_additive_expression_test.cc b/src/tint/reader/wgsl/parser_impl_additive_expression_test.cc
index 3259e27..971c695 100644
--- a/src/tint/reader/wgsl/parser_impl_additive_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_additive_expression_test.cc
@@ -17,80 +17,6 @@
namespace tint::reader::wgsl {
namespace {
-TEST_F(ParserImplTest, AdditiveExpression_Orig_Parses_Plus) {
- auto p = parser("a + true");
- auto e = p->additive_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kAdd, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, AdditiveExpression_Orig_Parses_Minus) {
- auto p = parser("a - true");
- auto e = p->additive_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kSubtract, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, AdditiveExpression_Orig_InvalidLHS) {
- auto p = parser("if (a) {} + true");
- auto e = p->additive_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, AdditiveExpression_Orig_InvalidRHS) {
- auto p = parser("true + if (a) {}");
- auto e = p->additive_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of + expression");
-}
-
-TEST_F(ParserImplTest, AdditiveExpression_Orig_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->additive_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
TEST_F(ParserImplTest, AdditiveExpression_Parses_Plus) {
auto p = parser("a + b");
auto lhs = p->unary_expression();
diff --git a/src/tint/reader/wgsl/parser_impl_and_expression_test.cc b/src/tint/reader/wgsl/parser_impl_and_expression_test.cc
deleted file mode 100644
index fd90460..0000000
--- a/src/tint/reader/wgsl/parser_impl_and_expression_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, AndExpression_Parses) {
- auto p = parser("a & true");
- auto e = p->and_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kAnd, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, AndExpression_InvalidLHS) {
- auto p = parser("if (a) {} & true");
- auto e = p->and_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, AndExpression_InvalidRHS) {
- auto p = parser("true & if (a) {}");
- auto e = p->and_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of & expression");
-}
-
-TEST_F(ParserImplTest, AndExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->and_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_equality_expression_test.cc b/src/tint/reader/wgsl/parser_impl_equality_expression_test.cc
deleted file mode 100644
index 158227d..0000000
--- a/src/tint/reader/wgsl/parser_impl_equality_expression_test.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, EqualityExpression_Parses_Equal) {
- auto p = parser("a == true");
- auto e = p->equality_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kEqual, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, EqualityExpression_Parses_NotEqual) {
- auto p = parser("a != true");
- auto e = p->equality_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kNotEqual, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, EqualityExpression_InvalidLHS) {
- auto p = parser("if (a) {} == true");
- auto e = p->equality_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, EqualityExpression_InvalidRHS) {
- auto p = parser("true == if (a) {}");
- auto e = p->equality_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:9: unable to parse right side of == expression");
-}
-
-TEST_F(ParserImplTest, EqualityExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->equality_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_exclusive_or_expression_test.cc b/src/tint/reader/wgsl/parser_impl_exclusive_or_expression_test.cc
deleted file mode 100644
index 2994ae8..0000000
--- a/src/tint/reader/wgsl/parser_impl_exclusive_or_expression_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, ExclusiveOrExpression_Parses) {
- auto p = parser("a ^ true");
- auto e = p->exclusive_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kXor, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, ExclusiveOrExpression_InvalidLHS) {
- auto p = parser("if (a) {} ^ true");
- auto e = p->exclusive_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, ExclusiveOrExpression_InvalidRHS) {
- auto p = parser("true ^ if (a) {}");
- auto e = p->exclusive_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of ^ expression");
-}
-
-TEST_F(ParserImplTest, ExclusiveOrExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->exclusive_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_expression_test.cc b/src/tint/reader/wgsl/parser_impl_expression_test.cc
index 2938e3c..0adbeef 100644
--- a/src/tint/reader/wgsl/parser_impl_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_expression_test.cc
@@ -19,7 +19,7 @@
TEST_F(ParserImplTest, Expression_InvalidLHS) {
auto p = parser("if (a) {} || true");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -28,7 +28,7 @@
TEST_F(ParserImplTest, Expression_Or_Parses) {
auto p = parser("a || true");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -53,7 +53,7 @@
TEST_F(ParserImplTest, Expression_Or_Parses_Multiple) {
auto p = parser("a || true || b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -85,7 +85,7 @@
TEST_F(ParserImplTest, Expression_Or_InvalidRHS) {
auto p = parser("true || if (a) {}");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -95,7 +95,7 @@
TEST_F(ParserImplTest, Expression_And_Parses) {
auto p = parser("a && true");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -120,7 +120,7 @@
TEST_F(ParserImplTest, Expression_And_Parses_Multple) {
auto p = parser("a && true && b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -150,7 +150,7 @@
TEST_F(ParserImplTest, Expression_And_InvalidRHS) {
auto p = parser("true && if (a) {}");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -160,7 +160,7 @@
TEST_F(ParserImplTest, Expression_Mixing_OrWithAnd) {
auto p = parser("a && true || b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -170,7 +170,7 @@
TEST_F(ParserImplTest, Expression_Mixing_AndWithOr) {
auto p = parser("a || true && b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -180,7 +180,7 @@
TEST_F(ParserImplTest, Expression_Bitwise) {
auto p = parser("a & b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -201,7 +201,7 @@
TEST_F(ParserImplTest, Expression_Relational) {
auto p = parser("a <= b");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -222,7 +222,7 @@
TEST_F(ParserImplTest, Expression_InvalidUnary) {
auto p = parser("!if || true");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -232,7 +232,7 @@
TEST_F(ParserImplTest, Expression_InvalidBitwise) {
auto p = parser("a & if");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
@@ -242,7 +242,7 @@
TEST_F(ParserImplTest, Expression_InvalidRelational) {
auto p = parser("a <= if");
- auto e = p->maybe_expression();
+ auto e = p->expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_EQ(e.value, nullptr);
diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
index 8e56ee7..60e8f86 100644
--- a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
@@ -41,6 +41,35 @@
EXPECT_EQ(values[2], nullptr);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_Expression) {
+ auto p = parser("workgroup_size(4 + 2)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr);
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::BinaryExpression>());
+ auto* expr = values[0]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kAdd);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(values[1], nullptr);
+ EXPECT_EQ(values[2], nullptr);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) {
auto p = parser("workgroup_size(4,)");
auto attr = p->attribute();
@@ -99,6 +128,39 @@
EXPECT_EQ(values[2], nullptr);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) {
+ auto p = parser("workgroup_size(4, 5 - 2)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr) << p->error();
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[1]->Is<ast::BinaryExpression>());
+ auto* expr = values[1]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 5);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(values[2], nullptr);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5,)");
auto attr = p->attribute();
@@ -164,6 +226,42 @@
ast::IntLiteralExpression::Suffix::kNone);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) {
+ auto p = parser("workgroup_size(4, 5, 6 << 1)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr);
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[1]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->value, 5);
+ EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[2]->Is<ast::BinaryExpression>());
+ auto* expr = values[2]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 6);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 1);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5, 6,)");
auto attr = p->attribute();
diff --git a/src/tint/reader/wgsl/parser_impl_inclusive_or_expression_test.cc b/src/tint/reader/wgsl/parser_impl_inclusive_or_expression_test.cc
deleted file mode 100644
index f534ff7..0000000
--- a/src/tint/reader/wgsl/parser_impl_inclusive_or_expression_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, InclusiveOrExpression_Parses) {
- auto p = parser("a | true");
- auto e = p->inclusive_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kOr, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, InclusiveOrExpression_InvalidLHS) {
- auto p = parser("if (a) {} | true");
- auto e = p->inclusive_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, InclusiveOrExpression_InvalidRHS) {
- auto p = parser("true | if (a) {}");
- auto e = p->inclusive_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of | expression");
-}
-
-TEST_F(ParserImplTest, InclusiveOrExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->inclusive_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_logical_and_expression_test.cc b/src/tint/reader/wgsl/parser_impl_logical_and_expression_test.cc
deleted file mode 100644
index 8baadaf..0000000
--- a/src/tint/reader/wgsl/parser_impl_logical_and_expression_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, LogicalAndExpression_Parses) {
- auto p = parser("a && true");
- auto e = p->logical_and_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kLogicalAnd, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, LogicalAndExpression_InvalidLHS) {
- auto p = parser("if (a) {} && true");
- auto e = p->logical_and_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, LogicalAndExpression_InvalidRHS) {
- auto p = parser("true && if (a) {}");
- auto e = p->logical_and_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:9: unable to parse right side of && expression");
-}
-
-TEST_F(ParserImplTest, LogicalAndExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->logical_and_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_logical_or_expression_test.cc b/src/tint/reader/wgsl/parser_impl_logical_or_expression_test.cc
deleted file mode 100644
index 943b059..0000000
--- a/src/tint/reader/wgsl/parser_impl_logical_or_expression_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
-
-namespace tint::reader::wgsl {
-namespace {
-
-TEST_F(ParserImplTest, LogicalOrExpression_Parses) {
- auto p = parser("a || true");
- auto e = p->logical_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kLogicalOr, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, LogicalOrExpression_InvalidLHS) {
- auto p = parser("if (a) {} || true");
- auto e = p->logical_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, LogicalOrExpression_InvalidRHS) {
- auto p = parser("true || if (a) {}");
- auto e = p->logical_or_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:9: unable to parse right side of || expression");
-}
-
-TEST_F(ParserImplTest, LogicalOrExpression_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->logical_or_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
-} // namespace
-} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_multiplicative_expression_test.cc b/src/tint/reader/wgsl/parser_impl_multiplicative_expression_test.cc
index 618a0c0..9417477 100644
--- a/src/tint/reader/wgsl/parser_impl_multiplicative_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_multiplicative_expression_test.cc
@@ -17,100 +17,6 @@
namespace tint::reader::wgsl {
namespace {
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_Parses_Multiply) {
- auto p = parser("a * true");
- auto e = p->multiplicative_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kMultiply, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_Parses_Divide) {
- auto p = parser("a / true");
- auto e = p->multiplicative_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kDivide, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_Parses_Modulo) {
- auto p = parser("a % true");
- auto e = p->multiplicative_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kModulo, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_InvalidLHS) {
- auto p = parser("if (a) {} * true");
- auto e = p->multiplicative_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_InvalidRHS) {
- auto p = parser("true * if (a) {}");
- auto e = p->multiplicative_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_EQ(e.value, nullptr);
- ASSERT_TRUE(p->has_error());
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of * expression");
-}
-
-TEST_F(ParserImplTest, MultiplicativeExpression_Orig_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->multiplicative_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
TEST_F(ParserImplTest, MultiplicativeExpression_Parses_Multiply) {
auto p = parser("a * b");
auto lhs = p->unary_expression();
diff --git a/src/tint/reader/wgsl/parser_impl_relational_expression_test.cc b/src/tint/reader/wgsl/parser_impl_relational_expression_test.cc
index 59cddc6..ebd8ffd 100644
--- a/src/tint/reader/wgsl/parser_impl_relational_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_relational_expression_test.cc
@@ -17,133 +17,6 @@
namespace tint::reader::wgsl {
namespace {
-TEST_F(ParserImplTest, RelationalExpression_Orig_Parses_LessThan) {
- auto p = parser("a < true");
- auto e = p->relational_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kLessThan, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_Parses_GreaterThan) {
- auto p = parser("a > true");
- auto e = p->relational_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 4u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kGreaterThan, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_Parses_LessThanEqual) {
- auto p = parser("a <= true");
- auto e = p->relational_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kLessThanEqual, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_Parses_GreaterThanEqual) {
- auto p = parser("a >= true");
- auto e = p->relational_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kGreaterThanEqual, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Register("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_InvalidLHS) {
- auto p = parser("if (a) {} < true");
- auto e = p->relational_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_InvalidRHS) {
- auto p = parser("true < if (a) {}");
- auto e = p->relational_expression();
- ASSERT_TRUE(p->has_error());
- EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:8: unable to parse right side of < expression");
-}
-
-TEST_F(ParserImplTest, RelationalExpression_Orig_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->relational_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
TEST_F(ParserImplTest, RelationalExpression_PostUnary_Parses_LessThan) {
auto p = parser("a < true");
auto lhs = p->unary_expression();
@@ -315,7 +188,7 @@
TEST_F(ParserImplTest, RelationalExpression_Matches) {
auto p = parser("a >= true");
- auto e = p->maybe_relational_expression();
+ auto e = p->relational_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -335,7 +208,7 @@
TEST_F(ParserImplTest, RelationalExpression_InvalidLHS) {
auto p = parser("if (a) {}< 3");
- auto e = p->maybe_relational_expression();
+ auto e = p->relational_expression();
ASSERT_FALSE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_FALSE(p->has_error()) << p->error();
@@ -344,7 +217,7 @@
TEST_F(ParserImplTest, RelationalExpression_InvalidRHS) {
auto p = parser("true < if (a) {}");
- auto e = p->maybe_relational_expression();
+ auto e = p->relational_expression();
ASSERT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
ASSERT_TRUE(p->has_error());
@@ -352,5 +225,65 @@
EXPECT_EQ(p->error(), "1:8: unable to parse right side of < expression");
}
+TEST_F(ParserImplTest, RelationalExpression_Parses_Equal) {
+ auto p = parser("a == true");
+ auto e = p->relational_expression();
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 3u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 5u);
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kEqual, rel->op);
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, RelationalExpression_Parses_NotEqual) {
+ auto p = parser("a != true");
+ auto e = p->relational_expression();
+ EXPECT_TRUE(e.matched);
+ EXPECT_FALSE(e.errored);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(e.value, nullptr);
+
+ EXPECT_EQ(e->source.range.begin.line, 1u);
+ EXPECT_EQ(e->source.range.begin.column, 3u);
+ EXPECT_EQ(e->source.range.end.line, 1u);
+ EXPECT_EQ(e->source.range.end.column, 5u);
+
+ ASSERT_TRUE(e->Is<ast::BinaryExpression>());
+ auto* rel = e->As<ast::BinaryExpression>();
+ EXPECT_EQ(ast::BinaryOp::kNotEqual, rel->op);
+
+ ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
+ auto* ident = rel->lhs->As<ast::IdentifierExpression>();
+ EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
+
+ ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
+ ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
+}
+
+TEST_F(ParserImplTest, RelationalExpression_Equal_InvalidRHS) {
+ auto p = parser("true == if (a) {}");
+ auto e = p->relational_expression();
+ EXPECT_FALSE(e.matched);
+ EXPECT_TRUE(e.errored);
+ EXPECT_EQ(e.value, nullptr);
+ EXPECT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:9: unable to parse right side of == expression");
+}
+
} // namespace
} // namespace tint::reader::wgsl
diff --git a/src/tint/reader/wgsl/parser_impl_shift_expression_test.cc b/src/tint/reader/wgsl/parser_impl_shift_expression_test.cc
index d518004..016f86f 100644
--- a/src/tint/reader/wgsl/parser_impl_shift_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_shift_expression_test.cc
@@ -17,103 +17,6 @@
namespace tint::reader::wgsl {
namespace {
-TEST_F(ParserImplTest, ShiftExpression_Orig_Parses_ShiftLeft) {
- auto p = parser("a << true");
- auto e = p->shift_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kShiftLeft, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_Parses_ShiftRight) {
- auto p = parser("a >> true");
- auto e = p->shift_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
-
- EXPECT_EQ(e->source.range.begin.line, 1u);
- EXPECT_EQ(e->source.range.begin.column, 3u);
- EXPECT_EQ(e->source.range.end.line, 1u);
- EXPECT_EQ(e->source.range.end.column, 5u);
-
- ASSERT_TRUE(e->Is<ast::BinaryExpression>());
- auto* rel = e->As<ast::BinaryExpression>();
- EXPECT_EQ(ast::BinaryOp::kShiftRight, rel->op);
-
- ASSERT_TRUE(rel->lhs->Is<ast::IdentifierExpression>());
- auto* ident = rel->lhs->As<ast::IdentifierExpression>();
- EXPECT_EQ(ident->symbol, p->builder().Symbols().Get("a"));
-
- ASSERT_TRUE(rel->rhs->Is<ast::BoolLiteralExpression>());
- ASSERT_TRUE(rel->rhs->As<ast::BoolLiteralExpression>()->value);
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_InvalidSpaceLeft) {
- auto p = parser("a < < true");
- auto e = p->shift_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- ASSERT_NE(e.value, nullptr);
- EXPECT_FALSE(e.value->Is<ast::BinaryExpression>());
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_InvalidSpaceRight) {
- auto p = parser("a > > true");
- auto e = p->shift_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- ASSERT_NE(e.value, nullptr);
- EXPECT_FALSE(e.value->Is<ast::BinaryExpression>());
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_InvalidLHS) {
- auto p = parser("if (a) {} << true");
- auto e = p->shift_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- EXPECT_EQ(e.value, nullptr);
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_InvalidRHS) {
- auto p = parser("true << if (a) {}");
- auto e = p->shift_expression();
- EXPECT_FALSE(e.matched);
- EXPECT_TRUE(e.errored);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:9: unable to parse right side of << expression");
-}
-
-TEST_F(ParserImplTest, ShiftExpression_Orig_NoOr_ReturnsLHS) {
- auto p = parser("a true");
- auto e = p->shift_expression();
- EXPECT_TRUE(e.matched);
- EXPECT_FALSE(e.errored);
- EXPECT_FALSE(p->has_error()) << p->error();
- ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::IdentifierExpression>());
-}
-
TEST_F(ParserImplTest, ShiftExpression_PostUnary_Parses_ShiftLeft) {
auto p = parser("a << true");
auto lhs = p->unary_expression();
@@ -231,7 +134,7 @@
EXPECT_TRUE(e.errored);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(e.value, nullptr);
- EXPECT_EQ(p->error(), "1:3: unable to parse right side of << expression");
+ EXPECT_EQ(p->error(), "1:6: unable to parse right side of << expression");
}
TEST_F(ParserImplTest, ShiftExpression_PostUnary_NoOr_ReturnsLHS) {
@@ -246,7 +149,7 @@
TEST_F(ParserImplTest, ShiftExpression_Parses) {
auto p = parser("a << true");
- auto e = p->maybe_shift_expression();
+ auto e = p->shift_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -266,7 +169,7 @@
TEST_F(ParserImplTest, ShiftExpression_Invalid_Unary) {
auto p = parser("if >> true");
- auto e = p->maybe_shift_expression();
+ auto e = p->shift_expression();
EXPECT_FALSE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -275,7 +178,7 @@
TEST_F(ParserImplTest, ShiftExpression_Inavlid_ShiftExpressionPostUnary) {
auto p = parser("a * if (a) {}");
- auto e = p->maybe_shift_expression();
+ auto e = p->shift_expression();
EXPECT_FALSE(e.matched);
EXPECT_TRUE(e.errored);
EXPECT_TRUE(p->has_error());
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
index df75f96..4fb9528 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
@@ -102,7 +102,10 @@
ASSERT_TRUE(member_attr->Is<ast::StructMemberAlignAttribute>());
auto* o = member_attr->As<ast::StructMemberAlignAttribute>();
- EXPECT_EQ(o->align, 4u);
+ ASSERT_TRUE(o->align->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(o->align->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(o->align->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, Attribute_Align_TrailingComma) {
@@ -118,7 +121,11 @@
ASSERT_TRUE(member_attr->Is<ast::StructMemberAlignAttribute>());
auto* o = member_attr->As<ast::StructMemberAlignAttribute>();
- EXPECT_EQ(o->align, 4u);
+ ASSERT_TRUE(o->align->Is<ast::IntLiteralExpression>());
+
+ auto* expr = o->align->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 4);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
}
TEST_F(ParserImplTest, Attribute_Align_MissingLeftParen) {
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
index d2ab916..c64f772 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
@@ -49,7 +49,13 @@
EXPECT_TRUE(m->type->Is<ast::I32>());
EXPECT_EQ(m->attributes.Length(), 1u);
EXPECT_TRUE(m->attributes[0]->Is<ast::StructMemberAlignAttribute>());
- EXPECT_EQ(m->attributes[0]->As<ast::StructMemberAlignAttribute>()->align, 2u);
+
+ auto* attr = m->attributes[0]->As<ast::StructMemberAlignAttribute>();
+ ASSERT_TRUE(attr->align->Is<ast::IntLiteralExpression>());
+
+ auto* expr = attr->align->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 2);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(m->source.range, (Source::Range{{1u, 11u}, {1u, 12u}}));
EXPECT_EQ(m->type->source.range, (Source::Range{{1u, 15u}, {1u, 18u}}));
@@ -89,10 +95,16 @@
EXPECT_EQ(m->symbol, builder.Symbols().Get("a"));
EXPECT_TRUE(m->type->Is<ast::I32>());
EXPECT_EQ(m->attributes.Length(), 2u);
- EXPECT_TRUE(m->attributes[0]->Is<ast::StructMemberSizeAttribute>());
+ ASSERT_TRUE(m->attributes[0]->Is<ast::StructMemberSizeAttribute>());
EXPECT_EQ(m->attributes[0]->As<ast::StructMemberSizeAttribute>()->size, 2u);
- EXPECT_TRUE(m->attributes[1]->Is<ast::StructMemberAlignAttribute>());
- EXPECT_EQ(m->attributes[1]->As<ast::StructMemberAlignAttribute>()->align, 4u);
+
+ ASSERT_TRUE(m->attributes[1]->Is<ast::StructMemberAlignAttribute>());
+ auto* attr = m->attributes[1]->As<ast::StructMemberAlignAttribute>();
+
+ ASSERT_TRUE(attr->align->Is<ast::IntLiteralExpression>());
+ auto* expr = attr->align->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 4);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
EXPECT_EQ(m->source.range, (Source::Range{{2u, 11u}, {2u, 12u}}));
EXPECT_EQ(m->type->source.range, (Source::Range{{2u, 15u}, {2u, 18u}}));
diff --git a/src/tint/reader/wgsl/parser_impl_unary_expression_test.cc b/src/tint/reader/wgsl/parser_impl_unary_expression_test.cc
index 184de66..29598d9 100644
--- a/src/tint/reader/wgsl/parser_impl_unary_expression_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_unary_expression_test.cc
@@ -86,7 +86,7 @@
TEST_F(ParserImplTest, UnaryExpression_AddressOf_Precedence) {
auto p = parser("&x.y");
- auto e = p->logical_or_expression();
+ auto e = p->unary_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
@@ -100,7 +100,7 @@
TEST_F(ParserImplTest, UnaryExpression_Dereference_Precedence) {
auto p = parser("*x.y");
- auto e = p->logical_or_expression();
+ auto e = p->unary_expression();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 4f6545c..70d4612 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -89,7 +89,7 @@
AttributeKind kind) {
switch (kind) {
case AttributeKind::kAlign:
- return {builder.create<ast::StructMemberAlignAttribute>(source, 4u)};
+ return {builder.MemberAlign(source, 4_u)};
case AttributeKind::kBinding:
return {builder.create<ast::BindingAttribute>(source, 1u)};
case AttributeKind::kBuiltin:
@@ -625,14 +625,13 @@
TestParams{AttributeKind::kWorkgroup, false},
TestParams{AttributeKind::kBindingAndGroup, false}));
TEST_F(StructMemberAttributeTest, DuplicateAttribute) {
- Structure("mystruct",
- utils::Vector{
- Member("a", ty.i32(),
- utils::Vector{
- create<ast::StructMemberAlignAttribute>(Source{{12, 34}}, 4u),
- create<ast::StructMemberAlignAttribute>(Source{{56, 78}}, 8u),
- }),
- });
+ Structure("mystruct", utils::Vector{
+ Member("a", ty.i32(),
+ utils::Vector{
+ MemberAlign(Source{{12, 34}}, 4_u),
+ MemberAlign(Source{{56, 78}}, 8_u),
+ }),
+ });
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(56:78 error: duplicate align attribute
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 1e66e5a..9bc1d47 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -323,7 +323,7 @@
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
- utils::HashCombine(&h, el->Hash());
+ h = utils::HashCombine(h, el->Hash());
}
return h;
}
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index f699504..22f2620 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -545,6 +545,19 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) {
+ // @compute @workgroup_size(1 + 2)
+ // fn main() {}
+
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
// @compute @workgroup_size(1u, 2, 3_i)
// fn main() {}
@@ -750,13 +763,43 @@
"overridable of type abstract-integer, i32 or u32");
}
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
- // @compute @workgroup_size(i32(1))
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
+ // @compute @workgroup_size(1 << 2 + 4)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either a literal, constant, or "
+ "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
+ // @compute @workgroup_size(1, 1 << 2 + 4)
+ // fn main() {}
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either a literal, constant, or "
+ "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
+ // @compute @workgroup_size(1, 1, 1 << 2 + 4)
+ // fn main() {}
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 67a75c3..0cd0abe 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -935,7 +935,7 @@
inline std::size_t operator()(const IntrinsicPrototype& i) const {
size_t hash = utils::Hash(i.parameters.Length());
for (auto& p : i.parameters) {
- utils::HashCombine(&hash, p.type, p.usage);
+ hash = utils::HashCombine(hash, p.type, p.usage);
}
return utils::Hash(hash, i.overload, i.return_type);
}
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index cd8bdc8..90ec254 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -578,8 +578,21 @@
sem::Variable* sem = nullptr;
if (is_global) {
sem::BindingPoint binding_point;
- if (auto bp = var->BindingPoint()) {
- binding_point = {bp.group->value, bp.binding->value};
+ if (var->HasBindingPoint()) {
+ uint32_t binding = 0;
+ {
+ auto* attr = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
+ // TODO(dsinclair): Materialize when binding attribute is an expression
+ binding = attr->value;
+ }
+
+ uint32_t group = 0;
+ {
+ auto* attr = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
+ // TODO(dsinclair): Materialize when group attribute is an expression
+ group = attr->value;
+ }
+ binding_point = {group, binding};
}
sem = builder_->create<sem::GlobalVariable>(var, var_ty, sem::EvaluationStage::kRuntime,
storage_class, access,
@@ -629,8 +642,23 @@
}
}
+ sem::BindingPoint binding_point;
+ if (param->HasBindingPoint()) {
+ {
+ auto* attr = ast::GetAttribute<ast::BindingAttribute>(param->attributes);
+ // TODO(dsinclair): Materialize the binding information
+ binding_point.binding = attr->value;
+ }
+ {
+ auto* attr = ast::GetAttribute<ast::GroupAttribute>(param->attributes);
+ // TODO(dsinclair): Materialize the group information
+ binding_point.group = attr->value;
+ }
+ }
+
auto* sem = builder_->create<sem::Parameter>(param, index, ty, ast::StorageClass::kNone,
- ast::Access::kUndefined);
+ ast::Access::kUndefined,
+ sem::ParameterUsage::kNone, binding_point);
builder_->Sem().Add(param, sem);
return sem;
}
@@ -937,7 +965,7 @@
for (size_t i = 0; i < 3; i++) {
// Each argument to this attribute can either be a literal, an identifier for a module-scope
- // constants, or nullptr if not specified.
+ // constants, a constant expression, or nullptr if not specified.
auto* value = values[i];
if (!value) {
break;
@@ -995,7 +1023,7 @@
ws[i].value = 0;
continue;
}
- } else if (values[i]->Is<ast::LiteralExpression>()) {
+ } else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
value = materialized->ConstantValue();
} else {
AddError(kErrBadExpr, values[i]->source);
@@ -2183,121 +2211,124 @@
auto* object = sem_.Get(expr->structure);
auto* source_var = object->SourceVariable();
- const sem::Type* ret = nullptr;
- utils::Vector<uint32_t, 4> swizzle;
+ const sem::Type* ty = nullptr;
// Object may be a side-effecting expression (e.g. function call).
bool has_side_effects = object && object->HasSideEffects();
- if (auto* str = storage_ty->As<sem::Struct>()) {
- Mark(expr->member);
- auto symbol = expr->member->symbol;
+ return Switch(
+ storage_ty, //
+ [&](const sem::Struct* str) -> sem::Expression* {
+ Mark(expr->member);
+ auto symbol = expr->member->symbol;
- const sem::StructMember* member = nullptr;
- for (auto* m : str->Members()) {
- if (m->Name() == symbol) {
- ret = m->Type();
- member = m;
- break;
- }
- }
-
- if (ret == nullptr) {
- AddError("struct member " + builder_->Symbols().NameFor(symbol) + " not found",
- expr->source);
- return nullptr;
- }
-
- // If we're extracting from a reference, we return a reference.
- if (auto* ref = structure->As<sem::Reference>()) {
- ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
- }
-
- const sem::Constant* val = nullptr;
- if (auto r = const_eval_.MemberAccess(object, member)) {
- val = r.Get();
- } else {
- return nullptr;
- }
- return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
- member, has_side_effects, source_var);
- }
-
- if (auto* vec = storage_ty->As<sem::Vector>()) {
- Mark(expr->member);
- std::string s = builder_->Symbols().NameFor(expr->member->symbol);
- auto size = s.size();
- swizzle.Reserve(s.size());
-
- for (auto c : s) {
- switch (c) {
- case 'x':
- case 'r':
- swizzle.Push(0u);
+ const sem::StructMember* member = nullptr;
+ for (auto* m : str->Members()) {
+ if (m->Name() == symbol) {
+ ty = m->Type();
+ member = m;
break;
- case 'y':
- case 'g':
- swizzle.Push(1u);
- break;
- case 'z':
- case 'b':
- swizzle.Push(2u);
- break;
- case 'w':
- case 'a':
- swizzle.Push(3u);
- break;
- default:
- AddError("invalid vector swizzle character",
- expr->member->source.Begin() + swizzle.Length());
- return nullptr;
+ }
}
- if (swizzle.Back() >= vec->Width()) {
- AddError("invalid vector swizzle member", expr->member->source);
+ if (ty == nullptr) {
+ AddError("struct member " + builder_->Symbols().NameFor(symbol) + " not found",
+ expr->source);
return nullptr;
}
- }
- if (size < 1 || size > 4) {
- AddError("invalid vector swizzle size", expr->member->source);
- return nullptr;
- }
-
- // All characters are valid, check if they're being mixed
- auto is_rgba = [](char c) { return c == 'r' || c == 'g' || c == 'b' || c == 'a'; };
- auto is_xyzw = [](char c) { return c == 'x' || c == 'y' || c == 'z' || c == 'w'; };
- if (!std::all_of(s.begin(), s.end(), is_rgba) &&
- !std::all_of(s.begin(), s.end(), is_xyzw)) {
- AddError("invalid mixing of vector swizzle characters rgba with xyzw",
- expr->member->source);
- return nullptr;
- }
-
- if (size == 1) {
- // A single element swizzle is just the type of the vector.
- ret = vec->type();
// If we're extracting from a reference, we return a reference.
if (auto* ref = structure->As<sem::Reference>()) {
- ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
+ ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
}
- } else {
- // The vector will have a number of components equal to the length of
- // the swizzle.
- ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
- }
- if (auto r = const_eval_.Swizzle(ret, object, swizzle)) {
- auto* val = r.Get();
- return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object,
- std::move(swizzle), has_side_effects, source_var);
- }
- return nullptr;
- }
- AddError("invalid member accessor expression. Expected vector or struct, got '" +
- sem_.TypeNameOf(storage_ty) + "'",
- expr->structure->source);
- return nullptr;
+ auto val = const_eval_.MemberAccess(object, member);
+ if (!val) {
+ return nullptr;
+ }
+ return builder_->create<sem::StructMemberAccess>(expr, ty, current_statement_,
+ val.Get(), object, member,
+ has_side_effects, source_var);
+ },
+
+ [&](const sem::Vector* vec) -> sem::Expression* {
+ Mark(expr->member);
+ std::string s = builder_->Symbols().NameFor(expr->member->symbol);
+ auto size = s.size();
+ utils::Vector<uint32_t, 4> swizzle;
+ swizzle.Reserve(s.size());
+
+ for (auto c : s) {
+ switch (c) {
+ case 'x':
+ case 'r':
+ swizzle.Push(0u);
+ break;
+ case 'y':
+ case 'g':
+ swizzle.Push(1u);
+ break;
+ case 'z':
+ case 'b':
+ swizzle.Push(2u);
+ break;
+ case 'w':
+ case 'a':
+ swizzle.Push(3u);
+ break;
+ default:
+ AddError("invalid vector swizzle character",
+ expr->member->source.Begin() + swizzle.Length());
+ return nullptr;
+ }
+
+ if (swizzle.Back() >= vec->Width()) {
+ AddError("invalid vector swizzle member", expr->member->source);
+ return nullptr;
+ }
+ }
+
+ if (size < 1 || size > 4) {
+ AddError("invalid vector swizzle size", expr->member->source);
+ return nullptr;
+ }
+
+ // All characters are valid, check if they're being mixed
+ auto is_rgba = [](char c) { return c == 'r' || c == 'g' || c == 'b' || c == 'a'; };
+ auto is_xyzw = [](char c) { return c == 'x' || c == 'y' || c == 'z' || c == 'w'; };
+ if (!std::all_of(s.begin(), s.end(), is_rgba) &&
+ !std::all_of(s.begin(), s.end(), is_xyzw)) {
+ AddError("invalid mixing of vector swizzle characters rgba with xyzw",
+ expr->member->source);
+ return nullptr;
+ }
+
+ if (size == 1) {
+ // A single element swizzle is just the type of the vector.
+ ty = vec->type();
+ // If we're extracting from a reference, we return a reference.
+ if (auto* ref = structure->As<sem::Reference>()) {
+ ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
+ }
+ } else {
+ // The vector will have a number of components equal to the length of
+ // the swizzle.
+ ty = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
+ }
+ auto val = const_eval_.Swizzle(ty, object, swizzle);
+ if (!val) {
+ return nullptr;
+ }
+ return builder_->create<sem::Swizzle>(expr, ty, current_statement_, val.Get(), object,
+ std::move(swizzle), has_side_effects, source_var);
+ },
+
+ [&](Default) {
+ AddError("invalid member accessor expression. Expected vector or struct, got '" +
+ sem_.TypeNameOf(storage_ty) + "'",
+ expr->structure->source);
+ return nullptr;
+ });
}
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
@@ -2673,11 +2704,26 @@
align = 1;
has_offset_attr = true;
} else if (auto* a = attr->As<ast::StructMemberAlignAttribute>()) {
- if (a->align <= 0 || !utils::IsPowerOfTwo(a->align)) {
+ const auto* expr = Expression(a->align);
+ if (!expr) {
+ return nullptr;
+ }
+ auto* materialized = Materialize(expr);
+ if (!materialized) {
+ return nullptr;
+ }
+ auto const_value = materialized->ConstantValue();
+ if (!const_value) {
+ AddError("'align' must be constant expression", a->align->source);
+ return nullptr;
+ }
+ auto value = const_value->As<AInt>();
+
+ if (value <= 0 || !utils::IsPowerOfTwo(value)) {
AddError("align value must be a positive, power-of-two integer", a->source);
return nullptr;
}
- align = a->align;
+ align = const_value->As<u32>();
has_align_attr = true;
} else if (auto* s = attr->As<ast::StructMemberSizeAttribute>()) {
if (s->size < size) {
diff --git a/src/tint/resolver/storage_class_layout_validation_test.cc b/src/tint/resolver/storage_class_layout_validation_test.cc
index 9c16ec4..51b8ecc 100644
--- a/src/tint/resolver/storage_class_layout_validation_test.cc
+++ b/src/tint/resolver/storage_class_layout_validation_test.cc
@@ -36,7 +36,7 @@
Structure(Source{{12, 34}}, "S",
utils::Vector{
Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
- Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1)}),
+ Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1_u)}),
});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kStorage, Group(0),
@@ -66,7 +66,7 @@
Structure(Source{{12, 34}}, "S",
utils::Vector{
Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
- Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(4)}),
+ Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(4_u)}),
});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kStorage, Group(0),
@@ -142,7 +142,7 @@
utils::Vector{
Member("scalar", ty.f32()),
Member(Source{{56, 78}}, "inner", ty.type_name("Inner"),
- utils::Vector{MemberAlign(16)}),
+ utils::Vector{MemberAlign(16_u)}),
});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("Outer"), ast::StorageClass::kUniform, Group(0),
@@ -201,7 +201,7 @@
utils::Vector{
Member("scalar", ty.f32()),
Member(Source{{34, 56}}, "inner", ty.type_name("Inner"),
- utils::Vector{MemberAlign(16)}),
+ utils::Vector{MemberAlign(16_u)}),
});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("Outer"), ast::StorageClass::kUniform, Group(0),
@@ -227,7 +227,7 @@
Structure(Source{{12, 34}}, "Inner",
utils::Vector{
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
});
Structure(Source{{34, 56}}, "Outer",
@@ -279,7 +279,7 @@
Member("a", ty.i32()),
Member("b", ty.i32()),
Member("c", ty.i32()),
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
});
Structure(Source{{34, 56}}, "Outer",
@@ -327,13 +327,13 @@
Structure(Source{{12, 34}}, "Inner",
utils::Vector{
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
});
Structure(Source{{34, 56}}, "Outer",
utils::Vector{
Member(Source{{56, 78}}, "inner", ty.type_name("Inner")),
- Member(Source{{78, 90}}, "scalar", ty.i32(), utils::Vector{MemberAlign(16)}),
+ Member(Source{{78, 90}}, "scalar", ty.i32(), utils::Vector{MemberAlign(16_u)}),
});
GlobalVar(Source{{22, 34}}, "a", ty.type_name("Outer"), ast::StorageClass::kUniform, Group(0),
@@ -551,7 +551,7 @@
Structure(
Source{{12, 34}}, "S",
utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
- Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1)})});
+ Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1_u)})});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kPushConstant);
ASSERT_FALSE(r()->Resolve());
@@ -576,7 +576,7 @@
// var<push_constant> a : S;
Enable(ast::Extension::kChromiumExperimentalPushConstant);
Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
- Member("b", ty.f32(), utils::Vector{MemberAlign(4)})});
+ Member("b", ty.f32(), utils::Vector{MemberAlign(4_u)})});
GlobalVar("a", ty.type_name("S"), ast::StorageClass::kPushConstant);
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/struct_layout_test.cc b/src/tint/resolver/struct_layout_test.cc
index 59faf8d..ae3972b 100644
--- a/src/tint/resolver/struct_layout_test.cc
+++ b/src/tint/resolver/struct_layout_test.cc
@@ -498,15 +498,15 @@
TEST_F(ResolverStructLayoutTest, AlignAttributes) {
auto* inner = Structure("Inner", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberAlign(8)}),
- Member("b", ty.f32(), utils::Vector{MemberAlign(16)}),
- Member("c", ty.f32(), utils::Vector{MemberAlign(4)}),
+ Member("a", ty.f32(), utils::Vector{MemberAlign(8_u)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(16_u)}),
+ Member("c", ty.f32(), utils::Vector{MemberAlign(4_u)}),
});
auto* s = Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberAlign(4)}),
- Member("b", ty.u32(), utils::Vector{MemberAlign(8)}),
+ Member("a", ty.f32(), utils::Vector{MemberAlign(4_u)}),
+ Member("b", ty.u32(), utils::Vector{MemberAlign(8_u)}),
Member("c", ty.Of(inner)),
- Member("d", ty.i32(), utils::Vector{MemberAlign(32)}),
+ Member("d", ty.i32(), utils::Vector{MemberAlign(32_u)}),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -536,7 +536,7 @@
TEST_F(ResolverStructLayoutTest, StructWithLotsOfPadding) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberAlign(1024)}),
+ Member("a", ty.i32(), utils::Vector{MemberAlign(1024_u)}),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index 623ff1c..cbd05c1 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -1239,7 +1239,7 @@
TEST_F(ResolverValidationTest, NonPOTStructMemberAlignAttribute) {
Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberAlign(Source{{12, 34}}, 3)}),
+ Member("a", ty.f32(), utils::Vector{MemberAlign(Source{{12, 34}}, 3_u)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1248,7 +1248,7 @@
TEST_F(ResolverValidationTest, ZeroStructMemberAlignAttribute) {
Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberAlign(Source{{12, 34}}, 0)}),
+ Member("a", ty.f32(), utils::Vector{MemberAlign(Source{{12, 34}}, 0_u)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1279,7 +1279,7 @@
TEST_F(ResolverValidationTest, OffsetAndAlignAttribute) {
Structure("S", utils::Vector{
Member(Source{{12, 34}}, "a", ty.f32(),
- utils::Vector{MemberOffset(0), MemberAlign(4)}),
+ utils::Vector{MemberOffset(0), MemberAlign(4_u)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1291,7 +1291,7 @@
TEST_F(ResolverValidationTest, OffsetAndAlignAndSizeAttribute) {
Structure("S", utils::Vector{
Member(Source{{12, 34}}, "a", ty.f32(),
- utils::Vector{MemberOffset(0), MemberAlign(4), MemberSize(4)}),
+ utils::Vector{MemberOffset(0), MemberAlign(4_u), MemberSize(4)}),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 6382d34..08451f0 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -664,7 +664,6 @@
return false;
}
- auto binding_point = decl->BindingPoint();
switch (global->StorageClass()) {
case ast::StorageClass::kUniform:
case ast::StorageClass::kStorage:
@@ -672,20 +671,23 @@
// https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
// Each resource variable must be declared with both group and binding
// attributes.
- if (!binding_point) {
+ if (!decl->HasBindingPoint()) {
AddError("resource variables require @group and @binding attributes", decl->source);
return false;
}
break;
}
- default:
- if (binding_point.binding || binding_point.group) {
+ default: {
+ auto* binding_attr = ast::GetAttribute<ast::BindingAttribute>(decl->attributes);
+ auto* group_attr = ast::GetAttribute<ast::GroupAttribute>(decl->attributes);
+ if (binding_attr || group_attr) {
// https://gpuweb.github.io/gpuweb/wgsl/#attribute-binding
// Must only be applied to a resource variable
AddError("non-resource variables must not have @group or @binding attributes",
decl->source);
return false;
}
+ }
}
return true;
@@ -1351,7 +1353,7 @@
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = global->Declaration()->As<ast::Var>();
- if (!var_decl || !var_decl->BindingPoint()) {
+ if (!var_decl || !var_decl->HasBindingPoint()) {
continue;
}
auto bp = global->BindingPoint();
diff --git a/src/tint/sem/call_target.cc b/src/tint/sem/call_target.cc
index 362482c..e3b03d4 100644
--- a/src/tint/sem/call_target.cc
+++ b/src/tint/sem/call_target.cc
@@ -70,7 +70,7 @@
const tint::sem::CallTargetSignature& sig) const {
size_t hash = tint::utils::Hash(sig.parameters.Length());
for (auto* p : sig.parameters) {
- tint::utils::HashCombine(&hash, p->Type(), p->Usage());
+ hash = tint::utils::HashCombine(hash, p->Type(), p->Usage());
}
return tint::utils::Hash(hash, sig.return_type);
}
diff --git a/src/tint/sem/function.cc b/src/tint/sem/function.cc
index ff3a2a7..97171fe 100644
--- a/src/tint/sem/function.cc
+++ b/src/tint/sem/function.cc
@@ -70,8 +70,8 @@
continue;
}
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- ret.push_back({global, binding_point});
+ if (global->Declaration()->HasBindingPoint()) {
+ ret.push_back({global, global->BindingPoint()});
}
}
return ret;
@@ -85,8 +85,8 @@
continue;
}
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- ret.push_back({global, binding_point});
+ if (global->Declaration()->HasBindingPoint()) {
+ ret.push_back({global, global->BindingPoint()});
}
}
return ret;
@@ -129,8 +129,8 @@
for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = global->Type()->UnwrapRef();
if (unwrapped_type->TypeInfo().Is(type)) {
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- ret.push_back({global, binding_point});
+ if (global->Declaration()->HasBindingPoint()) {
+ ret.push_back({global, global->BindingPoint()});
}
}
}
@@ -157,8 +157,8 @@
continue;
}
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- ret.push_back({global, binding_point});
+ if (global->Declaration()->HasBindingPoint()) {
+ ret.push_back({global, global->BindingPoint()});
}
}
return ret;
@@ -182,8 +182,8 @@
continue;
}
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- ret.push_back({global, binding_point});
+ if (global->Declaration()->HasBindingPoint()) {
+ ret.push_back({global, global->BindingPoint()});
}
}
diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h
index a09f81c..d4cb1c7 100644
--- a/src/tint/sem/function.h
+++ b/src/tint/sem/function.h
@@ -54,8 +54,8 @@
/// Function holds the semantic information for function nodes.
class Function final : public Castable<Function, CallTarget> {
public:
- /// A vector of [Variable*, ast::VariableBindingPoint] pairs
- using VariableBindings = std::vector<std::pair<const Variable*, ast::VariableBindingPoint>>;
+ /// A vector of [Variable*, sem::BindingPoint] pairs
+ using VariableBindings = std::vector<std::pair<const Variable*, sem::BindingPoint>>;
/// Constructor
/// @param declaration the ast::Function
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index 28a373b..f233053 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -72,10 +72,12 @@
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
- const ParameterUsage usage /* = ParameterUsage::kNone */)
+ const ParameterUsage usage /* = ParameterUsage::kNone */,
+ sem::BindingPoint binding_point /* = {} */)
: Base(declaration, type, EvaluationStage::kRuntime, storage_class, access, nullptr),
index_(index),
- usage_(usage) {}
+ usage_(usage),
+ binding_point_(binding_point) {}
Parameter::~Parameter() = default;
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 2bf6d23..b511563 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -188,12 +188,14 @@
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param usage the semantic usage for the parameter
+ /// @param binding_point the optional resource binding point of the parameter
Parameter(const ast::Parameter* declaration,
uint32_t index,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
- const ParameterUsage usage = ParameterUsage::kNone);
+ const ParameterUsage usage = ParameterUsage::kNone,
+ sem::BindingPoint binding_point = {});
/// Destructor
~Parameter() override;
@@ -217,11 +219,15 @@
/// @param shadows the Type, Function or Variable that this variable shadows
void SetShadows(const sem::Node* shadows) { shadows_ = shadows; }
+ /// @returns the resource binding point for the parameter
+ sem::BindingPoint BindingPoint() const { return binding_point_; }
+
private:
const uint32_t index_;
const ParameterUsage usage_;
CallTarget const* owner_ = nullptr;
const sem::Node* shadows_ = nullptr;
+ const sem::BindingPoint binding_point_;
};
/// VariableUser holds the semantic information for an identifier expression
diff --git a/src/tint/transform/add_spirv_block_attribute.cc b/src/tint/transform/add_spirv_block_attribute.cc
index 3615812..25abdce 100644
--- a/src/tint/transform/add_spirv_block_attribute.cc
+++ b/src/tint/transform/add_spirv_block_attribute.cc
@@ -14,13 +14,12 @@
#include "src/tint/transform/add_spirv_block_attribute.h"
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/variable.h"
-#include "src/tint/utils/map.h"
+#include "src/tint/utils/hashmap.h"
+#include "src/tint/utils/hashset.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute);
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute::SpirvBlockAttribute);
@@ -35,59 +34,64 @@
auto& sem = ctx.src->Sem();
// Collect the set of structs that are nested in other types.
- std::unordered_set<const sem::Struct*> nested_structs;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- if (auto* arr = sem.Get<sem::Array>(node->As<ast::Array>())) {
- if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
- nested_structs.insert(nested_str);
- }
- } else if (auto* str = sem.Get<sem::Struct>(node->As<ast::Struct>())) {
- for (auto* member : str->Members()) {
- if (auto* nested_str = member->Type()->As<sem::Struct>()) {
- nested_structs.insert(nested_str);
+ utils::Hashset<const sem::Struct*, 8> nested_structs;
+ for (auto* ty : ctx.src->Types()) {
+ Switch(
+ ty,
+ [&](const sem::Array* arr) {
+ if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
+ nested_structs.Add(nested_str);
}
- }
- }
+ },
+ [&](const sem::Struct* str) {
+ for (auto* member : str->Members()) {
+ if (auto* nested_str = member->Type()->As<sem::Struct>()) {
+ nested_structs.Add(nested_str);
+ }
+ }
+ });
}
- // A map from a type in the source program to a block-decorated wrapper that
- // contains it in the destination program.
- std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
+ // A map from a type in the source program to a block-decorated wrapper that contains it in the
+ // destination program.
+ utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs;
// Process global 'var' declarations that are buffers.
- for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
- auto* sem_var = sem.Get<sem::GlobalVariable>(var);
- if (var->declared_storage_class != ast::StorageClass::kStorage &&
- var->declared_storage_class != ast::StorageClass::kUniform &&
- var->declared_storage_class != ast::StorageClass::kPushConstant) {
+ for (auto* global : ctx.src->AST().GlobalVariables()) {
+ auto* var = sem.Get(global);
+ if (!ast::IsHostShareable(var->StorageClass())) {
+ // Not declared in a host-sharable storage class
continue;
}
- auto* ty = sem.Get(var->type);
+ auto* ty = var->Type()->UnwrapRef();
auto* str = ty->As<sem::Struct>();
- if (!str || nested_structs.count(str)) {
+ bool needs_wrapping = !str || // Type is not a structure
+ nested_structs.Contains(str); // Structure is nested by another type
+
+ if (needs_wrapping) {
const char* kMemberName = "inner";
// This is a non-struct or a struct that is nested somewhere else, so we
// need to wrap it first.
- auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
+ auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
auto* block = ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(
ctx.dst->ID(), ctx.dst->AllocateNodeID());
- auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
+ auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block";
auto* ret = ctx.dst->create<ast::Struct>(
ctx.dst->Symbols().New(wrapper_name),
utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
utils::Vector{block});
- ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret);
return ret;
});
- ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
+ ctx.Replace(global->type, ctx.dst->ty.Of(wrapper));
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
- for (auto* user : sem_var->Users()) {
+ for (auto* user : var->Users()) {
ctx.Replace(user->Declaration(),
- ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
+ ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName));
}
} else {
// Add a block attribute to this struct directly.
diff --git a/src/tint/transform/add_spirv_block_attribute_test.cc b/src/tint/transform/add_spirv_block_attribute_test.cc
index 90f9219..62abae3 100644
--- a/src/tint/transform/add_spirv_block_attribute_test.cc
+++ b/src/tint/transform/add_spirv_block_attribute_test.cc
@@ -163,7 +163,40 @@
EXPECT_EQ(expect, str(got));
}
-TEST_F(AddSpirvBlockAttributeTest, BasicStruct) {
+TEST_F(AddSpirvBlockAttributeTest, BasicStruct_AccessRoot) {
+ auto* src = R"(
+struct S {
+ f : f32,
+};
+
+@group(0) @binding(0)
+var<uniform> u : S;
+
+@fragment
+fn main() {
+ let f = u;
+}
+)";
+ auto* expect = R"(
+@internal(spirv_block)
+struct S {
+ f : f32,
+}
+
+@group(0) @binding(0) var<uniform> u : S;
+
+@fragment
+fn main() {
+ let f = u;
+}
+)";
+
+ auto got = Run<AddSpirvBlockAttribute>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(AddSpirvBlockAttributeTest, BasicStruct_AccessField) {
auto* src = R"(
struct S {
f : f32,
diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc
index 2807145..eaf9725 100644
--- a/src/tint/transform/binding_remapper.cc
+++ b/src/tint/transform/binding_remapper.cc
@@ -69,8 +69,9 @@
auto* func = ctx.src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* global : func->TransitivelyReferencedGlobals()) {
- if (auto binding_point = global->Declaration()->BindingPoint()) {
- BindingPoint from{binding_point.group->value, binding_point.binding->value};
+ if (global->Declaration()->HasBindingPoint()) {
+ BindingPoint from = global->BindingPoint();
+
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
// Remapped
@@ -90,9 +91,11 @@
}
for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
- if (auto binding_point = var->BindingPoint()) {
+ if (var->HasBindingPoint()) {
+ auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var);
+
// The original binding point
- BindingPoint from{binding_point.group->value, binding_point.binding->value};
+ BindingPoint from = global_sem->BindingPoint();
// The binding point after remapping
BindingPoint bp = from;
@@ -106,8 +109,11 @@
auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);
- ctx.Replace(binding_point.group, new_group);
- ctx.Replace(binding_point.binding, new_binding);
+ auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
+ auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
+
+ ctx.Replace(old_group, new_group);
+ ctx.Replace(old_binding, new_binding);
bp = to;
}
diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc
index 13e6de4..6290b26 100644
--- a/src/tint/transform/combine_samplers.cc
+++ b/src/tint/transform/combine_samplers.cc
@@ -149,12 +149,14 @@
// Remove all texture and sampler global variables. These will be replaced
// by combined samplers.
for (auto* global : ctx.src->AST().GlobalVariables()) {
+ auto* global_sem = sem.Get(global)->As<sem::GlobalVariable>();
auto* type = sem.Get(global->type);
if (tint::IsAnyOf<sem::Texture, sem::Sampler>(type) &&
!type->Is<sem::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
- } else if (auto binding_point = global->BindingPoint()) {
- if (binding_point.group->value == 0 && binding_point.binding->value == 0) {
+ } else if (global->HasBindingPoint()) {
+ auto binding_point = global_sem->BindingPoint();
+ if (binding_point.group == 0 && binding_point.binding == 0) {
auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertFront(global->attributes, attribute);
diff --git a/src/tint/transform/first_index_offset.cc b/src/tint/transform/first_index_offset.cc
index ee80e95..34b481e 100644
--- a/src/tint/transform/first_index_offset.cc
+++ b/src/tint/transform/first_index_offset.cc
@@ -145,7 +145,7 @@
}
}
}
- // Not interested in this experssion. Just clone.
+ // Not interested in this expression. Just clone.
return nullptr;
});
}
diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc
index d3ac05f..2fb4248 100644
--- a/src/tint/transform/multiplanar_external_texture.cc
+++ b/src/tint/transform/multiplanar_external_texture.cc
@@ -87,7 +87,7 @@
// represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct).
for (auto* global : ctx.src->AST().GlobalVariables()) {
- auto* sem_var = sem.Get(global);
+ auto* sem_var = sem.Get<sem::GlobalVariable>(global);
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
continue;
}
@@ -109,8 +109,7 @@
// provided to this transform. We fetch the new binding points by
// providing the original texture_external binding points into the
// passed map.
- BindingPoint bp = {global->BindingPoint().group->value,
- global->BindingPoint().binding->value};
+ BindingPoint bp = sem_var->BindingPoint();
BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp);
if (it == new_binding_points->bindings_map.end()) {
diff --git a/src/tint/transform/num_workgroups_from_uniform.cc b/src/tint/transform/num_workgroups_from_uniform.cc
index 293c6c4..5772424 100644
--- a/src/tint/transform/num_workgroups_from_uniform.cc
+++ b/src/tint/transform/num_workgroups_from_uniform.cc
@@ -136,9 +136,11 @@
group = 0;
for (auto* global : ctx.src->AST().GlobalVariables()) {
- if (auto binding_point = global->BindingPoint()) {
- if (binding_point.group->value >= group) {
- group = binding_point.group->value + 1;
+ if (global->HasBindingPoint()) {
+ auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global);
+ auto binding_point = global_sem->BindingPoint();
+ if (binding_point.group >= group) {
+ group = binding_point.group + 1;
}
}
}
diff --git a/src/tint/transform/remove_phonies.cc b/src/tint/transform/remove_phonies.cc
index 5e64252..c7c2dc5 100644
--- a/src/tint/transform/remove_phonies.cc
+++ b/src/tint/transform/remove_phonies.cc
@@ -33,33 +33,7 @@
namespace tint::transform {
namespace {
-struct SinkSignature {
- std::vector<const sem::Type*> types;
-
- bool operator==(const SinkSignature& other) const {
- if (types.size() != other.types.size()) {
- return false;
- }
- for (size_t i = 0; i < types.size(); i++) {
- if (types[i] != other.types[i]) {
- return false;
- }
- }
- return true;
- }
-
- struct Hasher {
- /// @param sig the CallTargetSignature to hash
- /// @return the hash value
- std::size_t operator()(const SinkSignature& sig) const {
- size_t hash = tint::utils::Hash(sig.types.size());
- for (auto* ty : sig.types) {
- tint::utils::HashCombine(&hash, ty);
- }
- return hash;
- }
- };
-};
+using SinkSignature = std::vector<const sem::Type*>;
} // namespace
@@ -84,7 +58,7 @@
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
- std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;
+ std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks;
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
@@ -138,12 +112,12 @@
ctx.Replace(stmt, [&, side_effects] {
SinkSignature sig;
for (auto* arg : side_effects) {
- sig.types.push_back(sem.Get(arg)->Type()->UnwrapRef());
+ sig.push_back(sem.Get(arg)->Type()->UnwrapRef());
}
auto sink = utils::GetOrCreate(sinks, sig, [&] {
auto name = ctx.dst->Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params;
- for (auto* ty : sig.types) {
+ for (auto* ty : sig) {
auto* ast_ty = CreateASTTypeFor(ctx, ty);
params.Push(
ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty));
diff --git a/src/tint/utils/hash.h b/src/tint/utils/hash.h
index d1fac11..ad53841 100644
--- a/src/tint/utils/hash.h
+++ b/src/tint/utils/hash.h
@@ -48,55 +48,96 @@
} // namespace detail
-// Forward declaration
+/// Forward declarations (see below)
template <typename... ARGS>
-size_t Hash(const ARGS&... args);
+size_t Hash(const ARGS&... values);
-/// HashCombine "hashes" together an existing hash and hashable values.
-template <typename T>
-void HashCombine(size_t* hash, const T& value) {
- constexpr size_t offset = detail::HashCombineOffset<sizeof(size_t)>::value();
- *hash ^= std::hash<T>()(value) + offset + (*hash << 6) + (*hash >> 2);
-}
+template <typename... ARGS>
+size_t HashCombine(size_t hash, const ARGS&... values);
-/// HashCombine "hashes" together an existing hash and hashable values.
+/// A STL-compatible hasher that does a more thorough job than most implementations of std::hash.
+/// Hasher has been optimized for a better quality hash at the expense of increased computation
+/// costs.
template <typename T>
-void HashCombine(size_t* hash, const std::vector<T>& vector) {
- HashCombine(hash, vector.size());
- for (auto& el : vector) {
- HashCombine(hash, el);
+struct Hasher {
+ /// @param value the value to hash
+ /// @returns a hash of the value
+ size_t operator()(const T& value) const { return std::hash<T>()(value); }
+};
+
+/// Hasher specialization for pointers
+/// std::hash<T*> typically uses a reinterpret of the pointer to a size_t.
+/// As most pointers a 4 or 16 byte aligned, this usually results in the LSBs of the hash being 0,
+/// resulting in bad hashes for hashtables. This implementation mixes up those LSBs.
+template <typename T>
+struct Hasher<T*> {
+ /// @param ptr the pointer to hash
+ /// @returns a hash of the pointer
+ size_t operator()(T* ptr) const {
+ auto hash = std::hash<T*>()(ptr);
+ return hash ^ (hash >> 4);
}
-}
+};
-/// HashCombine "hashes" together an existing hash and hashable values.
+/// Hasher specialization for std::vector
+template <typename T>
+struct Hasher<std::vector<T>> {
+ /// @param vector the vector to hash
+ /// @returns a hash of the vector
+ size_t operator()(const std::vector<T>& vector) const {
+ auto hash = Hash(vector.size());
+ for (auto& el : vector) {
+ hash = HashCombine(hash, el);
+ }
+ return hash;
+ }
+};
+
+/// Hasher specialization for utils::vector
template <typename T, size_t N>
-void HashCombine(size_t* hash, const utils::Vector<T, N>& list) {
- HashCombine(hash, list.Length());
- for (auto& el : list) {
- HashCombine(hash, el);
+struct Hasher<utils::Vector<T, N>> {
+ /// @param vector the vector to hash
+ /// @returns a hash of the vector
+ size_t operator()(const utils::Vector<T, N>& vector) const {
+ auto hash = Hash(vector.Length());
+ for (auto& el : vector) {
+ hash = HashCombine(hash, el);
+ }
+ return hash;
}
-}
+};
-/// HashCombine "hashes" together an existing hash and hashable values.
+/// Hasher specialization for std::tuple
template <typename... TYPES>
-void HashCombine(size_t* hash, const std::tuple<TYPES...>& tuple) {
- HashCombine(hash, sizeof...(TYPES));
- HashCombine(hash, std::apply(Hash<TYPES...>, tuple));
-}
+struct Hasher<std::tuple<TYPES...>> {
+ /// @param tuple the tuple to hash
+ /// @returns a hash of the tuple
+ size_t operator()(const std::tuple<TYPES...>& tuple) const {
+ return std::apply(Hash<TYPES...>, tuple);
+ }
+};
-/// HashCombine "hashes" together an existing hash and hashable values.
-template <typename T, typename... ARGS>
-void HashCombine(size_t* hash, const T& value, const ARGS&... args) {
- HashCombine(hash, value);
- HashCombine(hash, args...);
-}
-
-/// @returns a hash of the combined arguments. The returned hash is dependent on
-/// the order of the arguments.
+/// @returns a hash of the variadic list of arguments.
+/// The returned hash is dependent on the order of the arguments.
template <typename... ARGS>
size_t Hash(const ARGS&... args) {
- size_t hash = 102931; // seed with an arbitrary prime
- HashCombine(&hash, args...);
+ if constexpr (sizeof...(ARGS) == 0) {
+ return 0;
+ } else if constexpr (sizeof...(ARGS) == 1) {
+ using T = std::tuple_element_t<0, std::tuple<ARGS...>>;
+ return Hasher<T>()(args...);
+ } else {
+ size_t hash = 102931; // seed with an arbitrary prime
+ return HashCombine(hash, args...);
+ }
+}
+
+/// @returns a hash of the variadic list of arguments.
+/// The returned hash is dependent on the order of the arguments.
+template <typename... ARGS>
+size_t HashCombine(size_t hash, const ARGS&... values) {
+ constexpr size_t offset = detail::HashCombineOffset<sizeof(size_t)>::value();
+ ((hash ^= Hash(values) + offset + (hash << 6) + (hash >> 2)), ...);
return hash;
}
diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h
index 81bebf2..e179389 100644
--- a/src/tint/utils/hashmap.h
+++ b/src/tint/utils/hashmap.h
@@ -19,6 +19,7 @@
#include <optional>
#include <utility>
+#include "src/tint/utils/hash.h"
#include "src/tint/utils/hashset.h"
namespace tint::utils {
@@ -31,7 +32,7 @@
template <typename K,
typename V,
size_t N,
- typename HASH = std::hash<K>,
+ typename HASH = Hasher<K>,
typename EQUAL = std::equal_to<K>>
class Hashmap {
/// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into
diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h
index f88a304..3009c72 100644
--- a/src/tint/utils/hashset.h
+++ b/src/tint/utils/hashset.h
@@ -23,6 +23,7 @@
#include <utility>
#include "src/tint/debug.h"
+#include "src/tint/utils/hash.h"
#include "src/tint/utils/vector.h"
namespace tint::utils {
@@ -39,7 +40,7 @@
/// An unordered set that uses a robin-hood hashing algorithm.
/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html
-template <typename T, size_t N, typename HASH = std::hash<T>, typename EQUAL = std::equal_to<T>>
+template <typename T, size_t N, typename HASH = Hasher<T>, typename EQUAL = std::equal_to<T>>
class Hashset {
/// A slot is a single entry in the underlying vector.
/// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance
diff --git a/src/tint/writer/flatten_bindings.cc b/src/tint/writer/flatten_bindings.cc
index 1efc02a..bedec75 100644
--- a/src/tint/writer/flatten_bindings.cc
+++ b/src/tint/writer/flatten_bindings.cc
@@ -33,6 +33,7 @@
auto entry_points = inspector.GetEntryPoints();
for (auto& entry_point : entry_points) {
auto bindings = inspector.GetResourceBindings(entry_point.name);
+
for (auto& binding : bindings) {
BindingPoint src = {binding.bind_group, binding.binding};
if (binding_points.count(src)) {
diff --git a/src/tint/writer/flatten_bindings_test.cc b/src/tint/writer/flatten_bindings_test.cc
index 830e720..7137218 100644
--- a/src/tint/writer/flatten_bindings_test.cc
+++ b/src/tint/writer/flatten_bindings_test.cc
@@ -18,7 +18,6 @@
#include "gtest/gtest.h"
#include "src/tint/program_builder.h"
-#include "src/tint/resolver/resolver.h"
#include "src/tint/sem/variable.h"
namespace tint::writer {
@@ -28,9 +27,6 @@
TEST_F(FlattenBindingsTest, NoBindings) {
ProgramBuilder b;
-
- resolver::Resolver resolver(&b);
-
Program program(std::move(b));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
@@ -44,8 +40,6 @@
b.GlobalVar("b", b.ty.i32(), ast::StorageClass::kUniform, b.Group(0), b.Binding(1));
b.GlobalVar("c", b.ty.i32(), ast::StorageClass::kUniform, b.Group(0), b.Binding(2));
- resolver::Resolver resolver(&b);
-
Program program(std::move(b));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
@@ -60,8 +54,6 @@
b.GlobalVar("c", b.ty.i32(), ast::StorageClass::kUniform, b.Group(2), b.Binding(2));
b.WrapInFunction(b.Expr("a"), b.Expr("b"), b.Expr("c"));
- resolver::Resolver resolver(&b);
-
Program program(std::move(b));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
@@ -69,12 +61,21 @@
EXPECT_TRUE(flattened);
auto& vars = flattened->AST().GlobalVariables();
- EXPECT_EQ(vars[0]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[0]->BindingPoint().binding->value, 0u);
- EXPECT_EQ(vars[1]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[1]->BindingPoint().binding->value, 1u);
- EXPECT_EQ(vars[2]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[2]->BindingPoint().binding->value, 2u);
+
+ auto* sem = flattened->Sem().Get<sem::GlobalVariable>(vars[0]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, 0u);
+
+ sem = flattened->Sem().Get<sem::GlobalVariable>(vars[1]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, 1u);
+
+ sem = flattened->Sem().Get<sem::GlobalVariable>(vars[2]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, 2u);
}
TEST_F(FlattenBindingsTest, NotFlat_MultipleNamespaces) {
@@ -113,8 +114,6 @@
b.Assign(b.Phony(), "texture4"), b.Assign(b.Phony(), "texture5"),
b.Assign(b.Phony(), "texture6"));
- resolver::Resolver resolver(&b);
-
Program program(std::move(b));
ASSERT_TRUE(program.IsValid()) << program.Diagnostics().str();
@@ -124,16 +123,22 @@
auto& vars = flattened->AST().GlobalVariables();
for (size_t i = 0; i < num_buffers; ++i) {
- EXPECT_EQ(vars[i]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[i]->BindingPoint().binding->value, i);
+ auto* sem = flattened->Sem().Get<sem::GlobalVariable>(vars[i]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, i);
}
for (size_t i = 0; i < num_samplers; ++i) {
- EXPECT_EQ(vars[i + num_buffers]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[i + num_buffers]->BindingPoint().binding->value, i);
+ auto* sem = flattened->Sem().Get<sem::GlobalVariable>(vars[i + num_buffers]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, i);
}
for (size_t i = 0; i < num_textures; ++i) {
- EXPECT_EQ(vars[i + num_buffers + num_samplers]->BindingPoint().group->value, 0u);
- EXPECT_EQ(vars[i + num_buffers + num_samplers]->BindingPoint().binding->value, i);
+ auto* sem = flattened->Sem().Get<sem::GlobalVariable>(vars[i + num_buffers + num_samplers]);
+ ASSERT_NE(sem, nullptr);
+ EXPECT_EQ(sem->BindingPoint().group, 0u);
+ EXPECT_EQ(sem->BindingPoint().binding, i);
}
}
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index c66e174..31f749f 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -1743,31 +1743,15 @@
}
}
return Switch(
- expr,
- [&](const ast::IndexAccessorExpression* a) { //
- return EmitIndexAccessor(out, a);
- },
- [&](const ast::BinaryExpression* b) { //
- return EmitBinary(out, b);
- },
- [&](const ast::BitcastExpression* b) { //
- return EmitBitcast(out, b);
- },
- [&](const ast::CallExpression* c) { //
- return EmitCall(out, c);
- },
- [&](const ast::IdentifierExpression* i) { //
- return EmitIdentifier(out, i);
- },
- [&](const ast::LiteralExpression* l) { //
- return EmitLiteral(out, l);
- },
- [&](const ast::MemberAccessorExpression* m) { //
- return EmitMemberAccessor(out, m);
- },
- [&](const ast::UnaryOpExpression* u) { //
- return EmitUnaryOp(out, u);
- },
+ expr, //
+ [&](const ast::IndexAccessorExpression* a) { return EmitIndexAccessor(out, a); },
+ [&](const ast::BinaryExpression* b) { return EmitBinary(out, b); },
+ [&](const ast::BitcastExpression* b) { return EmitBitcast(out, b); },
+ [&](const ast::CallExpression* c) { return EmitCall(out, c); },
+ [&](const ast::IdentifierExpression* i) { return EmitIdentifier(out, i); },
+ [&](const ast::LiteralExpression* l) { return EmitLiteral(out, l); },
+ [&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(out, m); },
+ [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); },
[&](Default) { //
diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
std::string(expr->TypeInfo().name));
@@ -1922,10 +1906,10 @@
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false;
}
- ast::VariableBindingPoint bp = var->BindingPoint();
+ auto bp = sem->As<sem::GlobalVariable>()->BindingPoint();
{
auto out = line();
- out << "layout(binding = " << bp.binding->value;
+ out << "layout(binding = " << bp.binding;
if (version_.IsDesktop()) {
out << ", std140";
}
@@ -1946,8 +1930,8 @@
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false;
}
- ast::VariableBindingPoint bp = var->BindingPoint();
- line() << "layout(binding = " << bp.binding->value << ", std430) buffer "
+ auto bp = sem->As<sem::GlobalVariable>()->BindingPoint();
+ line() << "layout(binding = " << bp.binding << ", std430) buffer "
<< UniqueIdentifier(StructName(str)) << " {";
EmitStructMembers(current_buffer_, str, /* emit_offsets */ true);
auto name = builder_.Symbols().NameFor(var->symbol);
diff --git a/src/tint/writer/glsl/generator_impl_storage_buffer_test.cc b/src/tint/writer/glsl/generator_impl_storage_buffer_test.cc
index 842bf35..0264b02 100644
--- a/src/tint/writer/glsl/generator_impl_storage_buffer_test.cc
+++ b/src/tint/writer/glsl/generator_impl_storage_buffer_test.cc
@@ -13,9 +13,11 @@
// limitations under the License.
#include "gmock/gmock.h"
+#include "src/tint/number.h"
#include "src/tint/writer/glsl/test_helper.h"
using ::testing::HasSubstr;
+using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::glsl {
namespace {
@@ -31,9 +33,9 @@
// @group(0) @binding(0) var<storage, read_write> nephews : Nephews;
auto* nephews = ctx->Structure(
"Nephews", utils::Vector{
- ctx->Member("huey", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256)}),
- ctx->Member("dewey", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256)}),
- ctx->Member("louie", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256)}),
+ ctx->Member("huey", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256_u)}),
+ ctx->Member("dewey", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256_u)}),
+ ctx->Member("louie", ctx->ty.f32(), utils::Vector{ctx->MemberAlign(256_u)}),
});
ctx->GlobalVar("nephews", ctx->ty.Of(nephews), ast::StorageClass::kStorage, ctx->Binding(0),
ctx->Group(0));
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 444746c..fca1515 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -136,15 +136,15 @@
// Helper for writing " : register(RX, spaceY)", where R is the register, X is
// the binding point binding value, and Y is the binding point group value.
struct RegisterAndSpace {
- RegisterAndSpace(char r, ast::VariableBindingPoint bp) : reg(r), binding_point(bp) {}
+ RegisterAndSpace(char r, sem::BindingPoint bp) : reg(r), binding_point(bp) {}
const char reg;
- ast::VariableBindingPoint const binding_point;
+ sem::BindingPoint const binding_point;
};
std::ostream& operator<<(std::ostream& s, const RegisterAndSpace& rs) {
- s << " : register(" << rs.reg << rs.binding_point.binding->value << ", space"
- << rs.binding_point.group->value << ")";
+ s << " : register(" << rs.reg << rs.binding_point.binding << ", space" << rs.binding_point.group
+ << ")";
return s;
}
@@ -2631,32 +2631,16 @@
}
}
return Switch(
- expr,
- [&](const ast::IndexAccessorExpression* a) { //
- return EmitIndexAccessor(out, a);
- },
- [&](const ast::BinaryExpression* b) { //
- return EmitBinary(out, b);
- },
- [&](const ast::BitcastExpression* b) { //
- return EmitBitcast(out, b);
- },
- [&](const ast::CallExpression* c) { //
- return EmitCall(out, c);
- },
- [&](const ast::IdentifierExpression* i) { //
- return EmitIdentifier(out, i);
- },
- [&](const ast::LiteralExpression* l) { //
- return EmitLiteral(out, l);
- },
- [&](const ast::MemberAccessorExpression* m) { //
- return EmitMemberAccessor(out, m);
- },
- [&](const ast::UnaryOpExpression* u) { //
- return EmitUnaryOp(out, u);
- },
- [&](Default) { //
+ expr, //
+ [&](const ast::IndexAccessorExpression* a) { return EmitIndexAccessor(out, a); },
+ [&](const ast::BinaryExpression* b) { return EmitBinary(out, b); },
+ [&](const ast::BitcastExpression* b) { return EmitBitcast(out, b); },
+ [&](const ast::CallExpression* c) { return EmitCall(out, c); },
+ [&](const ast::IdentifierExpression* i) { return EmitIdentifier(out, i); },
+ [&](const ast::LiteralExpression* l) { return EmitLiteral(out, l); },
+ [&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(out, m); },
+ [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); },
+ [&](Default) {
diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
std::string(expr->TypeInfo().name));
return false;
@@ -2877,7 +2861,7 @@
}
bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
- auto binding_point = var->BindingPoint();
+ auto binding_point = sem->As<sem::GlobalVariable>()->BindingPoint();
auto* type = sem->Type()->UnwrapRef();
auto name = builder_.Symbols().NameFor(var->symbol);
line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point) << " {";
@@ -2904,7 +2888,9 @@
return false;
}
- out << RegisterAndSpace(sem->Access() == ast::Access::kRead ? 't' : 'u', var->BindingPoint())
+ auto* global_sem = sem->As<sem::GlobalVariable>();
+ out << RegisterAndSpace(sem->Access() == ast::Access::kRead ? 't' : 'u',
+ global_sem->BindingPoint())
<< ";";
return true;
@@ -2932,9 +2918,8 @@
}
if (register_space) {
- auto bp = var->BindingPoint();
- out << " : register(" << register_space << bp.binding->value << ", space" << bp.group->value
- << ")";
+ auto bp = sem->As<sem::GlobalVariable>()->BindingPoint();
+ out << " : register(" << register_space << bp.binding << ", space" << bp.group << ")";
}
out << ";";
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index cccf622..7617678 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -1764,36 +1764,20 @@
bool GeneratorImpl::EmitExpression(std::ostream& out, const ast::Expression* expr) {
if (auto* sem = builder_.Sem().Get(expr)) {
- if (auto constant = sem->ConstantValue()) {
+ if (auto* constant = sem->ConstantValue()) {
return EmitConstant(out, constant);
}
}
return Switch(
- expr,
- [&](const ast::IndexAccessorExpression* a) { //
- return EmitIndexAccessor(out, a);
- },
- [&](const ast::BinaryExpression* b) { //
- return EmitBinary(out, b);
- },
- [&](const ast::BitcastExpression* b) { //
- return EmitBitcast(out, b);
- },
- [&](const ast::CallExpression* c) { //
- return EmitCall(out, c);
- },
- [&](const ast::IdentifierExpression* i) { //
- return EmitIdentifier(out, i);
- },
- [&](const ast::LiteralExpression* l) { //
- return EmitLiteral(out, l);
- },
- [&](const ast::MemberAccessorExpression* m) { //
- return EmitMemberAccessor(out, m);
- },
- [&](const ast::UnaryOpExpression* u) { //
- return EmitUnaryOp(out, u);
- },
+ expr, //
+ [&](const ast::IndexAccessorExpression* a) { return EmitIndexAccessor(out, a); },
+ [&](const ast::BinaryExpression* b) { return EmitBinary(out, b); },
+ [&](const ast::BitcastExpression* b) { return EmitBitcast(out, b); },
+ [&](const ast::CallExpression* c) { return EmitCall(out, c); },
+ [&](const ast::IdentifierExpression* i) { return EmitIdentifier(out, i); },
+ [&](const ast::LiteralExpression* l) { return EmitLiteral(out, l); },
+ [&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(out, m); },
+ [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(out, u); },
[&](Default) { //
diagnostics_.add_error(diag::System::Writer, "unknown expression type: " +
std::string(expr->TypeInfo().name));
@@ -1930,18 +1914,19 @@
// attribute have a value of zero.
const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max();
auto get_binding_index = [&](const ast::Parameter* param) -> uint32_t {
- auto bp = param->BindingPoint();
- if (bp.group == nullptr || bp.binding == nullptr) {
+ if (!param->HasBindingPoint()) {
TINT_ICE(Writer, diagnostics_)
<< "missing binding attributes for entry point parameter";
return kInvalidBindingIndex;
}
- if (bp.group->value != 0) {
+ auto* param_sem = program_->Sem().Get<sem::Parameter>(param);
+ auto bp = param_sem->BindingPoint();
+ if (bp.group != 0) {
TINT_ICE(Writer, diagnostics_) << "encountered non-zero resource group index (use "
"BindingRemapper to fix)";
return kInvalidBindingIndex;
}
- return bp.binding->value;
+ return bp.binding;
};
{
diff --git a/src/tint/writer/msl/generator_impl_type_test.cc b/src/tint/writer/msl/generator_impl_type_test.cc
index 72f9fa5..531a45d 100644
--- a/src/tint/writer/msl/generator_impl_type_test.cc
+++ b/src/tint/writer/msl/generator_impl_type_test.cc
@@ -252,35 +252,35 @@
}
TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_NonComposites) {
- auto* s =
- Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberSize(32)}),
- Member("b", ty.f32(), utils::Vector{MemberAlign(128), MemberSize(128)}),
- Member("c", ty.vec2<f32>()),
- Member("d", ty.u32()),
- Member("e", ty.vec3<f32>()),
- Member("f", ty.u32()),
- Member("g", ty.vec4<f32>()),
- Member("h", ty.u32()),
- Member("i", ty.mat2x2<f32>()),
- Member("j", ty.u32()),
- Member("k", ty.mat2x3<f32>()),
- Member("l", ty.u32()),
- Member("m", ty.mat2x4<f32>()),
- Member("n", ty.u32()),
- Member("o", ty.mat3x2<f32>()),
- Member("p", ty.u32()),
- Member("q", ty.mat3x3<f32>()),
- Member("r", ty.u32()),
- Member("s", ty.mat3x4<f32>()),
- Member("t", ty.u32()),
- Member("u", ty.mat4x2<f32>()),
- Member("v", ty.u32()),
- Member("w", ty.mat4x3<f32>()),
- Member("x", ty.u32()),
- Member("y", ty.mat4x4<f32>()),
- Member("z", ty.f32()),
- });
+ auto* s = Structure(
+ "S", utils::Vector{
+ Member("a", ty.i32(), utils::Vector{MemberSize(32)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(128_u), MemberSize(128)}),
+ Member("c", ty.vec2<f32>()),
+ Member("d", ty.u32()),
+ Member("e", ty.vec3<f32>()),
+ Member("f", ty.u32()),
+ Member("g", ty.vec4<f32>()),
+ Member("h", ty.u32()),
+ Member("i", ty.mat2x2<f32>()),
+ Member("j", ty.u32()),
+ Member("k", ty.mat2x3<f32>()),
+ Member("l", ty.u32()),
+ Member("m", ty.mat2x4<f32>()),
+ Member("n", ty.u32()),
+ Member("o", ty.mat3x2<f32>()),
+ Member("p", ty.u32()),
+ Member("q", ty.mat3x3<f32>()),
+ Member("r", ty.u32()),
+ Member("s", ty.mat3x4<f32>()),
+ Member("t", ty.u32()),
+ Member("u", ty.mat4x2<f32>()),
+ Member("v", ty.u32()),
+ Member("w", ty.mat4x3<f32>()),
+ Member("x", ty.u32()),
+ Member("y", ty.mat4x4<f32>()),
+ Member("z", ty.f32()),
+ });
GlobalVar("G", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, Binding(0), Group(0));
@@ -368,10 +368,11 @@
TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_Structures) {
// inner_x: size(1024), align(512)
- auto* inner_x = Structure("inner_x", utils::Vector{
- Member("a", ty.i32()),
- Member("b", ty.f32(), utils::Vector{MemberAlign(512)}),
- });
+ auto* inner_x =
+ Structure("inner_x", utils::Vector{
+ Member("a", ty.i32()),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(512_u)}),
+ });
// inner_y: size(516), align(4)
auto* inner_y = Structure("inner_y", utils::Vector{
@@ -456,7 +457,7 @@
// inner: size(1024), align(512)
auto* inner = Structure("inner", utils::Vector{
Member("a", ty.i32()),
- Member("b", ty.f32(), utils::Vector{MemberAlign(512)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(512_u)}),
});
// array_x: size(28), align(4)
@@ -588,36 +589,36 @@
}
TEST_F(MslGeneratorImplTest, AttemptTintPadSymbolCollision) {
- auto* s = Structure(
- "S", utils::Vector{
- // uses symbols tint_pad_[0..9] and tint_pad_[20..35]
- Member("tint_pad_2", ty.i32(), utils::Vector{MemberSize(32)}),
- Member("tint_pad_20", ty.f32(), utils::Vector{MemberAlign(128), MemberSize(128)}),
- Member("tint_pad_33", ty.vec2<f32>()),
- Member("tint_pad_1", ty.u32()),
- Member("tint_pad_3", ty.vec3<f32>()),
- Member("tint_pad_7", ty.u32()),
- Member("tint_pad_25", ty.vec4<f32>()),
- Member("tint_pad_5", ty.u32()),
- Member("tint_pad_27", ty.mat2x2<f32>()),
- Member("tint_pad_24", ty.u32()),
- Member("tint_pad_23", ty.mat2x3<f32>()),
- Member("tint_pad", ty.u32()),
- Member("tint_pad_8", ty.mat2x4<f32>()),
- Member("tint_pad_26", ty.u32()),
- Member("tint_pad_29", ty.mat3x2<f32>()),
- Member("tint_pad_6", ty.u32()),
- Member("tint_pad_22", ty.mat3x3<f32>()),
- Member("tint_pad_32", ty.u32()),
- Member("tint_pad_34", ty.mat3x4<f32>()),
- Member("tint_pad_35", ty.u32()),
- Member("tint_pad_30", ty.mat4x2<f32>()),
- Member("tint_pad_9", ty.u32()),
- Member("tint_pad_31", ty.mat4x3<f32>()),
- Member("tint_pad_28", ty.u32()),
- Member("tint_pad_4", ty.mat4x4<f32>()),
- Member("tint_pad_21", ty.f32()),
- });
+ auto* s = Structure("S", utils::Vector{
+ // uses symbols tint_pad_[0..9] and tint_pad_[20..35]
+ Member("tint_pad_2", ty.i32(), utils::Vector{MemberSize(32)}),
+ Member("tint_pad_20", ty.f32(),
+ utils::Vector{MemberAlign(128_u), MemberSize(128_u)}),
+ Member("tint_pad_33", ty.vec2<f32>()),
+ Member("tint_pad_1", ty.u32()),
+ Member("tint_pad_3", ty.vec3<f32>()),
+ Member("tint_pad_7", ty.u32()),
+ Member("tint_pad_25", ty.vec4<f32>()),
+ Member("tint_pad_5", ty.u32()),
+ Member("tint_pad_27", ty.mat2x2<f32>()),
+ Member("tint_pad_24", ty.u32()),
+ Member("tint_pad_23", ty.mat2x3<f32>()),
+ Member("tint_pad", ty.u32()),
+ Member("tint_pad_8", ty.mat2x4<f32>()),
+ Member("tint_pad_26", ty.u32()),
+ Member("tint_pad_29", ty.mat3x2<f32>()),
+ Member("tint_pad_6", ty.u32()),
+ Member("tint_pad_22", ty.mat3x3<f32>()),
+ Member("tint_pad_32", ty.u32()),
+ Member("tint_pad_34", ty.mat3x4<f32>()),
+ Member("tint_pad_35", ty.u32()),
+ Member("tint_pad_30", ty.mat4x2<f32>()),
+ Member("tint_pad_9", ty.u32()),
+ Member("tint_pad_31", ty.mat4x3<f32>()),
+ Member("tint_pad_28", ty.u32()),
+ Member("tint_pad_4", ty.mat4x4<f32>()),
+ Member("tint_pad_21", ty.f32()),
+ });
GlobalVar("G", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, Binding(0), Group(0));
diff --git a/src/tint/writer/spirv/builder_type_test.cc b/src/tint/writer/spirv/builder_type_test.cc
index c6f044e..3b8e0b1 100644
--- a/src/tint/writer/spirv/builder_type_test.cc
+++ b/src/tint/writer/spirv/builder_type_test.cc
@@ -336,7 +336,7 @@
TEST_F(BuilderTest_Type, GenerateStruct_DecoratedMembers) {
auto* s = Structure("S", utils::Vector{
Member("a", ty.f32()),
- Member("b", ty.f32(), utils::Vector{MemberAlign(8)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(8_u)}),
});
spirv::Builder& b = Build();
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index efc9601..9f9c4c6 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -776,7 +776,11 @@
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
- out << "align(" << align->align << ")";
+ out << "align(";
+ if (!EmitExpression(out, align->align)) {
+ return false;
+ }
+ out << ")";
return true;
},
[&](const ast::StrideAttribute* stride) {
diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc
index 73fc186..cdba585 100644
--- a/src/tint/writer/wgsl/generator_impl_type_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_type_test.cc
@@ -219,8 +219,8 @@
TEST_F(WgslGeneratorImplTest, EmitType_StructAlignDecl) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberAlign(8)}),
- Member("b", ty.f32(), utils::Vector{MemberAlign(16)}),
+ Member("a", ty.i32(), utils::Vector{MemberAlign(8_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(16_a)}),
});
GeneratorImpl& gen = Build();
@@ -256,7 +256,7 @@
TEST_F(WgslGeneratorImplTest, EmitType_Struct_WithAttribute) {
auto* s = Structure("S", utils::Vector{
Member("a", ty.i32()),
- Member("b", ty.f32(), utils::Vector{MemberAlign(8)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(8_a)}),
});
GeneratorImpl& gen = Build();