[tint] Resolve types without recursion
Change ast::TraverseExpressions to consider identifier templated
expressions. This will make the resolver resolve the template
sub-expressions without recursion.
Change the resolver to fetch the type and expressions from the sem map
instead of calling the Type() and Expression() root resolve functions.
Change-Id: I4ad9283f33b85a58cc39e4887f699ef34f0d3617
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/155143
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.cc b/src/tint/lang/spirv/writer/ast_printer/builder.cc
index 02ae3bd..1a57747 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.cc
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.cc
@@ -1223,20 +1223,24 @@
return 0;
}
-bool Builder::IsConstructorConst(const ast::Expression* expr) {
+bool Builder::IsConstructorConst(const ast::CallExpression* expr) {
bool is_const = true;
ast::TraverseExpressions(expr, [&](const ast::Expression* e) {
+ auto* val = builder_.Sem().GetVal(e);
+ if (!val) {
+ return ast::TraverseAction::Descend;
+ }
+
if (e->Is<ast::LiteralExpression>()) {
return ast::TraverseAction::Descend;
}
- if (auto* ce = e->As<ast::CallExpression>()) {
- auto* sem = builder_.Sem().Get(ce);
- if (sem->Is<sem::Materialize>()) {
+ if (e->Is<ast::CallExpression>()) {
+ if (val->Is<sem::Materialize>()) {
// Materialize can only occur on compile time expressions, so this sub-tree must be
// constant.
return ast::TraverseAction::Skip;
}
- auto* call = sem->As<sem::Call>();
+ auto* call = val->As<sem::Call>();
if (call->Target()->Is<sem::ValueConstructor>()) {
return ast::TraverseAction::Descend;
}
diff --git a/src/tint/lang/spirv/writer/ast_printer/builder.h b/src/tint/lang/spirv/writer/ast_printer/builder.h
index ae4b639..78214d1 100644
--- a/src/tint/lang/spirv/writer/ast_printer/builder.h
+++ b/src/tint/lang/spirv/writer/ast_printer/builder.h
@@ -465,7 +465,7 @@
/// Determines if the given value constructor is created from constant values
/// @param expr the expression to check
/// @returns true if the constructor is constant
- bool IsConstructorConst(const ast::Expression* expr);
+ bool IsConstructorConst(const ast::CallExpression* expr);
private:
/// @returns an Operand with a new result ID in it. Increments the next_id_
diff --git a/src/tint/lang/wgsl/ast/traverse_expressions.h b/src/tint/lang/wgsl/ast/traverse_expressions.h
index 17688c2..7319a78 100644
--- a/src/tint/lang/wgsl/ast/traverse_expressions.h
+++ b/src/tint/lang/wgsl/ast/traverse_expressions.h
@@ -24,6 +24,7 @@
#include "src/tint/lang/wgsl/ast/literal_expression.h"
#include "src/tint/lang/wgsl/ast/member_accessor_expression.h"
#include "src/tint/lang/wgsl/ast/phony_expression.h"
+#include "src/tint/lang/wgsl/ast/templated_identifier.h"
#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/containers/vector.h"
@@ -73,7 +74,7 @@
auto push_single = [&](const Expression* expr, size_t depth) { to_visit.Push({expr, depth}); };
auto push_pair = [&](const Expression* left, const Expression* right, size_t depth) {
- if (ORDER == TraverseOrder::LeftToRight) {
+ if constexpr (ORDER == TraverseOrder::LeftToRight) {
to_visit.Push({right, depth});
to_visit.Push({left, depth});
} else {
@@ -82,7 +83,7 @@
}
};
auto push_list = [&](VectorRef<const Expression*> exprs, size_t depth) {
- if (ORDER == TraverseOrder::LeftToRight) {
+ if constexpr (ORDER == TraverseOrder::LeftToRight) {
for (auto* expr : tint::Reverse(exprs)) {
to_visit.Push({expr, depth});
}
@@ -117,6 +118,12 @@
bool ok = Switch(
expr,
+ [&](const IdentifierExpression* ident) {
+ if (auto* tmpl = ident->identifier->As<TemplatedIdentifier>()) {
+ push_list(tmpl->arguments, p.depth + 1);
+ }
+ return true;
+ },
[&](const IndexAccessorExpression* idx) {
push_pair(idx->object, idx->index, p.depth + 1);
return true;
@@ -130,7 +137,13 @@
return true;
},
[&](const CallExpression* call) {
- push_list(call->args, p.depth + 1);
+ if constexpr (ORDER == TraverseOrder::LeftToRight) {
+ push_list(call->args, p.depth + 1);
+ push_single(call->target, p.depth + 1);
+ } else {
+ push_single(call->target, p.depth + 1);
+ push_list(call->args, p.depth + 1);
+ }
return true;
},
[&](const MemberAccessorExpression* member) {
@@ -142,8 +155,7 @@
return true;
},
[&](Default) {
- if (TINT_LIKELY((expr->IsAnyOf<LiteralExpression, IdentifierExpression,
- PhonyExpression>()))) {
+ if (TINT_LIKELY((expr->IsAnyOf<LiteralExpression, PhonyExpression>()))) {
return true; // Leaf expression
}
TINT_ICE() << "unhandled expression type: "
diff --git a/src/tint/lang/wgsl/ast/traverse_expressions_test.cc b/src/tint/lang/wgsl/ast/traverse_expressions_test.cc
index c477ac0..85da66c 100644
--- a/src/tint/lang/wgsl/ast/traverse_expressions_test.cc
+++ b/src/tint/lang/wgsl/ast/traverse_expressions_test.cc
@@ -27,22 +27,44 @@
using TraverseExpressionsTest = TestHelper;
+TEST_F(TraverseExpressionsTest, DescendTemplatedIdentifier) {
+ tint::Vector e{Expr(1_i), Expr(2_i), Expr(1_i), Expr(1_i)};
+ tint::Vector c{Expr(Ident("a", e[0], e[1])), Expr(Ident("b", e[2], e[3]))};
+ auto* root = Expr(Ident("c", c[0], c[1]));
+ {
+ Vector<const Expression*, 8> l2r;
+ TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
+ l2r.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
+ }
+ {
+ Vector<const Expression*, 8> r2l;
+ TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
+ r2l.Push(expr);
+ return TraverseAction::Descend;
+ });
+ EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
+ }
+}
+
TEST_F(TraverseExpressionsTest, DescendIndexAccessor) {
- std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ Vector e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ Vector i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
{
- std::vector<const Expression*> l2r;
+ Vector<const ast::Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- l2r.push_back(expr);
+ l2r.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
}
{
- std::vector<const Expression*> r2l;
+ Vector<const ast::Expression*, 8> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
- r2l.push_back(expr);
+ r2l.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
@@ -50,21 +72,21 @@
}
TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
- std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- std::vector<const Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
+ Vector e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ Vector i = {Add(e[0], e[1]), Sub(e[2], e[3])};
auto* root = Mul(i[0], i[1]);
{
- std::vector<const Expression*> l2r;
+ Vector<const ast::Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- l2r.push_back(expr);
+ l2r.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
}
{
- std::vector<const Expression*> r2l;
+ Vector<const ast::Expression*, 8> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
- r2l.push_back(expr);
+ r2l.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
@@ -72,8 +94,8 @@
}
TEST_F(TraverseExpressionsTest, Depth) {
- std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- std::vector<const Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
+ Vector e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ Vector i = {Add(e[0], e[1]), Sub(e[2], e[3])};
auto* root = Mul(i[0], i[1]);
size_t j = 0;
@@ -113,16 +135,17 @@
}
TEST_F(TraverseExpressionsTest, DescendCallExpression) {
- tint::Vector e{Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- tint::Vector c{Call("a", e[0], e[1]), Call("b", e[2], e[3])};
- auto* root = Call("c", c[0], c[1]);
+ tint::Vector i{Expr("a"), Expr("b"), Expr("c")};
+ tint::Vector e{Expr(1_i), Expr(2_i), Expr(1_i), Expr(1_i)};
+ tint::Vector c{Call(i[0], e[0], e[1]), Call(i[1], e[2], e[3])};
+ auto* root = Call(i[2], c[0], c[1]);
{
Vector<const Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
l2r.Push(expr);
return TraverseAction::Descend;
});
- EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
+ EXPECT_THAT(l2r, ElementsAre(root, i[2], c[0], i[0], e[0], e[1], c[1], i[1], e[2], e[3]));
}
{
Vector<const Expression*, 8> r2l;
@@ -130,7 +153,7 @@
r2l.Push(expr);
return TraverseAction::Descend;
});
- EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
+ EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], i[1], c[0], e[1], e[0], i[0], i[2]));
}
}
@@ -139,17 +162,17 @@
auto* m = MemberAccessor(e, "a");
auto* root = MemberAccessor(m, "b");
{
- std::vector<const Expression*> l2r;
+ Vector<const ast::Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- l2r.push_back(expr);
+ l2r.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, m, e));
}
{
- std::vector<const Expression*> r2l;
+ Vector<const ast::Expression*, 8> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
- r2l.push_back(expr);
+ r2l.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, m, e));
@@ -165,17 +188,17 @@
auto* f = IndexAccessor(d, e);
auto* root = IndexAccessor(c, f);
{
- std::vector<const Expression*> l2r;
+ Vector<const ast::Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- l2r.push_back(expr);
+ l2r.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, c, a, b, f, d, e));
}
{
- std::vector<const Expression*> r2l;
+ Vector<const ast::Expression*, 8> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
- r2l.push_back(expr);
+ r2l.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, f, e, d, c, b, a));
@@ -189,17 +212,17 @@
auto* u2 = AddressOf(u1);
auto* root = Deref(u2);
{
- std::vector<const Expression*> l2r;
+ Vector<const ast::Expression*, 8> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- l2r.push_back(expr);
+ l2r.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, u2, u1, u0, e));
}
{
- std::vector<const Expression*> r2l;
+ Vector<const ast::Expression*, 8> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(root, [&](const Expression* expr) {
- r2l.push_back(expr);
+ r2l.Push(expr);
return TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, u2, u1, u0, e));
@@ -207,24 +230,24 @@
}
TEST_F(TraverseExpressionsTest, Skip) {
- std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ Vector e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ Vector i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
- std::vector<const Expression*> order;
+ Vector<const ast::Expression*, 8> order;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- order.push_back(expr);
+ order.Push(expr);
return expr == i[0] ? TraverseAction::Skip : TraverseAction::Descend;
});
EXPECT_THAT(order, ElementsAre(root, i[0], i[1], e[2], e[3]));
}
TEST_F(TraverseExpressionsTest, Stop) {
- std::vector<const Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
- std::vector<const Expression*> i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
+ Vector e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
+ Vector i = {IndexAccessor(e[0], e[1]), IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
- std::vector<const Expression*> order;
+ Vector<const ast::Expression*, 8> order;
TraverseExpressions<TraverseOrder::LeftToRight>(root, [&](const Expression* expr) {
- order.push_back(expr);
+ order.Push(expr);
return expr == i[0] ? TraverseAction::Stop : TraverseAction::Descend;
});
EXPECT_THAT(order, ElementsAre(root, i[0]));
diff --git a/src/tint/lang/wgsl/resolver/builtin_enum_test.cc b/src/tint/lang/wgsl/resolver/builtin_enum_test.cc
index 3f73a5d..1ad9ea7 100644
--- a/src/tint/lang/wgsl/resolver/builtin_enum_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtin_enum_test.cc
@@ -36,8 +36,8 @@
using ResolverAccessUsedWithTemplateArgs = ResolverTestWithParam<const char*>;
TEST_P(ResolverAccessUsedWithTemplateArgs, Test) {
- // @group(0) @binding(0) var t : texture_storage_2d<rgba8unorm, ACCESS<T>>;
- auto* tmpl = Ident(Source{{12, 34}}, GetParam(), "T");
+ // @group(0) @binding(0) var t : texture_storage_2d<rgba8unorm, ACCESS<i32>>;
+ auto* tmpl = Ident(Source{{12, 34}}, GetParam(), "i32");
GlobalVar("v", ty("texture_storage_2d", "rgba8unorm", tmpl), Group(0_u), Binding(0_u));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph.cc b/src/tint/lang/wgsl/resolver/dependency_graph.cc
index 1d4f8f8..8d9fcd4 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph.cc
@@ -349,13 +349,7 @@
expr,
[&](const ast::IdentifierExpression* e) {
AddDependency(e->identifier, e->identifier->symbol);
- if (auto* tmpl_ident = e->identifier->As<ast::TemplatedIdentifier>()) {
- for (auto* arg : tmpl_ident->arguments) {
- pending.Push(arg);
- }
- }
},
- [&](const ast::CallExpression* call) { TraverseExpression(call->target); },
[&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); });
return ast::TraverseAction::Descend;
});
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 29a88ca..8c940f9 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2056,7 +2056,7 @@
// * A builtin call.
// * A value constructor.
// * A value conversion.
- auto* target = Expression(expr->target);
+ auto* target = sem_.Get(expr->target);
if (TINT_UNLIKELY(!target)) {
return nullptr;
}
@@ -2727,7 +2727,7 @@
return nullptr;
}
- auto* ty = Type(tmpl_ident->arguments[0]);
+ auto* ty = sem_.GetType(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!ty)) {
return nullptr;
}
@@ -2766,7 +2766,7 @@
return nullptr;
}
- auto* el_ty = Type(tmpl_ident->arguments[0]);
+ auto* el_ty = sem_.GetType(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!el_ty)) {
return nullptr;
}
@@ -2787,7 +2787,7 @@
auto* ast_el_ty = tmpl_ident->arguments[0];
auto* ast_count = (tmpl_ident->arguments.Length() > 1) ? tmpl_ident->arguments[1] : nullptr;
- auto* el_ty = Type(ast_el_ty);
+ auto* el_ty = sem_.GetType(ast_el_ty);
if (!el_ty) {
return nullptr;
}
@@ -2829,14 +2829,13 @@
return nullptr;
}
- auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
- if (TINT_UNLIKELY(!ty_expr)) {
+ auto* el_ty = sem_.GetType(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!el_ty)) {
return nullptr;
}
- auto* ty = ty_expr->Type();
- auto* out = b.create<core::type::Atomic>(ty);
- if (!validator_.Atomic(tmpl_ident, out)) {
+ auto* out = b.create<core::type::Atomic>(el_ty);
+ if (TINT_UNLIKELY(!validator_.Atomic(tmpl_ident, out))) {
return nullptr;
}
return out;
@@ -2848,34 +2847,32 @@
return nullptr;
}
- auto* address_space_expr = AddressSpaceExpression(tmpl_ident->arguments[0]);
- if (TINT_UNLIKELY(!address_space_expr)) {
+ auto address_space = sem_.GetAddressSpace(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(address_space == core::AddressSpace::kUndefined)) {
return nullptr;
}
- auto address_space = address_space_expr->Value();
- auto* store_ty_expr = TypeExpression(tmpl_ident->arguments[1]);
- if (TINT_UNLIKELY(!store_ty_expr)) {
+ auto* store_ty = const_cast<core::type::Type*>(sem_.GetType(tmpl_ident->arguments[1]));
+ if (TINT_UNLIKELY(!store_ty)) {
return nullptr;
}
- auto* store_ty = const_cast<core::type::Type*>(store_ty_expr->Type());
- auto access = DefaultAccessForAddressSpace(address_space);
+ core::Access access = core::Access::kUndefined;
if (tmpl_ident->arguments.Length() > 2) {
- auto* access_expr = AccessExpression(tmpl_ident->arguments[2]);
- if (TINT_UNLIKELY(!access_expr)) {
+ access = sem_.GetAccess(tmpl_ident->arguments[2]);
+ if (TINT_UNLIKELY(access == core::Access::kUndefined)) {
return nullptr;
}
- access = access_expr->Value();
+ } else {
+ access = DefaultAccessForAddressSpace(address_space);
}
auto* out = b.create<core::type::Pointer>(address_space, store_ty, access);
- if (!validator_.Pointer(tmpl_ident, out)) {
+ if (TINT_UNLIKELY(!validator_.Pointer(tmpl_ident, out))) {
return nullptr;
}
- if (!ApplyAddressSpaceUsageToType(address_space, store_ty,
- store_ty_expr->Declaration()->source)) {
+ if (!ApplyAddressSpaceUsageToType(address_space, store_ty, tmpl_ident->arguments[1]->source)) {
AddNote("while instantiating " + out->FriendlyName(), ident->source);
return nullptr;
}
@@ -2889,12 +2886,12 @@
return nullptr;
}
- auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
+ auto* ty_expr = sem_.GetType(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!ty_expr)) {
return nullptr;
}
- auto* out = b.create<core::type::SampledTexture>(dim, ty_expr->Type());
+ auto* out = b.create<core::type::SampledTexture>(dim, ty_expr);
return validator_.SampledTexture(out, ident->source) ? out : nullptr;
}
@@ -2905,12 +2902,12 @@
return nullptr;
}
- auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
+ auto* ty_expr = sem_.GetType(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!ty_expr)) {
return nullptr;
}
- auto* out = b.create<core::type::MultisampledTexture>(dim, ty_expr->Type());
+ auto* out = b.create<core::type::MultisampledTexture>(dim, ty_expr);
return validator_.MultisampledTexture(out, ident->source) ? out : nullptr;
}
@@ -2921,19 +2918,18 @@
return nullptr;
}
- auto* format = TexelFormatExpression(tmpl_ident->arguments[0]);
- if (TINT_UNLIKELY(!format)) {
+ auto format = sem_.GetTexelFormat(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(format == core::TexelFormat::kUndefined)) {
return nullptr;
}
- auto* access = AccessExpression(tmpl_ident->arguments[1]);
- if (TINT_UNLIKELY(!access)) {
+ auto access = sem_.GetAccess(tmpl_ident->arguments[1]);
+ if (TINT_UNLIKELY(access == core::Access::kUndefined)) {
return nullptr;
}
- auto* subtype = core::type::StorageTexture::SubtypeFor(format->Value(), b.Types());
- auto* tex =
- b.create<core::type::StorageTexture>(dim, format->Value(), access->Value(), subtype);
+ auto* subtype = core::type::StorageTexture::SubtypeFor(format, b.Types());
+ auto* tex = b.create<core::type::StorageTexture>(dim, format, access, subtype);
if (!validator_.StorageTexture(tex, ident->source)) {
return nullptr;
}
@@ -2947,7 +2943,7 @@
return nullptr;
}
- auto* el_ty = Type(tmpl_ident->arguments[0]);
+ auto* el_ty = sem_.GetType(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!el_ty)) {
return nullptr;
}
@@ -3976,44 +3972,56 @@
const core::type::ArrayCount* Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array count expression.
- const auto* count_sem = Materialize(ValueExpression(count_expr));
+ const auto* count_sem = Materialize(sem_.GetVal(count_expr));
if (!count_sem) {
return nullptr;
}
- if (count_sem->Stage() == core::EvaluationStage::kOverride) {
- // array count is an override expression.
- // Is the count a named 'override'?
- if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
- if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
- return b.create<sem::NamedOverrideArrayCount>(global);
+ switch (count_sem->Stage()) {
+ case core::EvaluationStage::kNotEvaluated:
+ // Happens in expressions like:
+ // false && array<T, N>()[i]
+ // The end result will not be used, so just make N=1.
+ return b.create<core::type::ConstantArrayCount>(static_cast<uint32_t>(1));
+
+ case core::EvaluationStage::kOverride: {
+ // array count is an override expression.
+ // Is the count a named 'override'?
+ if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
+ if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
+ return b.create<sem::NamedOverrideArrayCount>(global);
+ }
}
+ return b.create<sem::UnnamedOverrideArrayCount>(count_sem);
}
- return b.create<sem::UnnamedOverrideArrayCount>(count_sem);
- }
- auto* count_val = count_sem->ConstantValue();
- if (!count_val) {
- AddError("array count must evaluate to a constant integer expression or override variable",
- count_expr->source);
- return nullptr;
- }
+ case core::EvaluationStage::kConstant: {
+ auto* count_val = count_sem->ConstantValue();
+ if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
+ AddError(
+ "array count must evaluate to a constant integer expression, but is type '" +
+ ty->FriendlyName() + "'",
+ count_expr->source);
+ return nullptr;
+ }
- if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
- AddError("array count must evaluate to a constant integer expression, but is type '" +
- ty->FriendlyName() + "'",
- count_expr->source);
- return nullptr;
- }
+ int64_t count = count_val->ValueAs<AInt>();
+ if (count < 1) {
+ AddError("array count (" + std::to_string(count) + ") must be greater than 0",
+ count_expr->source);
+ return nullptr;
+ }
- int64_t count = count_val->ValueAs<AInt>();
- if (count < 1) {
- AddError("array count (" + std::to_string(count) + ") must be greater than 0",
- count_expr->source);
- return nullptr;
- }
+ return b.create<core::type::ConstantArrayCount>(static_cast<uint32_t>(count));
+ }
- return b.create<core::type::ConstantArrayCount>(static_cast<uint32_t>(count));
+ default: {
+ AddError(
+ "array count must evaluate to a constant integer expression or override variable",
+ count_expr->source);
+ return nullptr;
+ }
+ }
}
bool Resolver::ArrayAttributes(VectorRef<const ast::Attribute*> attributes,