tint: Replace all remaining AST types with ast::Type
This CL removes the following AST nodes:
* ast::Array
* ast::Atomic
* ast::Matrix
* ast::MultisampledTexture
* ast::Pointer
* ast::SampledTexture
* ast::Texture
* ast::TypeName
* ast::Vector
ast::Type, which used to be the base class for all AST types, is now a
thin wrapper around ast::IdentifierExpression. All types are now
referred to using their type name.
The resolver now handles type resolution and validation of the types
listed above based on the TemplateIdentifier arguments.
Other changes:
* ProgramBuilder has undergone substantial refactoring.
* ProgramBuilder helpers for type inferencing is now more explicit.
Instead of passing 'nullptr', a new 'Infer' template argument is
passed.
* ast::CheckIdentifier() is used for more tests that check identifiers,
including types.
Bug: tint:1810
Change-Id: I8e739ef49435dc1c20a462f3ec5ba265661a7edb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118723
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/address_space_validation_test.cc b/src/tint/resolver/address_space_validation_test.cc
index fb4981e..f96b29f 100644
--- a/src/tint/resolver/address_space_validation_test.cc
+++ b/src/tint/resolver/address_space_validation_test.cc
@@ -38,10 +38,10 @@
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_NoAddressSpace_Fail) {
// type g = ptr<f32>;
- Alias("g", ty.pointer(Source{{12, 34}}, ty.f32(), type::AddressSpace::kUndefined));
+ Alias("g", ty(Source{{12, 34}}, "ptr", ty.f32()));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: ptr missing address space");
+ EXPECT_EQ(r()->error(), "12:34 error: 'ptr' requires at least 2 template arguments");
}
TEST_F(ResolverAddressSpaceValidationTest, GlobalVariable_FunctionAddressSpace_Fail) {
@@ -473,7 +473,7 @@
}
TEST_F(ResolverAddressSpaceValidationTest, PointerAlias_NotStorage_AccessMode) {
- // type t = ptr<private, read, a>;
+ // type t = ptr<private, i32, read>;
Alias("t", ty.pointer(Source{{12, 34}}, ty.i32(), type::AddressSpace::kPrivate,
type::Access::kRead));
diff --git a/src/tint/resolver/alias_analysis_test.cc b/src/tint/resolver/alias_analysis_test.cc
index 44db995..cc55f0b 100644
--- a/src/tint/resolver/alias_analysis_test.cc
+++ b/src/tint/resolver/alias_analysis_test.cc
@@ -759,7 +759,7 @@
ty.void_(),
utils::Vector{
Assign(Phony(), MemberAccessor(Deref("p2"), "a")),
- Assign(Deref("p1"), Call(ty("S"))),
+ Assign(Deref("p1"), Call("S")),
});
Func("f1", utils::Empty, ty.void_(),
utils::Vector{
diff --git a/src/tint/resolver/assignment_validation_test.cc b/src/tint/resolver/assignment_validation_test.cc
index 642340b..cf8b505 100644
--- a/src/tint/resolver/assignment_validation_test.cc
+++ b/src/tint/resolver/assignment_validation_test.cc
@@ -373,8 +373,8 @@
Assign(Phony(), 3_f), //
Assign(Phony(), 4_a), //
Assign(Phony(), 5.0_a), //
- Assign(Phony(), vec(nullptr, 2u, 6_a)), //
- Assign(Phony(), vec(nullptr, 3u, 7.0_a)), //
+ Assign(Phony(), vec2<Infer>(6_a)), //
+ Assign(Phony(), vec3<Infer>(7.0_a)), //
Assign(Phony(), vec4<bool>()), //
Assign(Phony(), "tex"), //
Assign(Phony(), "smp"), //
diff --git a/src/tint/resolver/atomics_validation_test.cc b/src/tint/resolver/atomics_validation_test.cc
index f57cf0b..bd0c5cc 100644
--- a/src/tint/resolver/atomics_validation_test.cc
+++ b/src/tint/resolver/atomics_validation_test.cc
@@ -173,9 +173,9 @@
// var<private> g : S0;
auto* atomic_array = Alias("AtomicArray", ty.atomic(ty.i32()));
- auto* array_i32_4 = ty.array<i32, 4>();
- auto* array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8_u);
- auto* array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4_u);
+ auto array_i32_4 = ty.array<i32, 4>();
+ auto array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8_u);
+ auto array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4_u);
auto* s6 = Structure("S6", utils::Vector{Member("x", array_i32_4)});
auto* s5 = Structure("S5", utils::Vector{Member("x", ty.Of(s6)), //
@@ -272,9 +272,9 @@
// var<storage, read> g : S0;
auto* atomic_array = Alias("AtomicArray", ty.atomic(ty.i32()));
- auto* array_i32_4 = ty.array<i32, 4>();
- auto* array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8_u);
- auto* array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4_u);
+ auto array_i32_4 = ty.array<i32, 4>();
+ auto array_atomic_u32_8 = ty.array(ty.atomic(ty.u32()), 8_u);
+ auto array_atomic_i32_4 = ty.array(ty.atomic(ty.i32()), 4_u);
auto* s6 = Structure("S6", utils::Vector{Member("x", array_i32_4)});
auto* s5 = Structure("S5", utils::Vector{Member("x", ty.Of(s6)), //
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 196796f..0f09306 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -849,7 +849,7 @@
TEST_P(ArrayAttributeTest, IsValid) {
auto& params = GetParam();
- auto* arr = ty.array(ty.f32(), nullptr, createAttributes(Source{{12, 34}}, *this, params.kind));
+ auto arr = ty.array(ty.f32(), createAttributes(Source{{12, 34}}, *this, params.kind));
Structure("mystruct", utils::Vector{
Member("a", arr),
});
@@ -1145,17 +1145,17 @@
using ArrayStrideTest = TestWithParams;
TEST_P(ArrayStrideTest, All) {
auto& params = GetParam();
- auto* el_ty = params.create_el_type(*this);
+ ast::Type el_ty = params.create_el_type(*this);
std::stringstream ss;
ss << "el_ty: " << FriendlyName(el_ty) << ", stride: " << params.stride
<< ", should_pass: " << params.should_pass;
SCOPED_TRACE(ss.str());
- auto* arr = ty.array(el_ty, 4_u,
- utils::Vector{
- create<ast::StrideAttribute>(Source{{12, 34}}, params.stride),
- });
+ auto arr = ty.array(el_ty, 4_u,
+ utils::Vector{
+ create<ast::StrideAttribute>(Source{{12, 34}}, params.stride),
+ });
GlobalVar("myarray", arr, type::AddressSpace::kPrivate);
@@ -1234,11 +1234,11 @@
ParamsFor<mat4x4<f32>>((default_mat4x4.align - 1) * 7, false)));
TEST_F(ArrayStrideTest, DuplicateAttribute) {
- auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4_u,
- utils::Vector{
- create<ast::StrideAttribute>(Source{{12, 34}}, 4u),
- create<ast::StrideAttribute>(Source{{56, 78}}, 4u),
- });
+ auto arr = ty.array(Source{{12, 34}}, ty.i32(), 4_u,
+ utils::Vector{
+ create<ast::StrideAttribute>(Source{{12, 34}}, 4u),
+ create<ast::StrideAttribute>(Source{{56, 78}}, 4u),
+ });
GlobalVar("myarray", arr, type::AddressSpace::kPrivate);
diff --git a/src/tint/resolver/builtin_test.cc b/src/tint/resolver/builtin_test.cc
index 74c1d60..dc49f3a 100644
--- a/src/tint/resolver/builtin_test.cc
+++ b/src/tint/resolver/builtin_test.cc
@@ -213,7 +213,7 @@
using ResolverBuiltinArrayTest = ResolverTest;
TEST_F(ResolverBuiltinArrayTest, ArrayLength_Vector) {
- auto* ary = ty.array<i32>();
+ auto ary = ty.array<i32>();
auto* str = Structure("S", utils::Vector{Member("x", ary)});
GlobalVar("a", ty.Of(str), type::AddressSpace::kStorage, type::Access::kRead, Binding(0_a),
Group(0_a));
@@ -2097,33 +2097,26 @@
/// @param dim dimensionality of the texture being sampled
/// @param scalar the scalar type
/// @returns a pointer to a type appropriate for the coord param
- const ast::Type* GetCoordsType(type::TextureDimension dim, const ast::Type* scalar) {
+ ast::Type GetCoordsType(type::TextureDimension dim, ast::Type scalar) {
switch (dim) {
case type::TextureDimension::k1d:
- return scalar;
+ return ty(scalar);
case type::TextureDimension::k2d:
case type::TextureDimension::k2dArray:
- return ty.vec(scalar, 2);
+ return ty.vec2(scalar);
case type::TextureDimension::k3d:
case type::TextureDimension::kCube:
case type::TextureDimension::kCubeArray:
- return ty.vec(scalar, 3);
+ return ty.vec3(scalar);
default:
[=]() { FAIL() << "Unsupported texture dimension: " << dim; }();
}
- return nullptr;
+ return ast::Type{};
}
- void add_call_param(std::string name, const ast::Type* type, ExpressionList* call_params) {
- if (auto* type_name = type->As<ast::TypeName>()) {
- auto n = Symbols().NameFor(type_name->name->symbol);
- if (utils::HasPrefix(n, "texture") || utils::HasPrefix(n, "sampler")) {
- GlobalVar(name, type, Binding(0_a), Group(0_a));
- return;
- }
- }
-
- if (type->Is<ast::Texture>()) {
+ void add_call_param(std::string name, ast::Type type, ExpressionList* call_params) {
+ std::string type_name = Symbols().NameFor(type->identifier->symbol);
+ if (utils::HasPrefix(type_name, "texture") || utils::HasPrefix(type_name, "sampler")) {
GlobalVar(name, type, Binding(0_a), Group(0_a));
} else {
GlobalVar(name, type, type::AddressSpace::kPrivate);
@@ -2131,7 +2124,7 @@
call_params->Push(Expr(name));
}
- const ast::Type* subtype(Texture type) {
+ ast::Type subtype(Texture type) {
if (type == Texture::kF32) {
return ty.f32();
}
@@ -2147,9 +2140,9 @@
auto dim = GetParam().dim;
auto type = GetParam().type;
- auto* s = subtype(type);
- auto* coords_type = GetCoordsType(dim, ty.i32());
- auto* texture_type = ty.sampled_texture(dim, s);
+ ast::Type s = subtype(type);
+ ast::Type coords_type = GetCoordsType(dim, ty.i32());
+ auto texture_type = ty.sampled_texture(dim, s);
ExpressionList call_params;
diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc
index 2750d97..b9c31c9 100644
--- a/src/tint/resolver/builtin_validation_test.cc
+++ b/src/tint/resolver/builtin_validation_test.cc
@@ -115,7 +115,8 @@
WrapInFunction(Decl(Var("v", Expr(Source{{56, 78}}, "mix"))));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(56:78 error: missing '(' for function call)");
+ EXPECT_EQ(r()->error(), R"(56:78 error: cannot use function 'mix' as value
+12:34 note: function 'mix' declared here)");
}
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsGlobalConstUsedAsVariable) {
@@ -159,7 +160,7 @@
TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAliasUsedAsType) {
auto* mix = Alias(Source{{12, 34}}, "mix", ty.i32());
- auto* use = Call(ty("mix"));
+ auto* use = Call("mix");
WrapInFunction(use);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -183,7 +184,7 @@
auto* mix = Structure("mix", utils::Vector{
Member("m", ty.i32()),
});
- auto* use = Call(ty("mix"));
+ auto* use = Call("mix");
WrapInFunction(use);
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 09b1247..e7f6f14 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -1953,8 +1953,7 @@
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
- auto* rhs =
- Equal(MemberAccessor(Call(ty("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a);
+ auto* rhs = Equal(MemberAccessor(Call("S", Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
@@ -1973,8 +1972,7 @@
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
- auto* rhs =
- Equal(MemberAccessor(Call(ty("S"), Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a);
+ auto* rhs = Equal(MemberAccessor(Call("S", Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
@@ -2146,7 +2144,7 @@
// const one = 1;
// const result = (one == 0) && (s.c == 0);
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
- GlobalConst("s", Call(ty("S"), Expr(1_a), Expr(2.0_a)));
+ GlobalConst("s", Call("S", Expr(1_a), Expr(2.0_a)));
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", "c"), 0_a);
@@ -2165,7 +2163,7 @@
// const one = 1;
// const result = (one == 1) || (s.c == 0);
Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())});
- GlobalConst("s", Call(ty("S"), Expr(1_a), Expr(2.0_a)));
+ GlobalConst("s", Call("S", Expr(1_a), Expr(2.0_a)));
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", "c"), 0_a);
@@ -2187,7 +2185,7 @@
// const result = (one == 0) && (vec2(1, 2).z == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 0_a);
- auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a);
+ auto* rhs = Equal(MemberAccessor(vec2<Infer>(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a);
GlobalConst("result", LogicalAnd(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
@@ -2199,7 +2197,7 @@
// const result = (one == 1) || (vec2(1, 2).z == 0);
GlobalConst("one", Expr(1_a));
auto* lhs = Equal("one", 1_a);
- auto* rhs = Equal(MemberAccessor(vec2<AInt>(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a);
+ auto* rhs = Equal(MemberAccessor(vec2<Infer>(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a);
GlobalConst("result", LogicalOr(lhs, rhs));
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/const_eval_bitcast_test.cc b/src/tint/resolver/const_eval_bitcast_test.cc
index 32c1382..cbb09d5 100644
--- a/src/tint/resolver/const_eval_bitcast_test.cc
+++ b/src/tint/resolver/const_eval_bitcast_test.cc
@@ -64,7 +64,7 @@
target_create_ptrs = expected.Failure().create_ptrs;
}
- auto* target_ty = target_create_ptrs.ast(*this);
+ auto target_ty = target_create_ptrs.ast(*this);
ASSERT_NE(target_ty, nullptr);
auto* input_val = input.Expr(*this);
const ast::Expression* expr = Bitcast(Source{{12, 34}}, target_ty, input_val);
diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc
index 80f0c11..fbf7515 100644
--- a/src/tint/resolver/const_eval_construction_test.cc
+++ b/src/tint/resolver/const_eval_construction_test.cc
@@ -148,7 +148,7 @@
TEST_P(ResolverConstEvalZeroInitTest, Test) {
Enable(ast::Extension::kF16);
auto& param = GetParam();
- auto* ty = param.type(*this);
+ auto ty = param.type(*this);
auto* expr = Call(ty);
auto* a = Const("a", expr);
WrapInFunction(a);
@@ -552,7 +552,7 @@
}
TEST_F(ResolverConstEvalTest, Vec3_FullConstruct_AInt) {
- auto* expr = vec3<AInt>(1_a, 2_a, 3_a);
+ auto* expr = vec3<Infer>(1_a, 2_a, 3_a);
auto* a = Const("a", expr);
WrapInFunction(a);
@@ -586,7 +586,7 @@
}
TEST_F(ResolverConstEvalTest, Vec3_FullConstruct_AFloat) {
- auto* expr = vec3<AFloat>(1.0_a, 2.0_a, 3.0_a);
+ auto* expr = vec3<Infer>(1.0_a, 2.0_a, 3.0_a);
auto* a = Const("a", expr);
WrapInFunction(a);
@@ -1392,7 +1392,7 @@
}
TEST_F(ResolverConstEvalTest, Mat3x2_Construct_Scalars_af) {
- auto* expr = Call(ty.mat(nullptr, 3, 2), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a);
+ auto* expr = Call(ty.mat3x2<Infer>(), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a);
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -1441,10 +1441,10 @@
}
TEST_F(ResolverConstEvalTest, Mat3x2_Construct_Columns_af) {
- auto* expr = Call(ty.mat(nullptr, 3, 2), //
- vec(nullptr, 2u, 1.0_a, 2.0_a), //
- vec(nullptr, 2u, 3.0_a, 4.0_a), //
- vec(nullptr, 2u, 5.0_a, 6.0_a));
+ auto* expr = Call(ty.mat<Infer>(3, 2), //
+ vec2<Infer>(1.0_a, 2.0_a), //
+ vec2<Infer>(3.0_a, 4.0_a), //
+ vec2<Infer>(5.0_a, 6.0_a));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -1853,9 +1853,9 @@
Member("m1", ty.f32()),
Member("m2", ty.f32()),
});
- auto* expr = Call(ty.array(ty("S"), 2_u), //
- Call(ty("S"), 1_f, 2_f), //
- Call(ty("S"), 3_f, 4_f));
+ auto* expr = Call(ty.array(ty("S"), 2_u), //
+ Call("S", 1_f, 2_f), //
+ Call("S", 3_f, 4_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -1992,7 +1992,7 @@
TEST_F(ResolverConstEvalTest, Struct_I32s_ZeroInit) {
Structure(
"S", utils::Vector{Member("m1", ty.i32()), Member("m2", ty.i32()), Member("m3", ty.i32())});
- auto* expr = Call(ty("S"));
+ auto* expr = Call("S");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2037,7 +2037,7 @@
Member("m4", ty.f16()),
Member("m5", ty.bool_()),
});
- auto* expr = Call(ty("S"));
+ auto* expr = Call("S");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2090,7 +2090,7 @@
Member("m2", ty.vec3<f32>()),
Member("m3", ty.vec3<f32>()),
});
- auto* expr = Call(ty("S"));
+ auto* expr = Call("S");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2147,7 +2147,7 @@
Member("m4", ty.vec3<f16>()),
Member("m5", ty.vec2<bool>()),
});
- auto* expr = Call(ty("S"));
+ auto* expr = Call("S");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2224,7 +2224,7 @@
Member("m1", ty("Inner")),
Member("m2", ty("Inner")),
});
- auto* expr = Call(ty("Outer"));
+ auto* expr = Call("Outer");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2267,7 +2267,7 @@
Member("m4", ty.f16()),
Member("m5", ty.bool_()),
});
- auto* expr = Call(ty("S"), 1_i, 2_u, 3_f, 4_h, false);
+ auto* expr = Call("S", 1_i, 2_u, 3_f, 4_h, false);
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2324,7 +2324,7 @@
Member("m4", ty.vec3<f16>()),
Member("m5", ty.vec2<bool>()),
});
- auto* expr = Call(ty("S"), vec2<i32>(1_i), vec3<u32>(2_u), vec4<f32>(3_f), vec3<f16>(4_h),
+ auto* expr = Call("S", vec2<i32>(1_i), vec3<u32>(2_u), vec4<f32>(3_f), vec3<f16>(4_h),
vec2<bool>(false));
WrapInFunction(expr);
@@ -2402,8 +2402,8 @@
Member("m1", ty("Inner")),
Member("m2", ty("Inner")),
});
- auto* expr = Call(ty("Outer"), //
- Call(ty("Inner"), 1_i, 2_u, 3_f), Call(ty("Inner"), 4_i, 0_u, 6_f));
+ auto* expr = Call("Outer", //
+ Call("Inner", 1_i, 2_u, 3_f), Call("Inner", 4_i, 0_u, 6_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -2441,7 +2441,7 @@
Member("m1", ty.array<i32, 2>()),
Member("m2", ty.array<f32, 3>()),
});
- auto* expr = Call(ty("S"), //
+ auto* expr = Call("S", //
Call(ty.array<i32, 2>(), 1_i, 2_i), Call(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
WrapInFunction(expr);
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index 739294d..e159b0e 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -73,7 +73,7 @@
auto* input_val = input.Expr(*this);
auto* expr = Call(type.ast(*this), input_val);
if (kind == Kind::kVector) {
- expr = Call(ty.vec(nullptr, 3), expr);
+ expr = Call(ty.vec<Infer>(3), expr);
}
WrapInFunction(expr);
diff --git a/src/tint/resolver/const_eval_member_access_test.cc b/src/tint/resolver/const_eval_member_access_test.cc
index c1fcfa1..e93d63b 100644
--- a/src/tint/resolver/const_eval_member_access_test.cc
+++ b/src/tint/resolver/const_eval_member_access_test.cc
@@ -31,8 +31,8 @@
Member("o1", ty("Inner")),
Member("o2", ty("Inner")),
});
- auto* outer_expr = Call(ty("Outer"), //
- Call(ty("Inner"), 1_i, 2_u, 3_f, true), Call(ty("Inner")));
+ auto* outer_expr = Call("Outer", //
+ Call("Inner", 1_i, 2_u, 3_f, true), Call("Inner"));
auto* o1_expr = MemberAccessor(outer_expr, "o1");
auto* i2_expr = MemberAccessor(o1_expr, "i2");
WrapInFunction(i2_expr);
@@ -71,9 +71,9 @@
}
TEST_F(ResolverConstEvalTest, Matrix_AFloat_Construct_From_AInt_Vectors) {
- auto* c = Const("a", Call(ty.mat(nullptr, 2, 2), //
- Call(ty.vec(nullptr, 2), Expr(1_a), Expr(2_a)),
- Call(ty.vec(nullptr, 2), Expr(3_a), Expr(4_a))));
+ auto* c = Const("a", Call(ty.mat2x2<Infer>(), //
+ Call(ty.vec<Infer>(2), Expr(1_a), Expr(2_a)),
+ Call(ty.vec<Infer>(2), Expr(3_a), Expr(4_a))));
WrapInFunction(c);
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -97,9 +97,9 @@
}
TEST_F(ResolverConstEvalTest, MatrixMemberAccess_AFloat) {
- auto* c = Const("a", Call(ty.mat(nullptr, 2, 3), //
- Call(ty.vec(nullptr, 3), Expr(1.0_a), Expr(2.0_a), Expr(3.0_a)),
- Call(ty.vec(nullptr, 3), Expr(4.0_a), Expr(5.0_a), Expr(6.0_a))));
+ auto* c = Const("a", Call(ty.mat2x3<Infer>(), //
+ Call(ty.vec3<Infer>(), Expr(1.0_a), Expr(2.0_a), Expr(3.0_a)),
+ Call(ty.vec3<Infer>(), Expr(4.0_a), Expr(5.0_a), Expr(6.0_a))));
auto* col_0 = Const("col_0", IndexAccessor("a", Expr(0_i)));
auto* col_1 = Const("col_1", IndexAccessor("a", Expr(1_i)));
@@ -174,9 +174,9 @@
}
TEST_F(ResolverConstEvalTest, MatrixMemberAccess_f32) {
- auto* c = Const("a", Call(ty.mat(nullptr, 2, 3), //
- Call(ty.vec(nullptr, 3), Expr(1.0_f), Expr(2.0_f), Expr(3.0_f)),
- Call(ty.vec(nullptr, 3), Expr(4.0_f), Expr(5.0_f), Expr(6.0_f))));
+ auto* c = Const("a", Call(ty.mat2x3<Infer>(), //
+ Call(ty.vec3<Infer>(), Expr(1.0_f), Expr(2.0_f), Expr(3.0_f)),
+ Call(ty.vec3<Infer>(), Expr(4.0_f), Expr(5.0_f), Expr(6.0_f))));
auto* col_0 = Const("col_0", IndexAccessor("a", Expr(0_i)));
auto* col_1 = Const("col_1", IndexAccessor("a", Expr(1_i)));
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index e581b9a..16a3f8e 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -19,9 +19,7 @@
#include <vector>
#include "src/tint/ast/alias.h"
-#include "src/tint/ast/array.h"
#include "src/tint/ast/assignment_statement.h"
-#include "src/tint/ast/atomic.h"
#include "src/tint/ast/block_statement.h"
#include "src/tint/ast/break_if_statement.h"
#include "src/tint/ast/break_statement.h"
@@ -41,12 +39,8 @@
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/location_attribute.h"
#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/matrix.h"
-#include "src/tint/ast/multisampled_texture.h"
#include "src/tint/ast/override.h"
-#include "src/tint/ast/pointer.h"
#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/sampled_texture.h"
#include "src/tint/ast/stage_attribute.h"
#include "src/tint/ast/stride_attribute.h"
#include "src/tint/ast/struct.h"
@@ -56,10 +50,8 @@
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/templated_identifier.h"
#include "src/tint/ast/traverse_expressions.h"
-#include "src/tint/ast/type_name.h"
#include "src/tint/ast/var.h"
#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/vector.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/scope_stack.h"
@@ -180,12 +172,12 @@
Declare(str->name->symbol, str);
for (auto* member : str->members) {
TraverseAttributes(member->attributes);
- TraverseType(member->type);
+ TraverseTypeExpression(member->type);
}
},
[&](const ast::Alias* alias) {
Declare(alias->name->symbol, alias);
- TraverseType(alias->type);
+ TraverseTypeExpression(alias->type);
},
[&](const ast::Function* func) {
Declare(func->name->symbol, func);
@@ -193,10 +185,10 @@
},
[&](const ast::Variable* var) {
Declare(var->name->symbol, var);
- TraverseType(var->type);
+ TraverseTypeExpression(var->type);
TraverseAttributes(var->attributes);
if (var->initializer) {
- TraverseExpression(var->initializer);
+ TraverseValueExpression(var->initializer);
}
},
[&](const ast::DiagnosticDirective*) {
@@ -205,7 +197,9 @@
[&](const ast::Enable*) {
// Enable directives do not affect the dependency graph.
},
- [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
+ [&](const ast::ConstAssert* assertion) {
+ TraverseValueExpression(assertion->condition);
+ },
[&](Default) { UnhandledNode(diagnostics_, global->node); });
}
@@ -220,10 +214,10 @@
// with the same identifier as its type.
for (auto* param : func->params) {
TraverseAttributes(param->attributes);
- TraverseType(param->type);
+ TraverseTypeExpression(param->type);
}
// Resolve the return type
- TraverseType(func->return_type);
+ TraverseTypeExpression(func->return_type);
// Push the scope stack for the parameters and function body.
scope_stack_.Push();
@@ -257,29 +251,29 @@
Switch(
stmt, //
[&](const ast::AssignmentStatement* a) {
- TraverseExpression(a->lhs);
- TraverseExpression(a->rhs);
+ TraverseValueExpression(a->lhs);
+ TraverseValueExpression(a->rhs);
},
[&](const ast::BlockStatement* b) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatements(b->statements);
},
- [&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); },
- [&](const ast::CallStatement* r) { TraverseExpression(r->expr); },
+ [&](const ast::BreakIfStatement* b) { TraverseValueExpression(b->condition); },
+ [&](const ast::CallStatement* r) { TraverseValueExpression(r->expr); },
[&](const ast::CompoundAssignmentStatement* a) {
- TraverseExpression(a->lhs);
- TraverseExpression(a->rhs);
+ TraverseValueExpression(a->lhs);
+ TraverseValueExpression(a->rhs);
},
[&](const ast::ForLoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatement(l->initializer);
- TraverseExpression(l->condition);
+ TraverseValueExpression(l->condition);
TraverseStatement(l->continuing);
TraverseStatement(l->body);
},
- [&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); },
+ [&](const ast::IncrementDecrementStatement* i) { TraverseValueExpression(i->lhs); },
[&](const ast::LoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
@@ -287,18 +281,18 @@
TraverseStatement(l->continuing);
},
[&](const ast::IfStatement* i) {
- TraverseExpression(i->condition);
+ TraverseValueExpression(i->condition);
TraverseStatement(i->body);
if (i->else_statement) {
TraverseStatement(i->else_statement);
}
},
- [&](const ast::ReturnStatement* r) { TraverseExpression(r->value); },
+ [&](const ast::ReturnStatement* r) { TraverseValueExpression(r->value); },
[&](const ast::SwitchStatement* s) {
- TraverseExpression(s->condition);
+ TraverseValueExpression(s->condition);
for (auto* c : s->body) {
for (auto* sel : c->selectors) {
- TraverseExpression(sel->expr);
+ TraverseValueExpression(sel->expr);
}
TraverseStatement(c->body);
}
@@ -307,17 +301,19 @@
if (auto* shadows = scope_stack_.Get(v->variable->name->symbol)) {
graph_.shadows.Add(v->variable, shadows);
}
- TraverseType(v->variable->type);
- TraverseExpression(v->variable->initializer);
+ TraverseTypeExpression(v->variable->type);
+ TraverseValueExpression(v->variable->initializer);
Declare(v->variable->name->symbol, v->variable);
},
[&](const ast::WhileStatement* w) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
- TraverseExpression(w->condition);
+ TraverseValueExpression(w->condition);
TraverseStatement(w->body);
},
- [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
+ [&](const ast::ConstAssert* assertion) {
+ TraverseValueExpression(assertion->condition);
+ },
[&](Default) {
if (TINT_UNLIKELY((!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement>()))) {
@@ -337,87 +333,63 @@
}
}
- /// Traverses the expression, performing symbol resolution and determining global dependencies.
- void TraverseExpression(const ast::Expression* root) {
- if (!root) {
+ /// Traverses the expression @p root_expr for the intended use as a value, performing symbol
+ /// resolution and determining global dependencies.
+ void TraverseValueExpression(const ast::Expression* root) {
+ TraverseExpression(root, "identifier", "references");
+ }
+
+ /// Traverses the expression @p root_expr for the intended use as a type, performing symbol
+ /// resolution and determining global dependencies.
+ void TraverseTypeExpression(const ast::Expression* root) {
+ TraverseExpression(root, "type", "references");
+ }
+
+ /// Traverses the expression @p root_expr for the intended use as a call target, performing
+ /// symbol resolution and determining global dependencies.
+ void TraverseCallableExpression(const ast::Expression* root) {
+ TraverseExpression(root, "function", "calls");
+ }
+
+ /// Traverses the expression @p root_expr, performing symbol resolution and determining global
+ /// dependencies.
+ void TraverseExpression(const ast::Expression* root_expr,
+ const char* root_use,
+ const char* root_action) {
+ if (!root_expr) {
return;
}
- utils::Vector<const ast::Expression*, 8> pending{root};
+
+ struct Pending {
+ const ast::Expression* expr;
+ const char* use;
+ const char* action;
+ };
+ utils::Vector<Pending, 8> pending{{root_expr, root_use, root_action}};
while (!pending.IsEmpty()) {
- ast::TraverseExpressions(pending.Pop(), diagnostics_, [&](const ast::Expression* expr) {
+ auto next = pending.Pop();
+ ast::TraverseExpressions(next.expr, diagnostics_, [&](const ast::Expression* expr) {
Switch(
expr,
[&](const ast::IdentifierExpression* e) {
- AddDependency(e->identifier, e->identifier->symbol, "identifier",
- "references");
+ AddDependency(e->identifier, e->identifier->symbol, next.use, next.action);
if (auto* tmpl_ident = e->identifier->As<ast::TemplatedIdentifier>()) {
for (auto* arg : tmpl_ident->arguments) {
- pending.Push(arg);
+ pending.Push({arg, "identifier", "references"});
}
}
},
[&](const ast::CallExpression* call) {
- if (call->target.name) {
- AddDependency(call->target.name, call->target.name->symbol, "function",
- "calls");
- if (auto* tmpl_ident =
- call->target.name->As<ast::TemplatedIdentifier>()) {
- for (auto* arg : tmpl_ident->arguments) {
- pending.Push(arg);
- }
- }
- }
- if (call->target.type) {
- TraverseType(call->target.type);
- }
+ TraverseCallableExpression(call->target);
},
- [&](const ast::BitcastExpression* cast) { TraverseType(cast->type); });
+ [&](const ast::BitcastExpression* cast) {
+ TraverseTypeExpression(cast->type);
+ });
return ast::TraverseAction::Descend;
});
}
}
- /// Traverses the type node, performing symbol resolution and determining
- /// global dependencies.
- void TraverseType(const ast::Type* ty) {
- if (!ty) {
- return;
- }
- Switch(
- ty, //
- [&](const ast::Array* arr) {
- TraverseType(arr->type); //
- TraverseExpression(arr->count);
- },
- [&](const ast::Atomic* atomic) { //
- TraverseType(atomic->type);
- },
- [&](const ast::Matrix* mat) { //
- TraverseType(mat->type);
- },
- [&](const ast::Pointer* ptr) { //
- TraverseType(ptr->type);
- },
- [&](const ast::TypeName* tn) { //
- AddDependency(tn->name, tn->name->symbol, "type", "references");
- if (auto* tmpl_ident = tn->name->As<ast::TemplatedIdentifier>()) {
- for (auto* arg : tmpl_ident->arguments) {
- TraverseExpression(arg);
- }
- }
- },
- [&](const ast::Vector* vec) { //
- TraverseType(vec->type);
- },
- [&](const ast::SampledTexture* tex) { //
- TraverseType(tex->type);
- },
- [&](const ast::MultisampledTexture* tex) { //
- TraverseType(tex->type);
- },
- [&](Default) { UnhandledNode(diagnostics_, ty); });
- }
-
/// Traverses the attribute list, performing symbol resolution and
/// determining global dependencies.
void TraverseAttributes(utils::VectorRef<const ast::Attribute*> attrs) {
@@ -432,33 +404,33 @@
bool handled = Switch(
attr,
[&](const ast::BindingAttribute* binding) {
- TraverseExpression(binding->expr);
+ TraverseValueExpression(binding->expr);
return true;
},
[&](const ast::GroupAttribute* group) {
- TraverseExpression(group->expr);
+ TraverseValueExpression(group->expr);
return true;
},
[&](const ast::IdAttribute* id) {
- TraverseExpression(id->expr);
+ TraverseValueExpression(id->expr);
return true;
},
[&](const ast::LocationAttribute* loc) {
- TraverseExpression(loc->expr);
+ TraverseValueExpression(loc->expr);
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
- TraverseExpression(align->expr);
+ TraverseValueExpression(align->expr);
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
- TraverseExpression(size->expr);
+ TraverseValueExpression(size->expr);
return true;
},
[&](const ast::WorkgroupAttribute* wg) {
- TraverseExpression(wg->x);
- TraverseExpression(wg->y);
- TraverseExpression(wg->z);
+ TraverseValueExpression(wg->x);
+ TraverseValueExpression(wg->y);
+ TraverseValueExpression(wg->z);
return true;
});
if (handled) {
@@ -517,8 +489,7 @@
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved));
}
- /// Appends an error to the diagnostics that the given symbol cannot be
- /// resolved.
+ /// Appends an error to the diagnostics that the given symbol cannot be resolved.
void UnknownSymbol(Symbol name, Source source, const char* use) {
AddError(diagnostics_, "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'",
source);
diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h
index 8b2e668..7819b52 100644
--- a/src/tint/resolver/dependency_graph.h
+++ b/src/tint/resolver/dependency_graph.h
@@ -21,6 +21,7 @@
#include "src/tint/ast/module.h"
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/sem/builtin_type.h"
+#include "src/tint/symbol_table.h"
#include "src/tint/type/access.h"
#include "src/tint/type/builtin.h"
#include "src/tint/type/texel_format.h"
diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc
index 7a98b2f..8d61599 100644
--- a/src/tint/resolver/dependency_graph_test.cc
+++ b/src/tint/resolver/dependency_graph_test.cc
@@ -277,26 +277,26 @@
std::string DiagString(SymbolUseKind kind) {
switch (kind) {
case SymbolUseKind::GlobalVarType:
+ case SymbolUseKind::GlobalConstType:
+ case SymbolUseKind::AliasType:
+ case SymbolUseKind::StructMemberType:
+ case SymbolUseKind::ParameterType:
+ case SymbolUseKind::LocalVarType:
+ case SymbolUseKind::LocalLetType:
+ case SymbolUseKind::NestedLocalVarType:
+ case SymbolUseKind::NestedLocalLetType:
+ return "type";
case SymbolUseKind::GlobalVarArrayElemType:
case SymbolUseKind::GlobalVarVectorElemType:
case SymbolUseKind::GlobalVarMatrixElemType:
case SymbolUseKind::GlobalVarSampledTexElemType:
case SymbolUseKind::GlobalVarMultisampledTexElemType:
- case SymbolUseKind::GlobalConstType:
case SymbolUseKind::GlobalConstArrayElemType:
case SymbolUseKind::GlobalConstVectorElemType:
case SymbolUseKind::GlobalConstMatrixElemType:
- case SymbolUseKind::AliasType:
- case SymbolUseKind::StructMemberType:
- case SymbolUseKind::ParameterType:
- case SymbolUseKind::LocalVarType:
case SymbolUseKind::LocalVarArrayElemType:
case SymbolUseKind::LocalVarVectorElemType:
case SymbolUseKind::LocalVarMatrixElemType:
- case SymbolUseKind::LocalLetType:
- case SymbolUseKind::NestedLocalVarType:
- case SymbolUseKind::NestedLocalLetType:
- return "type";
case SymbolUseKind::GlobalVarValue:
case SymbolUseKind::GlobalVarArraySizeValue:
case SymbolUseKind::GlobalConstValue:
@@ -469,14 +469,14 @@
auto& b = *builder;
switch (kind) {
case SymbolUseKind::GlobalVarType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), node, type::AddressSpace::kPrivate);
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarArrayElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.array(node, 4_i), type::AddressSpace::kPrivate);
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
@@ -484,24 +484,24 @@
return node->identifier;
}
case SymbolUseKind::GlobalVarVectorElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.vec3(node), type::AddressSpace::kPrivate);
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarMatrixElemType: {
- auto* node = b.ty(source, symbol);
+ ast::Type node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.mat3x4(node), type::AddressSpace::kPrivate);
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarSampledTexElemType: {
- auto* node = b.ty(source, symbol);
+ ast::Type node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.sampled_texture(type::TextureDimension::k2d, node));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarMultisampledTexElemType: {
- auto* node = b.ty(source, symbol);
+ ast::Type node = b.ty(source, symbol);
b.GlobalVar(b.Sym(), b.ty.multisampled_texture(type::TextureDimension::k2d, node));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalVarValue: {
auto* node = b.Expr(source, symbol);
@@ -509,14 +509,14 @@
return node->identifier;
}
case SymbolUseKind::GlobalConstType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), node, b.Expr(1_i));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalConstArrayElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.array(node, 4_i), b.Expr(1_i));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalConstArraySizeValue: {
auto* node = b.Expr(source, symbol);
@@ -524,14 +524,14 @@
return node->identifier;
}
case SymbolUseKind::GlobalConstVectorElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.vec3(node), b.Expr(1_i));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalConstMatrixElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.GlobalConst(b.Sym(), b.ty.mat3x4(node), b.Expr(1_i));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::GlobalConstValue: {
auto* node = b.Expr(source, symbol);
@@ -539,14 +539,14 @@
return node->identifier;
}
case SymbolUseKind::AliasType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.Alias(b.Sym(), node);
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::StructMemberType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
b.Structure(b.Sym(), utils::Vector{b.Member("m", node)});
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::CallFunction: {
auto* node = b.Ident(source, symbol);
@@ -554,19 +554,19 @@
return node;
}
case SymbolUseKind::ParameterType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
parameters.Push(b.Param(b.Sym(), node));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalVarType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), node)));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalVarArrayElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.array(node, 4_u), b.Expr(1_i))));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalVarArraySizeValue: {
auto* node = b.Expr(source, symbol);
@@ -574,14 +574,14 @@
return node->identifier;
}
case SymbolUseKind::LocalVarVectorElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.vec3(node))));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalVarMatrixElemType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
statements.Push(b.Decl(b.Var(b.Sym(), b.ty.mat3x4(node))));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalVarValue: {
auto* node = b.Expr(source, symbol);
@@ -589,9 +589,9 @@
return node->identifier;
}
case SymbolUseKind::LocalLetType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i))));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::LocalLetValue: {
auto* node = b.Expr(source, symbol);
@@ -599,9 +599,9 @@
return node->identifier;
}
case SymbolUseKind::NestedLocalVarType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
nested_statements.Push(b.Decl(b.Var(b.Sym(), node)));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::NestedLocalVarValue: {
auto* node = b.Expr(source, symbol);
@@ -609,9 +609,9 @@
return node->identifier;
}
case SymbolUseKind::NestedLocalLetType: {
- auto* node = b.ty(source, symbol);
+ auto node = b.ty(source, symbol);
nested_statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i))));
- return node->name;
+ return node->identifier;
}
case SymbolUseKind::NestedLocalLetValue: {
auto* node = b.Expr(source, symbol);
@@ -666,7 +666,7 @@
// type T = i32;
Func("F", utils::Empty, ty.void_(),
- utils::Vector{Block(Ignore(Call(ty(Source{{12, 34}}, "T"))))});
+ utils::Vector{Block(Ignore(Call(Ident(Source{{12, 34}}, "T"))))});
Alias(Source{{56, 78}}, "T", ty.i32());
Build();
@@ -1638,10 +1638,6 @@
return expr->identifier;
}
-static const ast::Identifier* IdentifierOf(const ast::TypeName* ty) {
- return ty->name;
-}
-
static const ast::Identifier* IdentifierOf(const ast::Identifier* ident) {
return ident;
}
@@ -1665,7 +1661,7 @@
utils::Vector<SymbolUse, 64> symbol_uses;
- auto add_use = [&](const ast::Node* decl, auto* use, int line, const char* kind) {
+ auto add_use = [&](const ast::Node* decl, auto use, int line, const char* kind) {
symbol_uses.Push(
SymbolUse{decl, IdentifierOf(use),
std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind});
@@ -1776,7 +1772,7 @@
Structure("B", utils::Vector{Member("b", ty.i32())});
Func("f", utils::Vector{Param("a", ty("A"))}, ty("B"),
utils::Vector{
- Return(Call(ty("B"))),
+ Return(Call("B")),
});
Build();
}
diff --git a/src/tint/resolver/evaluation_stage_test.cc b/src/tint/resolver/evaluation_stage_test.cc
index 28c8a99..2e148e5 100644
--- a/src/tint/resolver/evaluation_stage_test.cc
+++ b/src/tint/resolver/evaluation_stage_test.cc
@@ -270,7 +270,7 @@
// const str = S();
// str.m
Structure("S", utils::Vector{Member("m", ty.i32())});
- auto* str = Const("str", Call(ty("S")));
+ auto* str = Const("str", Call("S"));
auto* expr = MemberAccessor(str, "m");
WrapInFunction(str, expr);
@@ -284,7 +284,7 @@
// var str = S();
// str.m
Structure("S", utils::Vector{Member("m", ty.i32())});
- auto* str = Var("str", Call(ty("S")));
+ auto* str = Var("str", Call("S"));
auto* expr = MemberAccessor(str, "m");
WrapInFunction(str, expr);
diff --git a/src/tint/resolver/expression_kind_test.cc b/src/tint/resolver/expression_kind_test.cc
index 7737654..c8cea54 100644
--- a/src/tint/resolver/expression_kind_test.cc
+++ b/src/tint/resolver/expression_kind_test.cc
@@ -119,78 +119,144 @@
TEST_P(ResolverExpressionKindTest, Test) {
Symbol sym;
+ std::function<void(const sem::Expression*)> check_expr;
switch (GetParam().def) {
- case Def::kAccess:
+ case Def::kAccess: {
sym = Sym("write");
+ check_expr = [](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::Access>>();
+ ASSERT_NE(enum_expr, nullptr);
+ EXPECT_EQ(enum_expr->Value(), type::Access::kWrite);
+ };
break;
- case Def::kAddressSpace:
+ }
+ case Def::kAddressSpace: {
sym = Sym("workgroup");
+ check_expr = [](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::AddressSpace>>();
+ ASSERT_NE(enum_expr, nullptr);
+ EXPECT_EQ(enum_expr->Value(), type::AddressSpace::kWorkgroup);
+ };
break;
- case Def::kBuiltinFunction:
+ }
+ case Def::kBuiltinFunction: {
sym = Sym("workgroupBarrier");
+ check_expr = [](const sem::Expression* expr) { EXPECT_EQ(expr, nullptr); };
break;
- case Def::kBuiltinType:
+ }
+ case Def::kBuiltinType: {
sym = Sym("vec4f");
+ check_expr = [](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* ty_expr = expr->As<sem::TypeExpression>();
+ ASSERT_NE(ty_expr, nullptr);
+ EXPECT_TRUE(ty_expr->Type()->Is<type::Vector>());
+ };
break;
- case Def::kFunction:
- Func(kDefSource, "FUNCTION", utils::Empty, ty.i32(), Return(1_i));
+ }
+ case Def::kFunction: {
+ auto* fn = Func(kDefSource, "FUNCTION", utils::Empty, ty.i32(), Return(1_i));
sym = Sym("FUNCTION");
+ check_expr = [fn](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* fn_expr = expr->As<sem::FunctionExpression>();
+ ASSERT_NE(fn_expr, nullptr);
+ EXPECT_EQ(fn_expr->Function()->Declaration(), fn);
+ };
break;
- case Def::kStruct:
- Structure(kDefSource, "STRUCT", utils::Vector{Member("m", ty.i32())});
+ }
+ case Def::kStruct: {
+ auto* s = Structure(kDefSource, "STRUCT", utils::Vector{Member("m", ty.i32())});
sym = Sym("STRUCT");
+ check_expr = [s](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* ty_expr = expr->As<sem::TypeExpression>();
+ ASSERT_NE(ty_expr, nullptr);
+ auto* got = ty_expr->Type()->As<sem::Struct>();
+ ASSERT_NE(got, nullptr);
+ EXPECT_EQ(got->Declaration(), s);
+ };
break;
- case Def::kTexelFormat:
+ }
+ case Def::kTexelFormat: {
sym = Sym("rgba8unorm");
+ check_expr = [](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::TexelFormat>>();
+ ASSERT_NE(enum_expr, nullptr);
+ EXPECT_EQ(enum_expr->Value(), type::TexelFormat::kRgba8Unorm);
+ };
break;
- case Def::kTypeAlias:
+ }
+ case Def::kTypeAlias: {
Alias(kDefSource, "ALIAS", ty.i32());
sym = Sym("ALIAS");
+ check_expr = [](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* ty_expr = expr->As<sem::TypeExpression>();
+ ASSERT_NE(ty_expr, nullptr);
+ EXPECT_TRUE(ty_expr->Type()->Is<type::I32>());
+ };
break;
- case Def::kVariable:
- GlobalConst(kDefSource, "VARIABLE", Expr(1_i));
+ }
+ case Def::kVariable: {
+ auto* c = GlobalConst(kDefSource, "VARIABLE", Expr(1_i));
sym = Sym("VARIABLE");
+ check_expr = [c](const sem::Expression* expr) {
+ ASSERT_NE(expr, nullptr);
+ auto* var_expr = expr->As<sem::VariableUser>();
+ ASSERT_NE(var_expr, nullptr);
+ EXPECT_EQ(var_expr->Variable()->Declaration(), c);
+ };
break;
+ }
}
+ auto* expr = Expr(Ident(kUseSource, sym));
switch (GetParam().use) {
case Use::kAccess:
- GlobalVar("v", ty("texture_storage_2d", "rgba8unorm", sym), Group(0_u), Binding(0_u));
+ GlobalVar("v", ty("texture_storage_2d", "rgba8unorm", expr), Group(0_u), Binding(0_u));
break;
case Use::kAddressSpace:
- return; // TODO(crbug.com/tint/1810)
+ Enable(ast::Extension::kChromiumExperimentalFullPtrParameters);
+ Func("f", utils::Vector{Param("p", ty("ptr", expr, ty.f32()))}, ty.void_(),
+ utils::Empty);
+ break;
case Use::kCallExpr:
- Func("f", utils::Empty, ty.void_(), Decl(Var("v", Call(Ident(kUseSource, sym)))));
+ Func("f", utils::Empty, ty.void_(), Decl(Var("v", Call(expr))));
break;
case Use::kCallStmt:
- Func("f", utils::Empty, ty.void_(), CallStmt(Call(Ident(kUseSource, sym))));
+ Func("f", utils::Empty, ty.void_(), CallStmt(Call(expr)));
break;
case Use::kBinaryOp:
- GlobalVar("v", type::AddressSpace::kPrivate, Mul(1_a, Expr(kUseSource, sym)));
+ GlobalVar("v", type::AddressSpace::kPrivate, Mul(1_a, expr));
break;
case Use::kFunctionReturnType:
- Func("f", utils::Empty, ty(kUseSource, sym), Return(Call(sym)));
+ Func("f", utils::Empty, ty(expr), Return(Call(sym)));
break;
case Use::kMemberType:
- Structure("s", utils::Vector{Member("m", ty(kUseSource, sym))});
+ Structure("s", utils::Vector{Member("m", ty(expr))});
break;
case Use::kTexelFormat:
- GlobalVar("v", ty("texture_storage_2d", sym, "write"), Group(0_u), Binding(0_u));
+ GlobalVar("v", ty("texture_storage_2d", ty(expr), "write"), Group(0_u), Binding(0_u));
break;
case Use::kValueExpression:
- GlobalVar("v", type::AddressSpace::kPrivate, Expr(kUseSource, sym));
+ GlobalVar("v", type::AddressSpace::kPrivate, expr);
break;
case Use::kVariableType:
- GlobalVar("v", type::AddressSpace::kPrivate, ty(kUseSource, sym));
+ GlobalVar("v", type::AddressSpace::kPrivate, ty(expr));
break;
case Use::kUnaryOp:
- GlobalVar("v", type::AddressSpace::kPrivate, Negation(Expr(kUseSource, sym)));
+ GlobalVar("v", type::AddressSpace::kPrivate, Negation(expr));
break;
}
if (GetParam().error == kPass) {
EXPECT_TRUE(r()->Resolve());
EXPECT_EQ(r()->error(), "");
+ check_expr(Sem().Get(expr));
} else {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), GetParam().error);
@@ -202,19 +268,21 @@
ResolverExpressionKindTest,
testing::ValuesIn(std::vector<Case>{
{Def::kAccess, Use::kAccess, kPass},
- {Def::kAccess, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ {Def::kAccess, Use::kAddressSpace,
+ R"(5:6 error: cannot use access 'write' as address space)"},
{Def::kAccess, Use::kBinaryOp, R"(5:6 error: cannot use access 'write' as value)"},
{Def::kAccess, Use::kCallExpr, R"(5:6 error: cannot use access 'write' as call target)"},
{Def::kAccess, Use::kCallStmt, R"(5:6 error: cannot use access 'write' as call target)"},
{Def::kAccess, Use::kFunctionReturnType, R"(5:6 error: cannot use access 'write' as type)"},
{Def::kAccess, Use::kMemberType, R"(5:6 error: cannot use access 'write' as type)"},
- {Def::kAccess, Use::kTexelFormat, R"(error: cannot use access 'write' as texel format)"},
+ {Def::kAccess, Use::kTexelFormat,
+ R"(5:6 error: cannot use access 'write' as texel format)"},
{Def::kAccess, Use::kValueExpression, R"(5:6 error: cannot use access 'write' as value)"},
{Def::kAccess, Use::kVariableType, R"(5:6 error: cannot use access 'write' as type)"},
{Def::kAccess, Use::kUnaryOp, R"(5:6 error: cannot use access 'write' as value)"},
{Def::kAddressSpace, Use::kAccess,
- R"(error: cannot use address space 'workgroup' as access)"},
+ R"(5:6 error: cannot use address space 'workgroup' as access)"},
{Def::kAddressSpace, Use::kAddressSpace, kPass},
{Def::kAddressSpace, Use::kBinaryOp,
R"(5:6 error: cannot use address space 'workgroup' as value)"},
@@ -227,7 +295,7 @@
{Def::kAddressSpace, Use::kMemberType,
R"(5:6 error: cannot use address space 'workgroup' as type)"},
{Def::kAddressSpace, Use::kTexelFormat,
- R"(error: cannot use address space 'workgroup' as texel format)"},
+ R"(5:6 error: cannot use address space 'workgroup' as texel format)"},
{Def::kAddressSpace, Use::kValueExpression,
R"(5:6 error: cannot use address space 'workgroup' as value)"},
{Def::kAddressSpace, Use::kVariableType,
@@ -235,26 +303,29 @@
{Def::kAddressSpace, Use::kUnaryOp,
R"(5:6 error: cannot use address space 'workgroup' as value)"},
- {Def::kBuiltinFunction, Use::kAccess, R"(error: missing '(' for builtin function call)"},
- {Def::kBuiltinFunction, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ {Def::kBuiltinFunction, Use::kAccess,
+ R"(7:8 error: missing '(' for builtin function call)"},
+ {Def::kBuiltinFunction, Use::kAddressSpace,
+ R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kBinaryOp,
R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kCallStmt, kPass},
{Def::kBuiltinFunction, Use::kFunctionReturnType,
- R"(5:6 error: cannot use builtin function 'workgroupBarrier' as type)"},
+ R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kMemberType,
- R"(5:6 error: cannot use builtin function 'workgroupBarrier' as type)"},
+ R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kTexelFormat,
- R"(error: missing '(' for builtin function call)"},
+ R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kValueExpression,
R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kVariableType,
- R"(5:6 error: cannot use builtin function 'workgroupBarrier' as type)"},
+ R"(7:8 error: missing '(' for builtin function call)"},
{Def::kBuiltinFunction, Use::kUnaryOp,
R"(7:8 error: missing '(' for builtin function call)"},
- {Def::kBuiltinType, Use::kAccess, R"(error: cannot use type 'vec4<f32>' as access)"},
- {Def::kBuiltinType, Use::kAddressSpace, kPass},
+ {Def::kBuiltinType, Use::kAccess, R"(5:6 error: cannot use type 'vec4<f32>' as access)"},
+ {Def::kBuiltinType, Use::kAddressSpace,
+ R"(5:6 error: cannot use type 'vec4<f32>' as address space)"},
{Def::kBuiltinType, Use::kBinaryOp,
R"(5:6 error: cannot use type 'vec4<f32>' as value
7:8 note: are you missing '()' for type initializer?)"},
@@ -262,7 +333,7 @@
{Def::kBuiltinType, Use::kFunctionReturnType, kPass},
{Def::kBuiltinType, Use::kMemberType, kPass},
{Def::kBuiltinType, Use::kTexelFormat,
- R"(error: cannot use type 'vec4<f32>' as texel format)"},
+ R"(5:6 error: cannot use type 'vec4<f32>' as texel format)"},
{Def::kBuiltinType, Use::kValueExpression,
R"(5:6 error: cannot use type 'vec4<f32>' as value
7:8 note: are you missing '()' for type initializer?)"},
@@ -271,9 +342,13 @@
R"(5:6 error: cannot use type 'vec4<f32>' as value
7:8 note: are you missing '()' for type initializer?)"},
- {Def::kFunction, Use::kAccess, R"(error: missing '(' for function call)"},
- {Def::kFunction, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
- {Def::kFunction, Use::kBinaryOp, R"(7:8 error: missing '(' for function call)"},
+ {Def::kFunction, Use::kAccess, R"(5:6 error: cannot use function 'FUNCTION' as access
+1:2 note: function 'FUNCTION' declared here)"},
+ {Def::kFunction, Use::kAddressSpace,
+ R"(5:6 error: cannot use function 'FUNCTION' as address space
+1:2 note: function 'FUNCTION' declared here)"},
+ {Def::kFunction, Use::kBinaryOp, R"(5:6 error: cannot use function 'FUNCTION' as value
+1:2 note: function 'FUNCTION' declared here)"},
{Def::kFunction, Use::kCallExpr, kPass},
{Def::kFunction, Use::kCallStmt, kPass},
{Def::kFunction, Use::kFunctionReturnType,
@@ -282,21 +357,27 @@
{Def::kFunction, Use::kMemberType,
R"(5:6 error: cannot use function 'FUNCTION' as type
1:2 note: function 'FUNCTION' declared here)"},
- {Def::kFunction, Use::kTexelFormat, R"(error: missing '(' for function call)"},
- {Def::kFunction, Use::kValueExpression, R"(7:8 error: missing '(' for function call)"},
+ {Def::kFunction, Use::kTexelFormat,
+ R"(5:6 error: cannot use function 'FUNCTION' as texel format
+1:2 note: function 'FUNCTION' declared here)"},
+ {Def::kFunction, Use::kValueExpression,
+ R"(5:6 error: cannot use function 'FUNCTION' as value
+1:2 note: function 'FUNCTION' declared here)"},
{Def::kFunction, Use::kVariableType,
R"(5:6 error: cannot use function 'FUNCTION' as type
1:2 note: function 'FUNCTION' declared here)"},
- {Def::kFunction, Use::kUnaryOp, R"(7:8 error: missing '(' for function call)"},
+ {Def::kFunction, Use::kUnaryOp, R"(5:6 error: cannot use function 'FUNCTION' as value
+1:2 note: function 'FUNCTION' declared here)"},
- {Def::kStruct, Use::kAccess, R"(error: cannot use type 'STRUCT' as access)"},
- {Def::kStruct, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ {Def::kStruct, Use::kAccess, R"(5:6 error: cannot use type 'STRUCT' as access)"},
+ {Def::kStruct, Use::kAddressSpace,
+ R"(5:6 error: cannot use type 'STRUCT' as address space)"},
{Def::kStruct, Use::kBinaryOp, R"(5:6 error: cannot use type 'STRUCT' as value
7:8 note: are you missing '()' for type initializer?
1:2 note: struct 'STRUCT' declared here)"},
{Def::kStruct, Use::kFunctionReturnType, kPass},
{Def::kStruct, Use::kMemberType, kPass},
- {Def::kStruct, Use::kTexelFormat, R"(error: cannot use type 'STRUCT' as texel format)"},
+ {Def::kStruct, Use::kTexelFormat, R"(5:6 error: cannot use type 'STRUCT' as texel format)"},
{Def::kStruct, Use::kValueExpression,
R"(5:6 error: cannot use type 'STRUCT' as value
7:8 note: are you missing '()' for type initializer?
@@ -308,8 +389,9 @@
1:2 note: struct 'STRUCT' declared here)"},
{Def::kTexelFormat, Use::kAccess,
- R"(error: cannot use texel format 'rgba8unorm' as access)"},
- {Def::kTexelFormat, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ R"(5:6 error: cannot use texel format 'rgba8unorm' as access)"},
+ {Def::kTexelFormat, Use::kAddressSpace,
+ R"(5:6 error: cannot use texel format 'rgba8unorm' as address space)"},
{Def::kTexelFormat, Use::kBinaryOp,
R"(5:6 error: cannot use texel format 'rgba8unorm' as value)"},
{Def::kTexelFormat, Use::kCallExpr,
@@ -328,15 +410,16 @@
{Def::kTexelFormat, Use::kUnaryOp,
R"(5:6 error: cannot use texel format 'rgba8unorm' as value)"},
- {Def::kTypeAlias, Use::kAccess, R"(error: cannot use type 'i32' as access)"},
- {Def::kTypeAlias, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ {Def::kTypeAlias, Use::kAccess, R"(5:6 error: cannot use type 'i32' as access)"},
+ {Def::kTypeAlias, Use::kAddressSpace,
+ R"(5:6 error: cannot use type 'i32' as address space)"},
{Def::kTypeAlias, Use::kBinaryOp,
R"(5:6 error: cannot use type 'i32' as value
7:8 note: are you missing '()' for type initializer?)"},
{Def::kTypeAlias, Use::kCallExpr, kPass},
{Def::kTypeAlias, Use::kFunctionReturnType, kPass},
{Def::kTypeAlias, Use::kMemberType, kPass},
- {Def::kTypeAlias, Use::kTexelFormat, R"(error: cannot use type 'i32' as texel format)"},
+ {Def::kTypeAlias, Use::kTexelFormat, R"(5:6 error: cannot use type 'i32' as texel format)"},
{Def::kTypeAlias, Use::kValueExpression,
R"(5:6 error: cannot use type 'i32' as value
7:8 note: are you missing '()' for type initializer?)"},
@@ -345,8 +428,11 @@
R"(5:6 error: cannot use type 'i32' as value
7:8 note: are you missing '()' for type initializer?)"},
- {Def::kVariable, Use::kAccess, R"(error: cannot use 'VARIABLE' of type 'i32' as access)"},
- {Def::kVariable, Use::kAddressSpace, R"(TODO(crbug.com/tint/1810))"},
+ {Def::kVariable, Use::kAccess, R"(5:6 error: cannot use const 'VARIABLE' as access
+1:2 note: const 'VARIABLE' declared here)"},
+ {Def::kVariable, Use::kAddressSpace,
+ R"(5:6 error: cannot use const 'VARIABLE' as address space
+1:2 note: const 'VARIABLE' declared here)"},
{Def::kVariable, Use::kBinaryOp, kPass},
{Def::kVariable, Use::kCallStmt,
R"(5:6 error: cannot use const 'VARIABLE' as call target
@@ -361,7 +447,8 @@
R"(5:6 error: cannot use const 'VARIABLE' as type
1:2 note: const 'VARIABLE' declared here)"},
{Def::kVariable, Use::kTexelFormat,
- R"(error: cannot use 'VARIABLE' of type 'i32' as texel format)"},
+ R"(5:6 error: cannot use const 'VARIABLE' as texel format
+1:2 note: const 'VARIABLE' declared here)"},
{Def::kVariable, Use::kValueExpression, kPass},
{Def::kVariable, Use::kVariableType,
R"(5:6 error: cannot use const 'VARIABLE' as type
diff --git a/src/tint/resolver/f16_extension_test.cc b/src/tint/resolver/f16_extension_test.cc
index 1d327e5..f0d515f 100644
--- a/src/tint/resolver/f16_extension_test.cc
+++ b/src/tint/resolver/f16_extension_test.cc
@@ -65,7 +65,7 @@
// var<private> v = vec2<f16>();
Enable(ast::Extension::kF16);
- GlobalVar("v", Call(ty.vec2<f16>()), type::AddressSpace::kPrivate);
+ GlobalVar("v", vec2<f16>(), type::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
@@ -83,15 +83,14 @@
// var<private> v = vec2<f16>(vec2<f32>());
Enable(ast::Extension::kF16);
- GlobalVar("v", Call(ty.vec2<f16>(), Call(ty.vec2<f32>())), type::AddressSpace::kPrivate);
+ GlobalVar("v", vec2<f16>(vec2<f32>()), type::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverF16ExtensionTest, Vec2TypeConvUsedWithoutExtension) {
// var<private> v = vec2<f16>(vec2<f32>());
- GlobalVar("v", Call(ty.vec2(ty.f16(Source{{12, 34}})), Call(ty.vec2<f32>())),
- type::AddressSpace::kPrivate);
+ GlobalVar("v", vec2(ty.f16(Source{{12, 34}}), vec2<f32>()), type::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: f16 type used without 'f16' extension enabled");
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index 65aff49..511d379 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -936,7 +936,7 @@
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {
- auto* ret_type = ty.pointer(Source{{12, 34}}, ty.i32(), type::AddressSpace::kFunction);
+ auto ret_type = ty.pointer(Source{{12, 34}}, ty.i32(), type::AddressSpace::kFunction);
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
@@ -944,7 +944,7 @@
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_AtomicInt) {
- auto* ret_type = ty.atomic(Source{{12, 34}}, ty.i32());
+ auto ret_type = ty.atomic(Source{{12, 34}}, ty.i32());
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
@@ -952,7 +952,7 @@
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_ArrayOfAtomic) {
- auto* ret_type = ty.array(Source{{12, 34}}, ty.atomic(ty.i32()), 10_u);
+ auto ret_type = ty.array(Source{{12, 34}}, ty.atomic(ty.i32()), 10_u);
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
@@ -963,7 +963,7 @@
Structure("S", utils::Vector{
Member("m", ty.atomic(ty.i32())),
});
- auto* ret_type = ty(Source{{12, 34}}, "S");
+ auto ret_type = ty(Source{{12, 34}}, "S");
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
@@ -971,7 +971,7 @@
}
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_RuntimeArray) {
- auto* ret_type = ty.array(Source{{12, 34}}, ty.i32());
+ auto ret_type = ty.array(Source{{12, 34}}, ty.i32());
Func("f", utils::Empty, ret_type, utils::Empty);
EXPECT_FALSE(r()->Resolve());
@@ -982,7 +982,7 @@
Structure("S", utils::Vector{
Member("m", ty.atomic(ty.i32())),
});
- auto* ret_type = ty(Source{{12, 34}}, "S");
+ auto ret_type = ty(Source{{12, 34}}, "S");
auto* bar = Param("bar", ret_type);
Func("f", utils::Vector{bar}, ty.void_(), utils::Empty);
@@ -994,7 +994,7 @@
Structure("S", utils::Vector{
Member("m", ty.i32()),
});
- auto* ret_type = ty(Source{{12, 34}}, "S");
+ auto ret_type = ty(Source{{12, 34}}, "S");
auto* bar = Param(Source{{12, 34}}, "bar", ret_type);
Func("f", utils::Vector{bar}, ty.void_(), utils::Empty);
@@ -1025,29 +1025,28 @@
TEST_F(ResolverFunctionValidationTest, ParameterVectorNoType) {
// fn f(p : vec3) {}
- Func(Source{{12, 34}}, "f",
- utils::Vector{Param("p", create<ast::Vector>(Source{{12, 34}}, nullptr, 3u))}, ty.void_(),
- utils::Empty);
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
-}
-
-TEST_F(ResolverFunctionValidationTest, ParameterMatrixNoType) {
- // fn f(p : vec3) {}
-
- Func(Source{{12, 34}}, "f",
- utils::Vector{Param("p", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3u, 3u))},
+ Func(Source{{12, 34}}, "f", utils::Vector{Param("p", ty.vec3<Infer>(Source{{12, 34}}))},
ty.void_(), utils::Empty);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'vec3'");
+}
+
+TEST_F(ResolverFunctionValidationTest, ParameterMatrixNoType) {
+ // fn f(p : mat3x3) {}
+
+ Func(Source{{12, 34}}, "f", utils::Vector{Param("p", ty.mat3x3<Infer>(Source{{12, 34}}))},
+ ty.void_(), utils::Empty);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'mat3x3'");
}
enum class Expectation {
kAlwaysPass,
kPassWithFullPtrParameterExtension,
kAlwaysFail,
+ kInvalid,
};
struct TestParams {
type::AddressSpace address_space;
@@ -1059,7 +1058,7 @@
using ResolverFunctionParameterValidationTest = TestWithParams;
TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceNoExtension) {
auto& param = GetParam();
- auto* ptr_type = ty.pointer(Source{{12, 34}}, ty.i32(), param.address_space);
+ auto ptr_type = ty("ptr", Ident(Source{{12, 34}}, param.address_space), ty.i32());
auto* arg = Param(Source{{12, 34}}, "p", ptr_type);
Func("f", utils::Vector{arg}, ty.void_(), utils::Empty);
@@ -1069,37 +1068,48 @@
std::stringstream ss;
ss << param.address_space;
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: function parameter of pointer type cannot be in '" +
- ss.str() + "' address space");
+ if (param.expectation == Expectation::kInvalid) {
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + ss.str() + "'");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function parameter of pointer type cannot be in '" + ss.str() +
+ "' address space");
+ }
}
}
TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceWithExtension) {
auto& param = GetParam();
- auto* ptr_type = ty.pointer(Source{{12, 34}}, ty.i32(), param.address_space);
+ auto ptr_type = ty("ptr", Ident(Source{{12, 34}}, param.address_space), ty.i32());
auto* arg = Param(Source{{12, 34}}, "p", ptr_type);
Enable(ast::Extension::kChromiumExperimentalFullPtrParameters);
Func("f", utils::Vector{arg}, ty.void_(), utils::Empty);
- if (param.expectation != Expectation::kAlwaysFail) {
+ if (param.expectation == Expectation::kAlwaysPass ||
+ param.expectation == Expectation::kPassWithFullPtrParameterExtension) {
ASSERT_TRUE(r()->Resolve()) << r()->error();
} else {
std::stringstream ss;
ss << param.address_space;
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: function parameter of pointer type cannot be in '" +
- ss.str() + "' address space");
+ if (param.expectation == Expectation::kInvalid) {
+ EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + ss.str() + "'");
+ } else {
+ EXPECT_EQ(r()->error(),
+ "12:34 error: function parameter of pointer type cannot be in '" + ss.str() +
+ "' address space");
+ }
}
}
INSTANTIATE_TEST_SUITE_P(
ResolverTest,
ResolverFunctionParameterValidationTest,
testing::Values(
- TestParams{type::AddressSpace::kNone, Expectation::kAlwaysFail},
- TestParams{type::AddressSpace::kIn, Expectation::kAlwaysFail},
- TestParams{type::AddressSpace::kOut, Expectation::kAlwaysFail},
+ TestParams{type::AddressSpace::kNone, Expectation::kInvalid},
+ TestParams{type::AddressSpace::kIn, Expectation::kInvalid},
+ TestParams{type::AddressSpace::kOut, Expectation::kInvalid},
TestParams{type::AddressSpace::kUniform, Expectation::kPassWithFullPtrParameterExtension},
TestParams{type::AddressSpace::kWorkgroup, Expectation::kPassWithFullPtrParameterExtension},
- TestParams{type::AddressSpace::kHandle, Expectation::kAlwaysFail},
+ TestParams{type::AddressSpace::kHandle, Expectation::kInvalid},
TestParams{type::AddressSpace::kStorage, Expectation::kPassWithFullPtrParameterExtension},
TestParams{type::AddressSpace::kPrivate, Expectation::kAlwaysPass},
TestParams{type::AddressSpace::kFunction, Expectation::kAlwaysPass}));
diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc
index eba38ff..8a609d2 100644
--- a/src/tint/resolver/inferred_type_test.cc
+++ b/src/tint/resolver/inferred_type_test.cc
@@ -134,7 +134,7 @@
INSTANTIATE_TEST_SUITE_P(ResolverTest, ResolverInferredTypeParamTest, testing::ValuesIn(all_cases));
TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
- auto* type = ty.array<u32, 10>();
+ auto type = ty.array<u32, 10>();
auto* expected_type = create<type::Array>(
create<type::U32>(), create<type::ConstantArrayCount>(10u), 4u, 4u * 10u, 4u, 4u);
diff --git a/src/tint/resolver/load_test.cc b/src/tint/resolver/load_test.cc
index 7a74224..2334227 100644
--- a/src/tint/resolver/load_test.cc
+++ b/src/tint/resolver/load_test.cc
@@ -154,7 +154,7 @@
// var ref = vec4(1);
// var v = ref.xyz;
auto* ident = Expr("ref");
- WrapInFunction(Var("ref", Call(ty.vec4<i32>(), 1_i)), //
+ WrapInFunction(Var("ref", vec4<i32>(1_i)), //
Var("v", MemberAccessor(ident, "xyz")));
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -230,7 +230,7 @@
},
ty.vec4<f32>(),
utils::Vector{
- Return(Call("textureSampleLevel", "tp", "sp", Call(ty.vec2<f32>()), 0_a)),
+ Return(Call("textureSampleLevel", "tp", "sp", vec2<f32>(), 0_a)),
});
auto* t_ident = Expr("t");
auto* s_ident = Expr("s");
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 366fbc3..acd69a8 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -350,7 +350,7 @@
break;
case Method::kStruct:
Structure("S", utils::Vector{Member("v", target_ty())});
- WrapInFunction(Call(ty("S"), abstract_expr));
+ WrapInFunction(Call("S", abstract_expr));
break;
case Method::kBinaryOp: {
// Add 0 to ensure no overflow with max float values
@@ -1262,7 +1262,7 @@
TEST_F(MaterializeAbstractStructure, Modf_Vector_DefaultType) {
// var v = modf(vec2(1));
- auto* call = Call("modf", Call(ty.vec2(nullptr), 1_a));
+ auto* call = Call("modf", Call(ty.vec2<Infer>(), 1_a));
WrapInFunction(Decl(Var("v", call)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
@@ -1302,8 +1302,8 @@
// var v = modf(vec2(1_h)); // v is __modf_result_vec2_f16
// v = modf(vec2(1)); // __modf_result_vec2_f16 <- __modf_result_vec2_abstract
Enable(ast::Extension::kF16);
- auto* call = Call("modf", Call(ty.vec2(nullptr), 1_a));
- WrapInFunction(Decl(Var("v", Call("modf", Call(ty.vec2(nullptr), 1_h)))), Assign("v", call));
+ auto* call = Call("modf", Call(ty.vec2<Infer>(), 1_a));
+ WrapInFunction(Decl(Var("v", Call("modf", Call(ty.vec2<Infer>(), 1_h)))), Assign("v", call));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
@@ -1339,7 +1339,7 @@
TEST_F(MaterializeAbstractStructure, Frexp_Vector_DefaultType) {
// var v = frexp(vec2(1));
- auto* call = Call("frexp", Call(ty.vec2(nullptr), 1_a));
+ auto* call = Call("frexp", Call(ty.vec2<Infer>(), 1_a));
WrapInFunction(Decl(Var("v", call)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
@@ -1385,8 +1385,8 @@
// var v = frexp(vec2(1_h)); // v is __frexp_result_vec2_f16
// v = frexp(vec2(1)); // __frexp_result_vec2_f16 <- __frexp_result_vec2_abstract
Enable(ast::Extension::kF16);
- auto* call = Call("frexp", Call(ty.vec2(nullptr), 1_a));
- WrapInFunction(Decl(Var("v", Call("frexp", Call(ty.vec2(nullptr), 1_h)))), Assign("v", call));
+ auto* call = Call("frexp", Call(ty.vec2<Infer>(), 1_a));
+ WrapInFunction(Decl(Var("v", Call("frexp", Call(ty.vec2<Infer>(), 1_h)))), Assign("v", call));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc
index 813d4de..2fd48ed 100644
--- a/src/tint/resolver/override_test.cc
+++ b/src/tint/resolver/override_test.cc
@@ -208,8 +208,8 @@
TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize) {
auto* a = Override("a", ty.i32());
auto* b = Override("b", ty.i32(), Mul(2_a, "a"));
- auto* arr_ty = ty.array(ty.i32(), Mul(2_a, "b"));
- auto* arr = GlobalVar("arr", type::AddressSpace::kWorkgroup, arr_ty);
+ auto* arr = GlobalVar("arr", type::AddressSpace::kWorkgroup, ty.array(ty.i32(), Mul(2_a, "b")));
+ auto arr_ty = arr->type;
Override("unused", ty.i32(), Expr(1_a));
auto* func = Func("foo", utils::Empty, ty.void_(),
utils::Vector{
@@ -219,7 +219,7 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr_ty));
+ auto* r = Sem().TransitivelyReferencedOverrides(TypeOf(arr_ty));
ASSERT_NE(r, nullptr);
auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
@@ -248,8 +248,9 @@
TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize_Alias) {
auto* a = Override("a", ty.i32());
auto* b = Override("b", ty.i32(), Mul(2_a, "a"));
- auto* arr_ty = Alias("arr_ty", ty.array(ty.i32(), Mul(2_a, "b")));
+ Alias("arr_ty", ty.array(ty.i32(), Mul(2_a, "b")));
auto* arr = GlobalVar("arr", type::AddressSpace::kWorkgroup, ty("arr_ty"));
+ auto arr_ty = arr->type;
Override("unused", ty.i32(), Expr(1_a));
auto* func = Func("foo", utils::Empty, ty.void_(),
utils::Vector{
@@ -259,7 +260,7 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<type::Array>(arr_ty->type));
+ auto* r = Sem().TransitivelyReferencedOverrides(TypeOf(arr_ty));
ASSERT_NE(r, nullptr);
auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
diff --git a/src/tint/resolver/ptr_ref_test.cc b/src/tint/resolver/ptr_ref_test.cc
index 64b21ae..692bd3b 100644
--- a/src/tint/resolver/ptr_ref_test.cc
+++ b/src/tint/resolver/ptr_ref_test.cc
@@ -88,7 +88,7 @@
WrapInFunction(function, function_ptr, private_ptr, workgroup_ptr, uniform_ptr, storage_ptr);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_TRUE(TypeOf(function_ptr)->Is<type::Pointer>())
<< "function_ptr is " << TypeOf(function_ptr)->TypeInfo().name;
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index ab1970a..352c5e4 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -21,7 +21,6 @@
#include <utility>
#include "src/tint/ast/alias.h"
-#include "src/tint/ast/array.h"
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/attribute.h"
#include "src/tint/ast/bitcast_expression.h"
@@ -36,16 +35,11 @@
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/matrix.h"
-#include "src/tint/ast/pointer.h"
#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/sampled_texture.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/traverse_expressions.h"
-#include "src/tint/ast/type_name.h"
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/vector.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/resolver/uniformity.h"
@@ -54,6 +48,7 @@
#include "src/tint/sem/call.h"
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/function.h"
+#include "src/tint/sem/function_expression.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/index_accessor_expression.h"
#include "src/tint/sem/load.h"
@@ -217,138 +212,6 @@
return result;
}
-type::Type* Resolver::Type(const ast::Type* ty) {
- if (ty == nullptr) {
- return builder_->create<type::Void>();
- }
-
- Mark(ty);
- auto* s = Switch(
- ty, //
- [&](const ast::Vector* t) -> type::Vector* {
- if (!t->type) {
- AddError("missing vector element type", t->source.End());
- return nullptr;
- }
- if (auto* el = Type(t->type)) {
- if (auto* vector = builder_->create<type::Vector>(el, t->width)) {
- if (validator_.Vector(vector, t->source)) {
- return vector;
- }
- }
- }
- return nullptr;
- },
- [&](const ast::Matrix* t) -> type::Matrix* {
- if (!t->type) {
- AddError("missing matrix element type", t->source.End());
- return nullptr;
- }
- if (auto* el = Type(t->type)) {
- if (auto* column_type = builder_->create<type::Vector>(el, t->rows)) {
- if (auto* matrix = builder_->create<type::Matrix>(column_type, t->columns)) {
- if (validator_.Matrix(matrix, t->source)) {
- return matrix;
- }
- }
- }
- }
- return nullptr;
- },
- [&](const ast::Array* t) { return Array(t); },
- [&](const ast::Atomic* t) -> type::Atomic* {
- if (auto* el = Type(t->type)) {
- auto* a = builder_->create<type::Atomic>(el);
- if (!validator_.Atomic(t, a)) {
- return nullptr;
- }
- return a;
- }
- return nullptr;
- },
- [&](const ast::Pointer* t) -> type::Pointer* {
- if (auto* el = Type(t->type)) {
- auto access = t->access;
- if (access == type::Access::kUndefined) {
- access = DefaultAccessForAddressSpace(t->address_space);
- }
- auto ptr = builder_->create<type::Pointer>(el, t->address_space, access);
- if (!ptr) {
- return nullptr;
- }
- if (!validator_.Pointer(t, ptr)) {
- return nullptr;
- }
- if (!ApplyAddressSpaceUsageToType(t->address_space, el, t->type->source)) {
- AddNote("while instantiating " + builder_->FriendlyName(ptr), t->source);
- return nullptr;
- }
- return ptr;
- }
- return nullptr;
- },
- [&](const ast::SampledTexture* t) -> type::SampledTexture* {
- if (auto* el = Type(t->type)) {
- auto* sem = builder_->create<type::SampledTexture>(t->dim, el);
- if (!validator_.SampledTexture(sem, t->source)) {
- return nullptr;
- }
- return sem;
- }
- return nullptr;
- },
- [&](const ast::MultisampledTexture* t) -> type::MultisampledTexture* {
- if (auto* el = Type(t->type)) {
- auto* sem = builder_->create<type::MultisampledTexture>(t->dim, el);
- if (!validator_.MultisampledTexture(sem, t->source)) {
- return nullptr;
- }
- return sem;
- }
- return nullptr;
- },
- [&](const ast::TypeName* t) -> type::Type* {
- Mark(t->name);
-
- auto resolved = dependencies_.resolved_identifiers.Get(t->name);
- if (!resolved) {
- TINT_ICE(Resolver, diagnostics_)
- << "identifier '" << builder_->Symbols().NameFor(t->name->symbol)
- << "' was not resolved";
- return nullptr;
- }
-
- if (auto* ast_node = resolved->Node()) {
- auto* type = sem_.Get<type::Type>(ast_node);
- if (TINT_UNLIKELY(!type)) {
- ErrorMismatchedResolvedIdentifier(t->source, *resolved, "type");
- return nullptr;
- }
-
- if (TINT_UNLIKELY(t->name->Is<ast::TemplatedIdentifier>())) {
- AddError("type '" + builder_->Symbols().NameFor(t->name->symbol) +
- "' does not take template arguments",
- t->source);
- NoteDeclarationSource(ast_node);
- return nullptr;
- }
-
- return type;
- }
- if (auto b = resolved->BuiltinType(); b != type::Builtin::kUndefined) {
- return BuiltinType(b, t->name);
- }
-
- ErrorMismatchedResolvedIdentifier(t->source, *resolved, "type");
- return nullptr;
- });
-
- if (s) {
- builder_->Sem().Add(ty, s);
- }
- return s;
-}
-
sem::Variable* Resolver::Variable(const ast::Variable* v, bool is_global) {
Mark(v->name);
@@ -576,7 +439,7 @@
const type::Type* storage_ty = nullptr;
// If the variable has a declared type, resolve it.
- if (auto* ty = var->type) {
+ if (auto ty = var->type) {
storage_ty = Type(ty);
if (!storage_ty) {
return nullptr;
@@ -1032,7 +895,7 @@
// Resolve the return type
type::Type* return_type = nullptr;
- if (auto* ty = decl->return_type) {
+ if (auto ty = decl->return_type) {
return_type = Type(ty);
if (!return_type) {
return nullptr;
@@ -1601,7 +1464,37 @@
}
sem::ValueExpression* Resolver::ValueExpression(const ast::Expression* expr) {
- return sem_.AsValue(Expression(expr));
+ return sem_.AsValueExpression(Expression(expr));
+}
+
+sem::TypeExpression* Resolver::TypeExpression(const ast::Expression* expr) {
+ return sem_.AsTypeExpression(Expression(expr));
+}
+
+sem::FunctionExpression* Resolver::FunctionExpression(const ast::Expression* expr) {
+ return sem_.AsFunctionExpression(Expression(expr));
+}
+
+type::Type* Resolver::Type(const ast::Expression* ast) {
+ auto* type_expr = TypeExpression(ast);
+ if (!type_expr) {
+ return nullptr;
+ }
+ return const_cast<type::Type*>(type_expr->Type());
+}
+
+sem::BuiltinEnumExpression<type::AddressSpace>* Resolver::AddressSpaceExpression(
+ const ast::Expression* expr) {
+ return sem_.AsAddressSpace(Expression(expr));
+}
+
+sem::BuiltinEnumExpression<type::TexelFormat>* Resolver::TexelFormatExpression(
+ const ast::Expression* expr) {
+ return sem_.AsTexelFormat(Expression(expr));
+}
+
+sem::BuiltinEnumExpression<type::Access>* Resolver::AccessExpression(const ast::Expression* expr) {
+ return sem_.AsAccess(Expression(expr));
}
void Resolver::RegisterStore(const sem::ValueExpression* expr) {
@@ -2002,6 +1895,11 @@
// * A builtin call.
// * A type initializer.
// * A type conversion.
+ auto* target = expr->target;
+ Mark(target);
+
+ auto* ident = target->identifier;
+ Mark(ident);
// Resolve all of the arguments, their types and the set of behaviors.
utils::Vector<const sem::ValueExpression*, 8> args;
@@ -2023,9 +1921,9 @@
bool has_side_effects =
std::any_of(args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); });
- // ct_init_or_conv is a helper for building either a sem::TypeInitializer or
+ // init_or_conv is a helper for building either a sem::TypeInitializer or
// sem::TypeConversion call for a InitConvIntrinsic with an optional template argument type.
- auto ct_init_or_conv = [&](InitConvIntrinsic ty, const type::Type* template_arg) -> sem::Call* {
+ auto init_or_conv = [&](InitConvIntrinsic ty, const type::Type* template_arg) -> sem::Call* {
auto arg_tys = utils::Transform(args, [](auto* arg) { return arg->Type(); });
auto ctor_or_conv =
intrinsic_table_->Lookup(ty, template_arg, arg_tys, args_stage, expr->source);
@@ -2087,26 +1985,24 @@
current_statement_, value, has_side_effects);
};
- // ty_init_or_conv is a helper for building either a sem::TypeInitializer or
- // sem::TypeConversion call for the given semantic type.
- auto ty_init_or_conv = [&](const type::Type* ty) {
+ auto ty_init_or_conv = [&](const type::Type* type) {
return Switch(
- ty, //
- [&](const type::Vector* v) {
- return ct_init_or_conv(VectorInitConvIntrinsic(v->Width()), v->type());
- },
- [&](const type::Matrix* m) {
- return ct_init_or_conv(MatrixInitConvIntrinsic(m->columns(), m->rows()), m->type());
- },
- [&](const type::I32*) { return ct_init_or_conv(InitConvIntrinsic::kI32, nullptr); },
- [&](const type::U32*) { return ct_init_or_conv(InitConvIntrinsic::kU32, nullptr); },
+ type, //
+ [&](const type::I32*) { return init_or_conv(InitConvIntrinsic::kI32, nullptr); },
+ [&](const type::U32*) { return init_or_conv(InitConvIntrinsic::kU32, nullptr); },
[&](const type::F16*) {
return validator_.CheckF16Enabled(expr->source)
- ? ct_init_or_conv(InitConvIntrinsic::kF16, nullptr)
+ ? init_or_conv(InitConvIntrinsic::kF16, nullptr)
: nullptr;
},
- [&](const type::F32*) { return ct_init_or_conv(InitConvIntrinsic::kF32, nullptr); },
- [&](const type::Bool*) { return ct_init_or_conv(InitConvIntrinsic::kBool, nullptr); },
+ [&](const type::F32*) { return init_or_conv(InitConvIntrinsic::kF32, nullptr); },
+ [&](const type::Bool*) { return init_or_conv(InitConvIntrinsic::kBool, nullptr); },
+ [&](const type::Vector* v) {
+ return init_or_conv(VectorInitConvIntrinsic(v->Width()), v->type());
+ },
+ [&](const type::Matrix* m) {
+ return init_or_conv(MatrixInitConvIntrinsic(m->columns(), m->rows()), m->type());
+ },
[&](const type::Array* arr) -> sem::Call* {
auto* call_target = array_inits_.GetOrCreate(
ArrayInitializerSig{{arr, args.Length(), args_stage}},
@@ -2169,170 +2065,117 @@
});
};
- // ast::CallExpression has a target which is either an ast::Type or an
- // ast::IdentifierExpression
+ auto inferred_array = [&]() -> tint::sem::Call* {
+ auto el_count =
+ builder_->create<type::ConstantArrayCount>(static_cast<uint32_t>(args.Length()));
+ auto arg_tys = utils::Transform(args, [](auto* arg) { return arg->Type()->UnwrapRef(); });
+ auto el_ty = type::Type::Common(arg_tys);
+ if (!el_ty) {
+ AddError("cannot infer common array element type from initializer arguments",
+ expr->source);
+ utils::Hashset<const type::Type*, 8> types;
+ for (size_t i = 0; i < args.Length(); i++) {
+ if (types.Add(args[i]->Type())) {
+ AddNote("argument " + std::to_string(i) + " is of type '" +
+ sem_.TypeNameOf(args[i]->Type()) + "'",
+ args[i]->Declaration()->source);
+ }
+ }
+ return nullptr;
+ }
+ auto* arr = Array(expr->source, expr->source, el_ty, el_count, /* explicit_stride */ 0);
+ if (!arr) {
+ return nullptr;
+ }
+ return ty_init_or_conv(arr);
+ };
+
auto call = [&]() -> sem::Call* {
- if (expr->target.type) {
- // ast::CallExpression has an ast::Type as the target.
- // This call is either a type initializer or type conversion.
+ auto resolved = dependencies_.resolved_identifiers.Get(ident);
+ if (!resolved) {
+ TINT_ICE(Resolver, diagnostics_)
+ << "identifier '" << builder_->Symbols().NameFor(ident->symbol)
+ << "' was not resolved";
+ return nullptr;
+ }
+
+ if (auto* ast_node = resolved->Node()) {
return Switch(
- expr->target.type,
- [&](const ast::Vector* v) -> sem::Call* {
- Mark(v);
- // vector element type must be inferred if it was not specified.
- type::Type* template_arg = nullptr;
- if (v->type) {
- template_arg = Type(v->type);
- if (!template_arg) {
- return nullptr;
- }
- }
- if (auto* c =
- ct_init_or_conv(VectorInitConvIntrinsic(v->width), template_arg)) {
- builder_->Sem().Add(expr->target.type, c->Target()->ReturnType());
- return c;
- }
- return nullptr;
- },
- [&](const ast::Matrix* m) -> sem::Call* {
- Mark(m);
- // matrix element type must be inferred if it was not specified.
- type::Type* template_arg = nullptr;
- if (m->type) {
- template_arg = Type(m->type);
- if (!template_arg) {
- return nullptr;
- }
- }
- if (auto* c = ct_init_or_conv(MatrixInitConvIntrinsic(m->columns, m->rows),
- template_arg)) {
- builder_->Sem().Add(expr->target.type, c->Target()->ReturnType());
- return c;
- }
- return nullptr;
- },
- [&](const ast::Array* a) -> sem::Call* {
- Mark(a);
- // array element type must be inferred if it was not specified.
- const type::ArrayCount* el_count = nullptr;
- const type::Type* el_ty = nullptr;
- if (a->type) {
- el_ty = Type(a->type);
- if (!el_ty) {
- return nullptr;
- }
- if (!a->count) {
- AddError("cannot construct a runtime-sized array", expr->source);
- return nullptr;
- }
- el_count = ArrayCount(a->count);
- if (!el_count) {
- return nullptr;
- }
- // Note: validation later will detect any mismatches between explicit array
- // size and number of initializer expressions.
- } else {
- el_count = builder_->create<type::ConstantArrayCount>(
- static_cast<uint32_t>(args.Length()));
- auto arg_tys = utils::Transform(
- args, [](auto* arg) { return arg->Type()->UnwrapRef(); });
- el_ty = type::Type::Common(arg_tys);
- if (!el_ty) {
- AddError(
- "cannot infer common array element type from initializer arguments",
- expr->source);
- utils::Hashset<const type::Type*, 8> types;
- for (size_t i = 0; i < args.Length(); i++) {
- if (types.Add(args[i]->Type())) {
- AddNote("argument " + std::to_string(i) + " is of type '" +
- sem_.TypeNameOf(args[i]->Type()) + "'",
- args[i]->Declaration()->source);
- }
- }
- return nullptr;
- }
- }
- uint32_t explicit_stride = 0;
- if (!ArrayAttributes(a->attributes, el_ty, explicit_stride)) {
- return nullptr;
- }
-
- auto* arr = Array(a->type ? a->type->source : a->source,
- a->count ? a->count->source : a->source, //
- el_ty, el_count, explicit_stride);
- if (!arr) {
- return nullptr;
- }
- builder_->Sem().Add(a, arr);
-
- return ty_init_or_conv(arr);
- },
- [&](const ast::Type* ast) -> sem::Call* {
- // Handler for AST types that do not have an optional element type.
- if (auto* ty = Type(ast)) {
- return ty_init_or_conv(ty);
- }
+ sem_.Get(ast_node), //
+ [&](type::Type* t) { return ty_init_or_conv(t); },
+ [&](sem::Function* f) { return FunctionCall(expr, f, args, arg_behaviors); },
+ [&](sem::Expression* e) {
+ sem_.ErrorUnexpectedExprKind(e, "call target");
return nullptr;
},
[&](Default) {
- TINT_ICE(Resolver, diagnostics_)
- << expr->source << " unhandled CallExpression target:\n"
- << "type: "
- << (expr->target.type ? expr->target.type->TypeInfo().name : "<null>");
+ ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target");
return nullptr;
});
- } else {
- // ast::CallExpression has an ast::IdentifierExpression as the target.
- // This call is either a function call, builtin call, type initializer or type
- // conversion.
- auto* ident = expr->target.name;
- Mark(ident);
+ }
- auto resolved = dependencies_.resolved_identifiers.Get(ident);
- if (!resolved) {
- TINT_ICE(Resolver, diagnostics_)
- << "identifier '" << builder_->Symbols().NameFor(ident->symbol)
- << "' was not resolved";
+ if (auto f = resolved->BuiltinFunction(); f != sem::BuiltinType::kNone) {
+ return BuiltinCall(expr, f, args);
+ }
+
+ if (auto b = resolved->BuiltinType(); b != type::Builtin::kUndefined) {
+ if (!ident->Is<ast::TemplatedIdentifier>()) {
+ // No template arguments provided.
+ // Check to see if this is an inferred-element-type call.
+ switch (b) {
+ case type::Builtin::kArray:
+ return inferred_array();
+ case type::Builtin::kVec2:
+ return init_or_conv(InitConvIntrinsic::kVec2, nullptr);
+ case type::Builtin::kVec3:
+ return init_or_conv(InitConvIntrinsic::kVec3, nullptr);
+ case type::Builtin::kVec4:
+ return init_or_conv(InitConvIntrinsic::kVec4, nullptr);
+ case type::Builtin::kMat2X2:
+ return init_or_conv(InitConvIntrinsic::kMat2x2, nullptr);
+ case type::Builtin::kMat2X3:
+ return init_or_conv(InitConvIntrinsic::kMat2x3, nullptr);
+ case type::Builtin::kMat2X4:
+ return init_or_conv(InitConvIntrinsic::kMat2x4, nullptr);
+ case type::Builtin::kMat3X2:
+ return init_or_conv(InitConvIntrinsic::kMat3x2, nullptr);
+ case type::Builtin::kMat3X3:
+ return init_or_conv(InitConvIntrinsic::kMat3x3, nullptr);
+ case type::Builtin::kMat3X4:
+ return init_or_conv(InitConvIntrinsic::kMat3x4, nullptr);
+ case type::Builtin::kMat4X2:
+ return init_or_conv(InitConvIntrinsic::kMat4x2, nullptr);
+ case type::Builtin::kMat4X3:
+ return init_or_conv(InitConvIntrinsic::kMat4x3, nullptr);
+ case type::Builtin::kMat4X4:
+ return init_or_conv(InitConvIntrinsic::kMat4x4, nullptr);
+ default:
+ break;
+ }
+ }
+ auto* ty = BuiltinType(b, ident);
+ if (TINT_UNLIKELY(!ty)) {
return nullptr;
}
-
- if (auto* ast_node = resolved->Node()) {
- return Switch(
- sem_.Get(ast_node), //
- [&](const type::Type* ty) {
- // A type initializer or conversions.
- // Note: Unlike the code path where we're resolving the call target from an
- // ast::Type, all types must already have the element type explicitly
- // specified, so there's no need to infer element types.
- return ty_init_or_conv(ty);
- },
- [&](sem::Function* func) {
- return FunctionCall(expr, func, args, arg_behaviors);
- },
- [&](Default) {
- ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target");
- return nullptr;
- });
- }
-
- if (auto f = resolved->BuiltinFunction(); f != sem::BuiltinType::kNone) {
- return BuiltinCall(expr, f, args);
- }
-
- if (auto b = resolved->BuiltinType(); b != type::Builtin::kUndefined) {
- auto* ty = BuiltinType(b, expr->target.name);
- return ty ? ty_init_or_conv(ty) : nullptr;
- }
-
- ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target");
- return nullptr;
+ return ty_init_or_conv(ty);
}
+
+ ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target");
+ return nullptr;
}();
if (!call) {
return nullptr;
}
+ if (call->Target()->IsAnyOf<sem::TypeInitializer, sem::TypeConversion>()) {
+ // The target of the call was a type.
+ // Associate the target identifier expression with the resolved type.
+ auto* ty_expr =
+ builder_->create<sem::TypeExpression>(target, current_statement_, call->Type());
+ builder_->Sem().Add(target, ty_expr);
+ }
+
return validator_.Call(call, current_statement_) ? call : nullptr;
}
@@ -2444,32 +2287,224 @@
auto f16 = [&] {
return validator_.CheckF16Enabled(ident->source) ? b.create<type::F16>() : nullptr;
};
- auto vec = [&](type::Type* el, uint32_t n) {
- return el ? b.create<type::Vector>(el, n) : nullptr;
- };
- auto mat = [&](type::Type* el, uint32_t num_columns, uint32_t num_rows) {
- return el ? b.create<type::Matrix>(vec(el, num_rows), num_columns) : nullptr;
- };
- auto templated_identifier = [&](size_t num_args) -> const ast::TemplatedIdentifier* {
+ auto templated_identifier =
+ [&](size_t min_args, size_t max_args = /* use min */ 0) -> const ast::TemplatedIdentifier* {
+ if (max_args == 0) {
+ max_args = min_args;
+ }
auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>();
- if (TINT_UNLIKELY(!tmpl_ident)) {
- AddError("expected '<' for '" + b.Symbols().NameFor(ident->symbol) + "'",
- Source{ident->source.range.end});
+ if (!tmpl_ident) {
+ if (TINT_UNLIKELY(min_args != 0)) {
+ AddError("expected '<' for '" + b.Symbols().NameFor(ident->symbol) + "'",
+ Source{ident->source.range.end});
+ }
return nullptr;
}
- if (TINT_UNLIKELY(tmpl_ident->arguments.Length() != num_args)) {
- AddError("'" + b.Symbols().NameFor(ident->symbol) + "' requires " +
- std::to_string(num_args) + " template arguments",
- ident->source);
- return nullptr;
+ if (min_args == max_args) {
+ if (TINT_UNLIKELY(tmpl_ident->arguments.Length() != min_args)) {
+ AddError("'" + b.Symbols().NameFor(ident->symbol) + "' requires " +
+ std::to_string(min_args) + " template arguments",
+ ident->source);
+ return nullptr;
+ }
+ } else {
+ if (TINT_UNLIKELY(tmpl_ident->arguments.Length() < min_args)) {
+ AddError("'" + b.Symbols().NameFor(ident->symbol) + "' requires at least " +
+ std::to_string(min_args) + " template arguments",
+ ident->source);
+ return nullptr;
+ }
+ if (TINT_UNLIKELY(tmpl_ident->arguments.Length() > max_args)) {
+ AddError("'" + b.Symbols().NameFor(ident->symbol) + "' requires at most " +
+ std::to_string(max_args) + " template arguments",
+ ident->source);
+ return nullptr;
+ }
}
return tmpl_ident;
};
+ auto vec = [&](type::Type* el, uint32_t n) -> type::Vector* {
+ if (TINT_UNLIKELY(!el)) {
+ return nullptr;
+ }
+ if (TINT_UNLIKELY(!validator_.Vector(el, ident->source))) {
+ return nullptr;
+ }
+ return b.create<type::Vector>(el, n);
+ };
+ auto mat = [&](type::Type* el, uint32_t num_columns, uint32_t num_rows) -> type::Matrix* {
+ if (TINT_UNLIKELY(!el)) {
+ return nullptr;
+ }
+ if (TINT_UNLIKELY(!validator_.Matrix(el, ident->source))) {
+ return nullptr;
+ }
+ auto* column = vec(el, num_rows);
+ if (!column) {
+ return nullptr;
+ }
+ return b.create<type::Matrix>(column, num_columns);
+ };
+ auto vec_t = [&](uint32_t n) -> type::Vector* {
+ auto* tmpl_ident = templated_identifier(1);
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+ auto* ty = Type(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!ty)) {
+ return nullptr;
+ }
+ return vec(const_cast<type::Type*>(ty), n);
+ };
+ auto mat_t = [&](uint32_t num_columns, uint32_t num_rows) -> type::Matrix* {
+ auto* tmpl_ident = templated_identifier(1);
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+ auto* ty = Type(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!ty)) {
+ return nullptr;
+ }
+ return mat(const_cast<type::Type*>(ty), num_columns, num_rows);
+ };
+ auto array = [&]() -> type::Array* {
+ utils::UniqueVector<const sem::GlobalVariable*, 4> transitively_referenced_overrides;
+ TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides);
+
+ auto* tmpl_ident = templated_identifier(1, 2);
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+ auto* ast_el_ty = tmpl_ident->arguments[0];
+ auto* ast_count = (tmpl_ident->arguments.Length() > 1) ? tmpl_ident->arguments[1] : nullptr;
+
+ auto* el_ty = Type(ast_el_ty);
+ if (!el_ty) {
+ return nullptr;
+ }
+
+ const type::ArrayCount* el_count =
+ ast_count ? ArrayCount(ast_count) : builder_->create<type::RuntimeArrayCount>();
+ if (!el_count) {
+ return nullptr;
+ }
+
+ // Look for explicit stride via @stride(n) attribute
+ uint32_t explicit_stride = 0;
+ if (!ArrayAttributes(tmpl_ident->attributes, el_ty, explicit_stride)) {
+ return nullptr;
+ }
+
+ auto* out = Array(ast_el_ty->source, //
+ ast_count ? ast_count->source : ident->source, //
+ el_ty, el_count, explicit_stride);
+ if (!out) {
+ return nullptr;
+ }
+
+ if (el_ty->Is<type::Atomic>()) {
+ atomic_composite_info_.Add(out, &ast_el_ty->source);
+ } else {
+ if (auto found = atomic_composite_info_.Get(el_ty)) {
+ atomic_composite_info_.Add(out, *found);
+ }
+ }
+
+ // Track the pipeline-overridable constants that are transitively referenced by this
+ // array type.
+ for (auto* var : transitively_referenced_overrides) {
+ builder_->Sem().AddTransitivelyReferencedOverride(out, var);
+ }
+ return out;
+ };
+ auto atomic = [&]() -> type::Atomic* {
+ auto* tmpl_ident = templated_identifier(1); // atomic<type>
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+
+ auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!ty_expr)) {
+ return nullptr;
+ }
+ auto* ty = ty_expr->Type();
+
+ auto* out = builder_->create<type::Atomic>(ty);
+ if (!validator_.Atomic(tmpl_ident, out)) {
+ return nullptr;
+ }
+ return out;
+ };
+ auto ptr = [&]() -> type::Pointer* {
+ auto* tmpl_ident = templated_identifier(2, 3); // ptr<address, type [, access]>
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+
+ auto* address_space_expr = AddressSpaceExpression(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!address_space_expr)) {
+ return nullptr;
+ }
+ auto address_space = address_space_expr->Value();
+
+ auto* store_ty_expr = TypeExpression(tmpl_ident->arguments[1]);
+ if (TINT_UNLIKELY(!store_ty_expr)) {
+ return nullptr;
+ }
+ auto* store_ty = const_cast<type::Type*>(store_ty_expr->Type());
+
+ auto access = DefaultAccessForAddressSpace(address_space);
+ if (tmpl_ident->arguments.Length() > 2) {
+ auto* access_expr = AccessExpression(tmpl_ident->arguments[2]);
+ if (TINT_UNLIKELY(!access_expr)) {
+ return nullptr;
+ }
+ access = access_expr->Value();
+ }
+
+ auto* out = b.create<type::Pointer>(store_ty, address_space, access);
+ if (!validator_.Pointer(tmpl_ident, out)) {
+ return nullptr;
+ }
+ if (!ApplyAddressSpaceUsageToType(address_space, store_ty,
+ store_ty_expr->Declaration()->source)) {
+ AddNote("while instantiating " + builder_->FriendlyName(out), ident->source);
+ return nullptr;
+ }
+ return out;
+ };
+ auto sampled_texture = [&](type::TextureDimension dim) -> type::SampledTexture* {
+ auto* tmpl_ident = templated_identifier(1);
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+
+ auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!ty_expr)) {
+ return nullptr;
+ }
+ auto* out = b.create<type::SampledTexture>(dim, ty_expr->Type());
+ return validator_.SampledTexture(out, ident->source) ? out : nullptr;
+ };
+ auto multisampled_texture = [&](type::TextureDimension dim) -> type::MultisampledTexture* {
+ auto* tmpl_ident = templated_identifier(1);
+ if (TINT_UNLIKELY(!tmpl_ident)) {
+ return nullptr;
+ }
+
+ auto* ty_expr = TypeExpression(tmpl_ident->arguments[0]);
+ if (TINT_UNLIKELY(!ty_expr)) {
+ return nullptr;
+ }
+ auto* out = b.create<type::MultisampledTexture>(dim, ty_expr->Type());
+ return validator_.MultisampledTexture(out, ident->source) ? out : nullptr;
+ };
auto storage_texture = [&](type::TextureDimension dim) -> type::StorageTexture* {
auto* tmpl_ident = templated_identifier(2);
if (TINT_UNLIKELY(!tmpl_ident)) {
return nullptr;
}
+
auto* format = sem_.AsTexelFormat(Expression(tmpl_ident->arguments[0]));
if (TINT_UNLIKELY(!format)) {
return nullptr;
@@ -2497,6 +2532,30 @@
return check_no_tmpl_args(f16());
case type::Builtin::kF32:
return check_no_tmpl_args(b.create<type::F32>());
+ case type::Builtin::kVec2:
+ return vec_t(2);
+ case type::Builtin::kVec3:
+ return vec_t(3);
+ case type::Builtin::kVec4:
+ return vec_t(4);
+ case type::Builtin::kMat2X2:
+ return mat_t(2, 2);
+ case type::Builtin::kMat2X3:
+ return mat_t(2, 3);
+ case type::Builtin::kMat2X4:
+ return mat_t(2, 4);
+ case type::Builtin::kMat3X2:
+ return mat_t(3, 2);
+ case type::Builtin::kMat3X3:
+ return mat_t(3, 3);
+ case type::Builtin::kMat3X4:
+ return mat_t(3, 4);
+ case type::Builtin::kMat4X2:
+ return mat_t(4, 2);
+ case type::Builtin::kMat4X3:
+ return mat_t(4, 3);
+ case type::Builtin::kMat4X4:
+ return mat_t(4, 4);
case type::Builtin::kMat2X2F:
return check_no_tmpl_args(mat(f32(), 2u, 2u));
case type::Builtin::kMat2X3F:
@@ -2557,11 +2616,29 @@
return check_no_tmpl_args(vec(u32(), 3u));
case type::Builtin::kVec4U:
return check_no_tmpl_args(vec(u32(), 4u));
+ case type::Builtin::kArray:
+ return array();
+ case type::Builtin::kAtomic:
+ return atomic();
+ case type::Builtin::kPtr:
+ return ptr();
case type::Builtin::kSampler:
return check_no_tmpl_args(builder_->create<type::Sampler>(type::SamplerKind::kSampler));
case type::Builtin::kSamplerComparison:
return check_no_tmpl_args(
builder_->create<type::Sampler>(type::SamplerKind::kComparisonSampler));
+ case type::Builtin::kTexture1D:
+ return sampled_texture(type::TextureDimension::k1d);
+ case type::Builtin::kTexture2D:
+ return sampled_texture(type::TextureDimension::k2d);
+ case type::Builtin::kTexture2DArray:
+ return sampled_texture(type::TextureDimension::k2dArray);
+ case type::Builtin::kTexture3D:
+ return sampled_texture(type::TextureDimension::k3d);
+ case type::Builtin::kTextureCube:
+ return sampled_texture(type::TextureDimension::kCube);
+ case type::Builtin::kTextureCubeArray:
+ return sampled_texture(type::TextureDimension::kCubeArray);
case type::Builtin::kTextureDepth2D:
return check_no_tmpl_args(
builder_->create<type::DepthTexture>(type::TextureDimension::k2d));
@@ -2579,6 +2656,8 @@
builder_->create<type::DepthMultisampledTexture>(type::TextureDimension::k2d));
case type::Builtin::kTextureExternal:
return check_no_tmpl_args(builder_->create<type::ExternalTexture>());
+ case type::Builtin::kTextureMultisampled2D:
+ return multisampled_texture(type::TextureDimension::k2d);
case type::Builtin::kTextureStorage1D:
return storage_texture(type::TextureDimension::k1d);
case type::Builtin::kTextureStorage2D:
@@ -2626,9 +2705,6 @@
sem::Function* target,
utils::Vector<const sem::ValueExpression*, N>& args,
sem::Behaviors arg_behaviors) {
- auto sym = expr->target.name->symbol;
- auto name = builder_->Symbols().NameFor(sym);
-
if (!MaybeMaterializeAndLoadArguments(args, target)) {
return nullptr;
}
@@ -2671,6 +2747,11 @@
CollectTextureSamplerPairs(target, call->Arguments());
}
+ // Associate the target identifier expression with the resolved function.
+ auto* fn_expr =
+ builder_->create<sem::FunctionExpression>(expr->target, current_statement_, target);
+ builder_->Sem().Add(expr->target, fn_expr);
+
return call;
}
@@ -2758,13 +2839,13 @@
}
sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
- Mark(expr->identifier);
+ auto* ident = expr->identifier;
+ Mark(ident);
- auto resolved = dependencies_.resolved_identifiers.Get(expr->identifier);
+ auto resolved = dependencies_.resolved_identifiers.Get(ident);
if (!resolved) {
TINT_ICE(Resolver, diagnostics_)
- << "identifier '" << builder_->Symbols().NameFor(expr->identifier->symbol)
- << "' was not resolved";
+ << "identifier '" << builder_->Symbols().NameFor(ident->symbol) << "' was not resolved";
return nullptr;
}
@@ -2773,7 +2854,7 @@
return Switch(
resolved_node, //
[&](sem::Variable* variable) -> sem::VariableUser* {
- auto symbol = expr->identifier->symbol;
+ auto symbol = ident->symbol;
auto* user =
builder_->create<sem::VariableUser>(expr, current_statement_, variable);
@@ -2844,17 +2925,32 @@
variable->AddUser(user);
return user;
},
- [&](const type::Type* ty) {
+ [&](const type::Type* ty) -> sem::TypeExpression* {
+ if (TINT_UNLIKELY(ident->Is<ast::TemplatedIdentifier>())) {
+ AddError("type '" + builder_->Symbols().NameFor(ident->symbol) +
+ "' does not take template arguments",
+ ident->source);
+ sem_.NoteDeclarationSource(ast_node);
+ return nullptr;
+ }
+
return builder_->create<sem::TypeExpression>(expr, current_statement_, ty);
},
- [&](const sem::Function*) {
- AddError("missing '(' for function call", expr->source.End());
- return nullptr;
+ [&](const sem::Function* fn) -> sem::FunctionExpression* {
+ if (TINT_UNLIKELY(ident->Is<ast::TemplatedIdentifier>())) {
+ AddError("function '" + builder_->Symbols().NameFor(ident->symbol) +
+ "' does not take template arguments",
+ ident->source);
+ sem_.NoteDeclarationSource(ast_node);
+ return nullptr;
+ }
+
+ return builder_->create<sem::FunctionExpression>(expr, current_statement_, fn);
});
}
if (auto builtin_ty = resolved->BuiltinType(); builtin_ty != type::Builtin::kUndefined) {
- auto* ty = BuiltinType(builtin_ty, expr->identifier);
+ auto* ty = BuiltinType(builtin_ty, ident);
if (!ty) {
return nullptr;
}
@@ -3237,62 +3333,6 @@
return result;
}
-type::Array* Resolver::Array(const ast::Array* arr) {
- if (!arr->type) {
- AddError("missing array element type", arr->source.End());
- return nullptr;
- }
-
- utils::UniqueVector<const sem::GlobalVariable*, 4> transitively_referenced_overrides;
- TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides);
-
- auto* el_ty = Type(arr->type);
- if (!el_ty) {
- return nullptr;
- }
-
- // Look for explicit stride via @stride(n) attribute
- uint32_t explicit_stride = 0;
- if (!ArrayAttributes(arr->attributes, el_ty, explicit_stride)) {
- return nullptr;
- }
-
- const type::ArrayCount* el_count = nullptr;
-
- // Evaluate the constant array count expression.
- if (auto* count_expr = arr->count) {
- el_count = ArrayCount(count_expr);
- if (!el_count) {
- return nullptr;
- }
- } else {
- el_count = builder_->create<type::RuntimeArrayCount>();
- }
-
- auto* out = Array(arr->type->source, //
- arr->count ? arr->count->source : arr->source, //
- el_ty, el_count, explicit_stride);
- if (out == nullptr) {
- return nullptr;
- }
-
- if (el_ty->Is<type::Atomic>()) {
- atomic_composite_info_.Add(out, &arr->type->source);
- } else {
- if (auto found = atomic_composite_info_.Get(el_ty)) {
- atomic_composite_info_.Add(out, *found);
- }
- }
-
- // Track the pipeline-overridable constants that are transitively referenced by this array
- // type.
- for (auto* var : transitively_referenced_overrides) {
- builder_->Sem().AddTransitivelyReferencedOverride(out, var);
- }
-
- return out;
-}
-
const type::ArrayCount* Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array count expression.
const auto* count_sem = Materialize(ValueExpression(count_expr));
@@ -3442,7 +3482,7 @@
}
// Resolve member type
- auto* type = Type(member->type);
+ auto type = Type(member->type);
if (!type) {
return nullptr;
}
@@ -4074,45 +4114,7 @@
AddError("cannot use " + resolved.String(builder_->Symbols(), diagnostics_) + " as " +
std::string(wanted),
source);
- NoteDeclarationSource(resolved.Node());
-}
-
-void Resolver::NoteDeclarationSource(const ast::Node* node) {
- Switch(
- node,
- [&](const ast::Struct* n) {
- AddNote("struct '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Alias* n) {
- AddNote("alias '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Var* n) {
- AddNote("var '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Let* n) {
- AddNote("let '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Override* n) {
- AddNote("override '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Const* n) {
- AddNote("const '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Parameter* n) {
- AddNote(
- "parameter '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- },
- [&](const ast::Function* n) {
- AddNote("function '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
- n->source);
- });
+ sem_.NoteDeclarationSource(resolved.Node());
}
void Resolver::AddError(const std::string& msg, const Source& source) const {
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index bc8821a..c5223db 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -126,7 +126,36 @@
/// not a sem::ValueExpression, then an error diagnostic is raised and nullptr is returned.
sem::ValueExpression* ValueExpression(const ast::Expression* expr);
- /// Expression traverses the graph of expressions starting at `expr`, building a postordered
+ /// @returns the call of Expression() cast to a sem::TypeExpression. If the sem::Expression is
+ /// not a sem::TypeExpression, then an error diagnostic is raised and nullptr is returned.
+ sem::TypeExpression* TypeExpression(const ast::Expression* expr);
+
+ /// @returns the call of Expression() cast to a sem::FunctionExpression. If the sem::Expression
+ /// is not a sem::FunctionExpression, then an error diagnostic is raised and nullptr is
+ /// returned.
+ sem::FunctionExpression* FunctionExpression(const ast::Expression* expr);
+
+ /// @returns the resolved type from an expression, or nullptr on error
+ type::Type* Type(const ast::Expression* ast);
+
+ /// @returns the call of Expression() cast to a sem::BuiltinEnumExpression<type::AddressSpace>.
+ /// If the sem::Expression is not a sem::BuiltinEnumExpression<type::AddressSpace>, then an
+ /// error diagnostic is raised and nullptr is returned.
+ sem::BuiltinEnumExpression<type::AddressSpace>* AddressSpaceExpression(
+ const ast::Expression* expr);
+
+ /// @returns the call of Expression() cast to a sem::BuiltinEnumExpression<type::TexelFormat>.
+ /// If the sem::Expression is not a sem::BuiltinEnumExpression<type::TexelFormat>, then an error
+ /// diagnostic is raised and nullptr is returned.
+ sem::BuiltinEnumExpression<type::TexelFormat>* TexelFormatExpression(
+ const ast::Expression* expr);
+
+ /// @returns the call of Expression() cast to a sem::BuiltinEnumExpression<type::Access>*.
+ /// If the sem::Expression is not a sem::BuiltinEnumExpression<type::Access>*, then an error
+ /// diagnostic is raised and nullptr is returned.
+ sem::BuiltinEnumExpression<type::Access>* AccessExpression(const ast::Expression* expr);
+
+ /// Expression traverses the graph of expressions starting at `expr`, building a post-ordered
/// list (leaf-first) of all the expression nodes. Each of the expressions are then resolved by
/// dispatching to the appropriate expression handlers below.
/// @returns the resolved semantic node for the expression `expr`, or nullptr on failure.
@@ -259,12 +288,6 @@
/// current_function_
bool WorkgroupSize(const ast::Function*);
- /// @returns the type::Type for the ast::Type `ty`, building it if it
- /// hasn't been constructed already. If an error is raised, nullptr is
- /// returned.
- /// @param ty the ast::Type
- type::Type* Type(const ast::Type* ty);
-
/// @param control the diagnostic control
/// @returns true on success, false on failure
bool DiagnosticControl(const ast::DiagnosticControl& control);
@@ -277,13 +300,6 @@
/// @returns the resolved semantic type
type::Type* TypeDecl(const ast::TypeDecl* named_type);
- /// Builds and returns the semantic information for the AST array `arr`.
- /// This method does not mark the ast::Array node, nor attach the generated semantic information
- /// to the AST node.
- /// @returns the semantic Array information, or nullptr if an error is raised.
- /// @param arr the Array to get semantic information for
- type::Array* Array(const ast::Array* arr);
-
/// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array.
@@ -432,11 +448,6 @@
const ResolvedIdentifier& resolved,
std::string_view wanted);
- /// If @p node is a module-scope type, variable or function declaration, then appends a note
- /// diagnostic where this declaration was declared, otherwise the function does nothing.
- /// @param node the AST node.
- void NoteDeclarationSource(const ast::Node* node);
-
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
diff --git a/src/tint/resolver/resolver_behavior_test.cc b/src/tint/resolver/resolver_behavior_test.cc
index af69c3c..bc4af42 100644
--- a/src/tint/resolver/resolver_behavior_test.cc
+++ b/src/tint/resolver/resolver_behavior_test.cc
@@ -82,7 +82,7 @@
Func("ArrayDiscardOrNext", utils::Empty, ty.array<i32, 4>(),
utils::Vector{
If(true, Block(Discard())),
- Return(Call(ty.array<i32, 4>())),
+ Return(array<i32, 4>()),
});
auto* stmt = Decl(Var("lhs", ty.i32(), IndexAccessor(Call("ArrayDiscardOrNext"), 1_i)));
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index 0630eb9..3d32799 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -1641,8 +1641,8 @@
TEST_P(Expr_Binary_Test_Valid, All) {
auto& params = GetParam();
- auto* lhs_type = params.create_lhs_type(*this);
- auto* rhs_type = params.create_rhs_type(*this);
+ ast::Type lhs_type = params.create_lhs_type(*this);
+ ast::Type rhs_type = params.create_rhs_type(*this);
auto* result_type = params.create_result_type(*this);
std::stringstream ss;
@@ -1674,8 +1674,8 @@
? params.create_rhs_alias_type
: params.create_rhs_type;
- auto* lhs_type = create_lhs_type(*this);
- auto* rhs_type = create_rhs_type(*this);
+ ast::Type lhs_type = create_lhs_type(*this);
+ ast::Type rhs_type = create_rhs_type(*this);
std::stringstream ss;
ss << FriendlyName(lhs_type) << " " << params.op << " " << FriendlyName(rhs_type);
@@ -1723,8 +1723,8 @@
}
}
- auto* lhs_type = lhs_create_type_func(*this);
- auto* rhs_type = rhs_create_type_func(*this);
+ ast::Type lhs_type = lhs_create_type_func(*this);
+ ast::Type rhs_type = rhs_create_type_func(*this);
std::stringstream ss;
ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type);
@@ -1753,8 +1753,8 @@
uint32_t mat_rows = std::get<2>(GetParam());
uint32_t mat_cols = std::get<3>(GetParam());
- const ast::Type* lhs_type = nullptr;
- const ast::Type* rhs_type = nullptr;
+ ast::Type lhs_type;
+ ast::Type rhs_type;
const type::Type* result_type = nullptr;
bool is_valid_expr;
@@ -1800,8 +1800,8 @@
uint32_t rhs_mat_rows = std::get<2>(GetParam());
uint32_t rhs_mat_cols = std::get<3>(GetParam());
- auto* lhs_type = ty.mat<f32>(lhs_mat_cols, lhs_mat_rows);
- auto* rhs_type = ty.mat<f32>(rhs_mat_cols, rhs_mat_rows);
+ auto lhs_type = ty.mat<f32>(lhs_mat_cols, lhs_mat_rows);
+ auto rhs_type = ty.mat<f32>(rhs_mat_cols, rhs_mat_rows);
auto* f32 = create<type::F32>();
auto* col = create<type::Vector>(f32, lhs_mat_rows);
@@ -1876,7 +1876,7 @@
}
TEST_F(ResolverTest, AddressSpace_SetForSampler) {
- auto* t = ty.sampler(type::SamplerKind::kSampler);
+ auto t = ty.sampler(type::SamplerKind::kSampler);
auto* var = GlobalVar("var", t, Binding(0_a), Group(0_a));
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -1885,7 +1885,7 @@
}
TEST_F(ResolverTest, AddressSpace_SetForTexture) {
- auto* t = ty.sampled_texture(type::TextureDimension::k1d, ty.f32());
+ auto t = ty.sampled_texture(type::TextureDimension::k1d, ty.f32());
auto* var = GlobalVar("var", t, Binding(0_a), Group(0_a));
EXPECT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index ee913f4..67d71b8 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -112,7 +112,7 @@
/// @param type a type
/// @returns the name for `type` that closely resembles how it would be
/// declared in WGSL.
- std::string FriendlyName(const ast::Type* type) { return type->FriendlyName(Symbols()); }
+ std::string FriendlyName(ast::Type type) { return Symbols().NameFor(type->identifier->symbol); }
/// @param type a type
/// @returns the name for `type` that closely resembles how it would be
@@ -199,7 +199,7 @@
return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
}
-using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
+using ast_type_func_ptr = ast::Type (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
utils::VectorRef<Scalar> args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
@@ -222,7 +222,7 @@
using ElementType = void;
/// @return nullptr
- static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; }
+ static inline ast::Type AST(ProgramBuilder&) { return {}; }
/// @return nullptr
static inline const type::Type* Sem(ProgramBuilder&) { return nullptr; }
};
@@ -238,7 +238,7 @@
/// @param b the ProgramBuilder
/// @return a new AST bool type
- static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
+ static inline ast::Type AST(ProgramBuilder& b) { return b.ty.bool_(); }
/// @param b the ProgramBuilder
/// @return the semantic bool type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::Bool>(); }
@@ -269,7 +269,7 @@
/// @param b the ProgramBuilder
/// @return a new AST i32 type
- static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
+ static inline ast::Type AST(ProgramBuilder& b) { return b.ty.i32(); }
/// @param b the ProgramBuilder
/// @return the semantic i32 type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::I32>(); }
@@ -300,7 +300,7 @@
/// @param b the ProgramBuilder
/// @return a new AST u32 type
- static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
+ static inline ast::Type AST(ProgramBuilder& b) { return b.ty.u32(); }
/// @param b the ProgramBuilder
/// @return the semantic u32 type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::U32>(); }
@@ -331,7 +331,7 @@
/// @param b the ProgramBuilder
/// @return a new AST f32 type
- static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
+ static inline ast::Type AST(ProgramBuilder& b) { return b.ty.f32(); }
/// @param b the ProgramBuilder
/// @return the semantic f32 type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::F32>(); }
@@ -362,7 +362,7 @@
/// @param b the ProgramBuilder
/// @return a new AST f16 type
- static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f16(); }
+ static inline ast::Type AST(ProgramBuilder& b) { return b.ty.f16(); }
/// @param b the ProgramBuilder
/// @return the semantic f16 type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::F16>(); }
@@ -392,7 +392,7 @@
static constexpr bool is_composite = false;
/// @returns nullptr, as abstract floats are un-typeable
- static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; }
+ static inline ast::Type AST(ProgramBuilder&) { return {}; }
/// @param b the ProgramBuilder
/// @return the semantic abstract-float type
static inline const type::Type* Sem(ProgramBuilder& b) {
@@ -424,7 +424,7 @@
static constexpr bool is_composite = false;
/// @returns nullptr, as abstract integers are un-typeable
- static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; }
+ static inline ast::Type AST(ProgramBuilder&) { return {}; }
/// @param b the ProgramBuilder
/// @return the semantic abstract-int type
static inline const type::Type* Sem(ProgramBuilder& b) { return b.create<type::AbstractInt>(); }
@@ -455,8 +455,12 @@
/// @param b the ProgramBuilder
/// @return a new AST vector type
- static inline const ast::Type* AST(ProgramBuilder& b) {
- return b.ty.vec(DataType<T>::AST(b), N);
+ static inline ast::Type AST(ProgramBuilder& b) {
+ if (IsInferOrAbstract<T>) {
+ return b.ty.vec<Infer, N>();
+ } else {
+ return b.ty.vec(DataType<T>::AST(b), N);
+ }
}
/// @param b the ProgramBuilder
/// @return the semantic vector type
@@ -503,8 +507,12 @@
/// @param b the ProgramBuilder
/// @return a new AST matrix type
- static inline const ast::Type* AST(ProgramBuilder& b) {
- return b.ty.mat(DataType<T>::AST(b), N, M);
+ static inline ast::Type AST(ProgramBuilder& b) {
+ if (IsInferOrAbstract<T>) {
+ return b.ty.mat<Infer, N, M>();
+ } else {
+ return b.ty.mat(DataType<T>::AST(b), N, M);
+ }
}
/// @param b the ProgramBuilder
/// @return the semantic matrix type
@@ -562,14 +570,15 @@
/// @param b the ProgramBuilder
/// @return a new AST alias type
- static inline const ast::Type* AST(ProgramBuilder& b) {
+ static inline ast::Type AST(ProgramBuilder& b) {
auto name = b.Symbols().Register("alias_" + std::to_string(ID));
if (!b.AST().LookupType(name)) {
- auto* type = DataType<T>::AST(b);
+ auto type = DataType<T>::AST(b);
b.AST().AddTypeDecl(b.ty.alias(name, type));
}
return b.ty(name);
}
+
/// @param b the ProgramBuilder
/// @return the semantic aliased type
static inline const type::Type* Sem(ProgramBuilder& b) { return DataType<T>::Sem(b); }
@@ -618,9 +627,9 @@
/// @param b the ProgramBuilder
/// @return a new AST alias type
- static inline const ast::Type* AST(ProgramBuilder& b) {
- return b.create<ast::Pointer>(DataType<T>::AST(b), type::AddressSpace::kPrivate,
- type::Access::kUndefined);
+ static inline ast::Type AST(ProgramBuilder& b) {
+ return b.ty.pointer(DataType<T>::AST(b), type::AddressSpace::kPrivate,
+ type::Access::kUndefined);
}
/// @param b the ProgramBuilder
/// @return the semantic aliased type
@@ -660,11 +669,11 @@
/// @param b the ProgramBuilder
/// @return a new AST array type
- static inline const ast::Type* AST(ProgramBuilder& b) {
- if (auto* ast = DataType<T>::AST(b)) {
+ static inline ast::Type AST(ProgramBuilder& b) {
+ if (auto ast = DataType<T>::AST(b)) {
return b.ty.array(ast, u32(N));
}
- return b.ty.array(nullptr, nullptr);
+ return b.ty.array<Infer>();
}
/// @param b the ProgramBuilder
/// @return the semantic array type
diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc
index e288707..01a737c 100644
--- a/src/tint/resolver/sem_helper.cc
+++ b/src/tint/resolver/sem_helper.cc
@@ -15,6 +15,8 @@
#include "src/tint/resolver/sem_helper.h"
#include "src/tint/sem/builtin_enum_expression.h"
+#include "src/tint/sem/function.h"
+#include "src/tint/sem/function_expression.h"
#include "src/tint/sem/type_expression.h"
#include "src/tint/sem/value_expression.h"
@@ -42,11 +44,18 @@
Switch(
expr, //
[&](const sem::VariableUser* var_expr) {
- auto name =
- builder_->Symbols().NameFor(var_expr->Variable()->Declaration()->name->symbol);
- auto type = var_expr->Type()->FriendlyName(builder_->Symbols());
- AddError("cannot use '" + name + "' of type '" + type + "' as " + std::string(wanted),
+ auto* variable = var_expr->Variable()->Declaration();
+ auto name = builder_->Symbols().NameFor(variable->name->symbol);
+ std::string kind = Switch(
+ variable, //
+ [&](const ast::Var*) { return "var"; }, //
+ [&](const ast::Const*) { return "const"; }, //
+ [&](const ast::Parameter*) { return "parameter"; }, //
+ [&](const ast::Override*) { return "override"; }, //
+ [&](Default) { return "variable"; });
+ AddError("cannot use " + kind + " '" + name + "' as " + std::string(wanted),
var_expr->Declaration()->source);
+ NoteDeclarationSource(variable);
},
[&](const sem::ValueExpression* val_expr) {
auto type = val_expr->Type()->FriendlyName(builder_->Symbols());
@@ -58,6 +67,13 @@
AddError("cannot use type '" + name + "' as " + std::string(wanted),
ty_expr->Declaration()->source);
},
+ [&](const sem::FunctionExpression* fn_expr) {
+ auto* fn = fn_expr->Function()->Declaration();
+ auto name = builder_->Symbols().NameFor(fn->name->symbol);
+ AddError("cannot use function '" + name + "' as " + std::string(wanted),
+ fn_expr->Declaration()->source);
+ NoteDeclarationSource(fn);
+ },
[&](const sem::BuiltinEnumExpression<type::Access>* access) {
AddError("cannot use access '" + utils::ToString(access->Value()) + "' as " +
std::string(wanted),
@@ -93,6 +109,44 @@
}
}
+void SemHelper::NoteDeclarationSource(const ast::Node* node) const {
+ Switch(
+ node,
+ [&](const ast::Struct* n) {
+ AddNote("struct '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Alias* n) {
+ AddNote("alias '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Var* n) {
+ AddNote("var '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Let* n) {
+ AddNote("let '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Override* n) {
+ AddNote("override '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Const* n) {
+ AddNote("const '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Parameter* n) {
+ AddNote(
+ "parameter '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ },
+ [&](const ast::Function* n) {
+ AddNote("function '" + builder_->Symbols().NameFor(n->name->symbol) + "' declared here",
+ n->source);
+ });
+}
+
void SemHelper::AddError(const std::string& msg, const Source& source) const {
builder_->Diagnostics().add_error(diag::System::Resolver, msg, source);
}
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index db3eda4..dd752d3 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -21,6 +21,8 @@
#include "src/tint/program_builder.h"
#include "src/tint/resolver/dependency_graph.h"
#include "src/tint/sem/builtin_enum_expression.h"
+#include "src/tint/sem/function_expression.h"
+#include "src/tint/sem/type_expression.h"
#include "src/tint/utils/map.h"
namespace tint::resolver {
@@ -58,16 +60,16 @@
/// @returns the sem node for @p ast
template <typename AST = ast::Node>
auto* GetVal(const AST* ast) const {
- return AsValue(Get(ast));
+ return AsValueExpression(Get(ast));
}
/// @param expr the semantic node
/// @returns nullptr if @p expr is nullptr, or @p expr cast to sem::ValueExpression if the cast
/// is successful, otherwise an error diagnostic is raised.
- sem::ValueExpression* AsValue(sem::Expression* expr) const {
+ sem::ValueExpression* AsValueExpression(sem::Expression* expr) const {
if (TINT_LIKELY(expr)) {
- if (auto* val = expr->As<sem::ValueExpression>(); TINT_LIKELY(val)) {
- return val;
+ if (auto* val_expr = expr->As<sem::ValueExpression>(); TINT_LIKELY(val_expr)) {
+ return val_expr;
}
ErrorExpectedValueExpr(expr);
}
@@ -75,14 +77,56 @@
}
/// @param expr the semantic node
+ /// @returns nullptr if @p expr is nullptr, or @p expr cast to type::Type if the cast is
+ /// successful, otherwise an error diagnostic is raised.
+ sem::TypeExpression* AsTypeExpression(sem::Expression* expr) const {
+ if (TINT_LIKELY(expr)) {
+ if (auto* ty_expr = expr->As<sem::TypeExpression>(); TINT_LIKELY(ty_expr)) {
+ return ty_expr;
+ }
+ ErrorUnexpectedExprKind(expr, "type");
+ }
+ return nullptr;
+ }
+
+ /// @param expr the semantic node
+ /// @returns nullptr if @p expr is nullptr, or @p expr cast to sem::Function if the cast is
+ /// successful, otherwise an error diagnostic is raised.
+ sem::FunctionExpression* AsFunctionExpression(sem::Expression* expr) const {
+ if (TINT_LIKELY(expr)) {
+ auto* fn_expr = expr->As<sem::FunctionExpression>();
+ if (TINT_LIKELY(fn_expr)) {
+ return fn_expr;
+ }
+ ErrorUnexpectedExprKind(expr, "function");
+ }
+ return nullptr;
+ }
+
+ /// @param expr the semantic node
+ /// @returns nullptr if @p expr is nullptr, or @p expr cast to
+ /// sem::BuiltinEnumExpression<type::AddressSpace> if the cast is successful, otherwise an error
+ /// diagnostic is raised.
+ sem::BuiltinEnumExpression<type::AddressSpace>* AsAddressSpace(sem::Expression* expr) const {
+ if (TINT_LIKELY(expr)) {
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::AddressSpace>>();
+ if (TINT_LIKELY(enum_expr)) {
+ return enum_expr;
+ }
+ ErrorUnexpectedExprKind(expr, "address space");
+ }
+ return nullptr;
+ }
+
+ /// @param expr the semantic node
/// @returns nullptr if @p expr is nullptr, or @p expr cast to
/// sem::BuiltinEnumExpression<type::TexelFormat> if the cast is successful, otherwise an error
/// diagnostic is raised.
sem::BuiltinEnumExpression<type::TexelFormat>* AsTexelFormat(sem::Expression* expr) const {
if (TINT_LIKELY(expr)) {
- if (auto* val = expr->As<sem::BuiltinEnumExpression<type::TexelFormat>>();
- TINT_LIKELY(val)) {
- return val;
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::TexelFormat>>();
+ if (TINT_LIKELY(enum_expr)) {
+ return enum_expr;
}
ErrorUnexpectedExprKind(expr, "texel format");
}
@@ -95,9 +139,9 @@
/// diagnostic is raised.
sem::BuiltinEnumExpression<type::Access>* AsAccess(sem::Expression* expr) const {
if (TINT_LIKELY(expr)) {
- if (auto* val = expr->As<sem::BuiltinEnumExpression<type::Access>>();
- TINT_LIKELY(val)) {
- return val;
+ auto* enum_expr = expr->As<sem::BuiltinEnumExpression<type::Access>>();
+ if (TINT_LIKELY(enum_expr)) {
+ return enum_expr;
}
ErrorUnexpectedExprKind(expr, "access");
}
@@ -121,11 +165,17 @@
/// @param expr the expression
void ErrorExpectedValueExpr(const sem::Expression* expr) const;
- private:
/// Raises an error diagnostic that the expression @p got was not of the kind @p wanted.
/// @param expr the expression
+ /// @param wanted the expected expression kind
void ErrorUnexpectedExprKind(const sem::Expression* expr, std::string_view wanted) const;
+ /// If @p node is a module-scope type, variable or function declaration, then appends a note
+ /// diagnostic where this declaration was declared, otherwise the function does nothing.
+ /// @param node the AST node.
+ void NoteDeclarationSource(const ast::Node* node) const;
+
+ private:
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
diff --git a/src/tint/resolver/struct_address_space_use_test.cc b/src/tint/resolver/struct_address_space_use_test.cc
index 6500158..2d011bb 100644
--- a/src/tint/resolver/struct_address_space_use_test.cc
+++ b/src/tint/resolver/struct_address_space_use_test.cc
@@ -99,7 +99,7 @@
TEST_F(ResolverAddressSpaceUseTest, StructReachableViaGlobalArray) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32())});
- auto* a = ty.array(ty.Of(s), 3_u);
+ auto a = ty.array(ty.Of(s), 3_u);
GlobalVar("g", a, type::AddressSpace::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -147,7 +147,7 @@
TEST_F(ResolverAddressSpaceUseTest, StructReachableViaLocalArray) {
auto* s = Structure("S", utils::Vector{Member("a", ty.f32())});
- auto* a = ty.array(ty.Of(s), 3_u);
+ auto a = ty.array(ty.Of(s), 3_u);
WrapInFunction(Var("g", a));
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/struct_layout_test.cc b/src/tint/resolver/struct_layout_test.cc
index 8d9dee9..7288449 100644
--- a/src/tint/resolver/struct_layout_test.cc
+++ b/src/tint/resolver/struct_layout_test.cc
@@ -256,8 +256,8 @@
}
TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayOfExplicitStrideArray) {
- auto* inner = ty.array<i32, 2>(utils::Vector{Stride(16)}); // size: 32
- auto* outer = ty.array(inner, 12_u); // size: 12 * 32
+ auto inner = ty.array<i32, 2>(utils::Vector{Stride(16)}); // size: 32
+ auto outer = ty.array(inner, 12_u); // size: 12 * 32
auto* s = Structure("S", utils::Vector{
Member("c", outer),
});
@@ -283,8 +283,8 @@
Member("a", ty.vec2<i32>()),
Member("b", ty.vec3<i32>()),
Member("c", ty.vec4<i32>()),
- }); // size: 48
- auto* outer = ty.array(ty.Of(inner), 12_u); // size: 12 * 48
+ }); // size: 48
+ auto outer = ty.array(ty.Of(inner), 12_u); // size: 12 * 48
auto* s = Structure("S", utils::Vector{
Member("c", outer),
});
diff --git a/src/tint/resolver/type_initializer_validation_test.cc b/src/tint/resolver/type_initializer_validation_test.cc
index c1212ec..0f7144b 100644
--- a/src/tint/resolver/type_initializer_validation_test.cc
+++ b/src/tint/resolver/type_initializer_validation_test.cc
@@ -346,9 +346,9 @@
Enable(ast::Extension::kF16);
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
- auto* lhs_type1 = params.lhs_type(*this);
- auto* lhs_type2 = params.lhs_type(*this);
- auto* rhs_type = params.rhs_type(*this);
+ auto lhs_type1 = params.lhs_type(*this);
+ auto lhs_type2 = params.lhs_type(*this);
+ auto rhs_type = params.rhs_type(*this);
auto* rhs_value_expr = params.rhs_value_expr(*this, 0);
std::stringstream ss;
@@ -439,9 +439,9 @@
}
// var a : <lhs_type1> = <lhs_type2>(<rhs_type>(<rhs_value_expr>));
- auto* lhs_type1 = lhs_params.ast(*this);
- auto* lhs_type2 = lhs_params.ast(*this);
- auto* rhs_type = rhs_params.ast(*this);
+ auto lhs_type1 = lhs_params.ast(*this);
+ auto lhs_type2 = lhs_params.ast(*this);
+ auto rhs_type = rhs_params.ast(*this);
auto* rhs_value_expr = rhs_params.expr_from_double(*this, 0);
std::stringstream ss;
@@ -523,7 +523,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArray_U32U32U32) {
// array(0u, 10u, 20u);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 0_u, 10_u, 20_u);
+ auto* tc = array<Infer>(Source{{12, 34}}, 0_u, 10_u, 20_u);
WrapInFunction(tc);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -561,7 +561,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArray_U32AIU32) {
// array(0u, 10u, 20u);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 0_u, 10_a, 20_u);
+ auto* tc = array<Infer>(Source{{12, 34}}, 0_u, 10_a, 20_u);
WrapInFunction(tc);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -599,7 +599,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArray_AIAIAI) {
// const c = array(0, 10, 20);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 0_a, 10_a, 20_a);
+ auto* tc = array<Infer>(Source{{12, 34}}, 0_a, 10_a, 20_a);
WrapInFunction(Decl(Const("C", tc)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -618,9 +618,9 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayU32_VecI32_VecAI) {
// array(vec2(10i), vec2(20));
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, //
- Call(ty.vec(nullptr, 2), 20_i), //
- Call(ty.vec(nullptr, 2), 20_a));
+ auto* tc = array<Infer>(Source{{12, 34}}, //
+ Call(ty.vec<Infer>(2), 20_i), //
+ Call(ty.vec<Infer>(2), 20_a));
WrapInFunction(tc);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -640,9 +640,9 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayU32_VecAI_VecF32) {
// array(vec2(20), vec2(10f));
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, //
- Call(ty.vec(nullptr, 2), 20_a), //
- Call(ty.vec(nullptr, 2), 20_f));
+ auto* tc = array<Infer>(Source{{12, 34}}, //
+ Call(ty.vec<Infer>(2), 20_a), //
+ Call(ty.vec<Infer>(2), 20_f));
WrapInFunction(tc);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -671,7 +671,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_U32F32) {
// array(0u, 1.0f, 20u);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 0_u, 1_f, 20_u);
+ auto* tc = array<Infer>(Source{{12, 34}}, 0_u, 1_f, 20_u);
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -692,7 +692,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_F32I32) {
// array(1f, 1i);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 1_f, 1_i);
+ auto* tc = array<Infer>(Source{{12, 34}}, 1_f, 1_i);
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -713,7 +713,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_U32I32) {
// array(1i, 0u, 0u, 0u, 0u, 0u);
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 1_i, 0_u, 0_u, 0_u, 0_u);
+ auto* tc = array<Infer>(Source{{12, 34}}, 1_i, 0_u, 0_u, 0_u, 0_u);
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -734,7 +734,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_I32Vec2) {
// array(1i, vec2<i32>());
- auto* tc = array(Source{{12, 34}}, nullptr, nullptr, 1_i, vec2<i32>());
+ auto* tc = array<Infer>(Source{{12, 34}}, 1_i, vec2<i32>());
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -755,7 +755,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_Vec3i32_Vec3u32) {
// array(vec3<i32>(), vec3<u32>());
- auto* t = array(Source{{12, 34}}, nullptr, nullptr, vec3<i32>(), vec3<u32>());
+ auto* t = array<Infer>(Source{{12, 34}}, vec3<i32>(), vec3<u32>());
WrapInFunction(t);
EXPECT_FALSE(r()->Resolve());
@@ -767,7 +767,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_Vec3i32_Vec3AF) {
// array(vec3<i32>(), vec3(1.0));
- auto* t = array(Source{{12, 34}}, nullptr, nullptr, vec3<i32>(), Call(ty.vec3(nullptr), 1._a));
+ auto* t = array<Infer>(Source{{12, 34}}, vec3<i32>(), Call("vec3", 1._a));
WrapInFunction(t);
EXPECT_FALSE(r()->Resolve());
@@ -789,7 +789,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayArgumentTypeMismatch_Vec3i32_Vec3bool) {
// array(vec3<i32>(), vec3<bool>());
- auto* t = array(Source{{12, 34}}, nullptr, nullptr, vec3<i32>(), vec3<bool>());
+ auto* t = array<Infer>(Source{{12, 34}}, vec3<i32>(), vec3<bool>());
WrapInFunction(t);
EXPECT_FALSE(r()->Resolve());
@@ -811,7 +811,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayOfArray_SubElemSizeMismatch) {
// array<array<i32, 2u>, 2u>(array<i32, 3u>(), array<i32, 2u>());
- auto* t = array(Source{{12, 34}}, nullptr, nullptr, array<i32, 3>(), array<i32, 2>());
+ auto* t = array<Infer>(Source{{12, 34}}, array<i32, 3>(), array<i32, 2>());
WrapInFunction(t);
EXPECT_FALSE(r()->Resolve());
@@ -833,7 +833,7 @@
TEST_F(ResolverTypeInitializerValidationTest, InferredArrayOfArray_SubElemTypeMismatch) {
// array<array<i32, 2u>, 2u>(array<i32, 2u>(), array<u32, 2u>());
- auto* t = array(Source{{12, 34}}, nullptr, nullptr, array<i32, 2>(), array<u32, 2>());
+ auto* t = array<Infer>(Source{{12, 34}}, array<i32, 2>(), array<u32, 2>());
WrapInFunction(t);
EXPECT_FALSE(r()->Resolve());
@@ -869,7 +869,7 @@
TEST_F(ResolverTypeInitializerValidationTest, Array_Runtime) {
// array<i32>(1i);
- auto* tc = array(Source{{12, 34}}, ty.i32(), nullptr, Expr(1_i));
+ auto* tc = array<i32>(Source{{12, 34}}, Expr(1_i));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -878,7 +878,7 @@
TEST_F(ResolverTypeInitializerValidationTest, Array_RuntimeZeroValue) {
// array<i32>();
- auto* tc = array(Source{{12, 34}}, ty.i32(), nullptr);
+ auto* tc = array<i32>(Source{{12, 34}});
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
@@ -1992,7 +1992,7 @@
auto* f32_alias = Alias("Float32", ty.f32());
// vec2<Float32>(1.0f, 1u)
- auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto vec_type = ty.vec(ty.Of(f32_alias), 2);
WrapInFunction(Call(Source{{12, 34}}, vec_type, 1_f, 1_u));
EXPECT_FALSE(r()->Resolve());
@@ -2004,7 +2004,7 @@
auto* f32_alias = Alias("Float32", ty.f32());
// vec2<Float32>(1.0f, 1.0f)
- auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto vec_type = ty.vec(ty.Of(f32_alias), 2);
auto* tc = Call(Source{{12, 34}}, vec_type, 1_f, 1_f);
WrapInFunction(tc);
@@ -2015,7 +2015,7 @@
auto* f32_alias = Alias("Float32", ty.f32());
// vec3<u32>(vec<Float32>(), 1.0f)
- auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto vec_type = ty.vec(ty.Of(f32_alias), 2);
WrapInFunction(vec3<u32>(Source{{12, 34}}, Call(vec_type), 1_f));
EXPECT_FALSE(r()->Resolve());
@@ -2027,7 +2027,7 @@
auto* f32_alias = Alias("Float32", ty.f32());
// vec3<f32>(vec<Float32>(), 1.0f)
- auto* vec_type = ty.vec(ty.Of(f32_alias), 2);
+ auto vec_type = ty.vec(ty.Of(f32_alias), 2);
auto* tc = vec3<f32>(Call(Source{{12, 34}}, vec_type), 1_f);
WrapInFunction(tc);
@@ -2037,11 +2037,11 @@
TEST_F(ResolverTypeInitializerValidationTest, InferVec2ElementTypeFromScalars) {
Enable(ast::Extension::kF16);
- auto* vec2_bool = Call(create<ast::Vector>(nullptr, 2u), Expr(true), Expr(false));
- auto* vec2_i32 = Call(create<ast::Vector>(nullptr, 2u), Expr(1_i), Expr(2_i));
- auto* vec2_u32 = Call(create<ast::Vector>(nullptr, 2u), Expr(1_u), Expr(2_u));
- auto* vec2_f32 = Call(create<ast::Vector>(nullptr, 2u), Expr(1_f), Expr(2_f));
- auto* vec2_f16 = Call(create<ast::Vector>(nullptr, 2u), Expr(1_h), Expr(2_h));
+ auto* vec2_bool = vec2<Infer>(true, false);
+ auto* vec2_i32 = vec2<Infer>(1_i, 2_i);
+ auto* vec2_u32 = vec2<Infer>(1_u, 2_u);
+ auto* vec2_f32 = vec2<Infer>(1_f, 2_f);
+ auto* vec2_f16 = vec2<Infer>(1_h, 2_h);
WrapInFunction(vec2_bool, vec2_i32, vec2_u32, vec2_f32, vec2_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2061,21 +2061,16 @@
EXPECT_EQ(TypeOf(vec2_u32)->As<type::Vector>()->Width(), 2u);
EXPECT_EQ(TypeOf(vec2_f32)->As<type::Vector>()->Width(), 2u);
EXPECT_EQ(TypeOf(vec2_f16)->As<type::Vector>()->Width(), 2u);
- EXPECT_EQ(TypeOf(vec2_bool), TypeOf(vec2_bool->target.type));
- EXPECT_EQ(TypeOf(vec2_i32), TypeOf(vec2_i32->target.type));
- EXPECT_EQ(TypeOf(vec2_u32), TypeOf(vec2_u32->target.type));
- EXPECT_EQ(TypeOf(vec2_f32), TypeOf(vec2_f32->target.type));
- EXPECT_EQ(TypeOf(vec2_f16), TypeOf(vec2_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec2ElementTypeFromVec2) {
Enable(ast::Extension::kF16);
- auto* vec2_bool = Call(create<ast::Vector>(nullptr, 2u), vec2<bool>(true, false));
- auto* vec2_i32 = Call(create<ast::Vector>(nullptr, 2u), vec2<i32>(1_i, 2_i));
- auto* vec2_u32 = Call(create<ast::Vector>(nullptr, 2u), vec2<u32>(1_u, 2_u));
- auto* vec2_f32 = Call(create<ast::Vector>(nullptr, 2u), vec2<f32>(1_f, 2_f));
- auto* vec2_f16 = Call(create<ast::Vector>(nullptr, 2u), vec2<f16>(1_h, 2_h));
+ auto* vec2_bool = vec2<Infer>(vec2<bool>(true, false));
+ auto* vec2_i32 = vec2<Infer>(vec2<i32>(1_i, 2_i));
+ auto* vec2_u32 = vec2<Infer>(vec2<u32>(1_u, 2_u));
+ auto* vec2_f32 = vec2<Infer>(vec2<f32>(1_f, 2_f));
+ auto* vec2_f16 = vec2<Infer>(vec2<f16>(1_h, 2_h));
WrapInFunction(vec2_bool, vec2_i32, vec2_u32, vec2_f32, vec2_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2095,21 +2090,16 @@
EXPECT_EQ(TypeOf(vec2_u32)->As<type::Vector>()->Width(), 2u);
EXPECT_EQ(TypeOf(vec2_f32)->As<type::Vector>()->Width(), 2u);
EXPECT_EQ(TypeOf(vec2_f16)->As<type::Vector>()->Width(), 2u);
- EXPECT_EQ(TypeOf(vec2_bool), TypeOf(vec2_bool->target.type));
- EXPECT_EQ(TypeOf(vec2_i32), TypeOf(vec2_i32->target.type));
- EXPECT_EQ(TypeOf(vec2_u32), TypeOf(vec2_u32->target.type));
- EXPECT_EQ(TypeOf(vec2_f32), TypeOf(vec2_f32->target.type));
- EXPECT_EQ(TypeOf(vec2_f16), TypeOf(vec2_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec3ElementTypeFromScalars) {
Enable(ast::Extension::kF16);
- auto* vec3_bool = Call(create<ast::Vector>(nullptr, 3u), Expr(true), Expr(false), Expr(true));
- auto* vec3_i32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_i), Expr(2_i), Expr(3_i));
- auto* vec3_u32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_u), Expr(2_u), Expr(3_u));
- auto* vec3_f32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_f), Expr(2_f), Expr(3_f));
- auto* vec3_f16 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_h), Expr(2_h), Expr(3_h));
+ auto* vec3_bool = vec3<Infer>(Expr(true), Expr(false), Expr(true));
+ auto* vec3_i32 = vec3<Infer>(Expr(1_i), Expr(2_i), Expr(3_i));
+ auto* vec3_u32 = vec3<Infer>(Expr(1_u), Expr(2_u), Expr(3_u));
+ auto* vec3_f32 = vec3<Infer>(Expr(1_f), Expr(2_f), Expr(3_f));
+ auto* vec3_f16 = vec3<Infer>(Expr(1_h), Expr(2_h), Expr(3_h));
WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32, vec3_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2129,21 +2119,16 @@
EXPECT_EQ(TypeOf(vec3_u32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f16)->As<type::Vector>()->Width(), 3u);
- EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
- EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
- EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
- EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
- EXPECT_EQ(TypeOf(vec3_f16), TypeOf(vec3_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec3ElementTypeFromVec3) {
Enable(ast::Extension::kF16);
- auto* vec3_bool = Call(create<ast::Vector>(nullptr, 3u), vec3<bool>(true, false, true));
- auto* vec3_i32 = Call(create<ast::Vector>(nullptr, 3u), vec3<i32>(1_i, 2_i, 3_i));
- auto* vec3_u32 = Call(create<ast::Vector>(nullptr, 3u), vec3<u32>(1_u, 2_u, 3_u));
- auto* vec3_f32 = Call(create<ast::Vector>(nullptr, 3u), vec3<f32>(1_f, 2_f, 3_f));
- auto* vec3_f16 = Call(create<ast::Vector>(nullptr, 3u), vec3<f16>(1_h, 2_h, 3_h));
+ auto* vec3_bool = vec3<Infer>(vec3<bool>(true, false, true));
+ auto* vec3_i32 = vec3<Infer>(vec3<i32>(1_i, 2_i, 3_i));
+ auto* vec3_u32 = vec3<Infer>(vec3<u32>(1_u, 2_u, 3_u));
+ auto* vec3_f32 = vec3<Infer>(vec3<f32>(1_f, 2_f, 3_f));
+ auto* vec3_f16 = vec3<Infer>(vec3<f16>(1_h, 2_h, 3_h));
WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32, vec3_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2163,21 +2148,16 @@
EXPECT_EQ(TypeOf(vec3_u32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f16)->As<type::Vector>()->Width(), 3u);
- EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
- EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
- EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
- EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
- EXPECT_EQ(TypeOf(vec3_f16), TypeOf(vec3_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec3ElementTypeFromScalarAndVec2) {
Enable(ast::Extension::kF16);
- auto* vec3_bool = Call(create<ast::Vector>(nullptr, 3u), Expr(true), vec2<bool>(false, true));
- auto* vec3_i32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_i), vec2<i32>(2_i, 3_i));
- auto* vec3_u32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_u), vec2<u32>(2_u, 3_u));
- auto* vec3_f32 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_f), vec2<f32>(2_f, 3_f));
- auto* vec3_f16 = Call(create<ast::Vector>(nullptr, 3u), Expr(1_h), vec2<f16>(2_h, 3_h));
+ auto* vec3_bool = vec3<Infer>(Expr(true), vec2<bool>(false, true));
+ auto* vec3_i32 = vec3<Infer>(Expr(1_i), vec2<i32>(2_i, 3_i));
+ auto* vec3_u32 = vec3<Infer>(Expr(1_u), vec2<u32>(2_u, 3_u));
+ auto* vec3_f32 = vec3<Infer>(Expr(1_f), vec2<f32>(2_f, 3_f));
+ auto* vec3_f16 = vec3<Infer>(Expr(1_h), vec2<f16>(2_h, 3_h));
WrapInFunction(vec3_bool, vec3_i32, vec3_u32, vec3_f32, vec3_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2197,26 +2177,16 @@
EXPECT_EQ(TypeOf(vec3_u32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f32)->As<type::Vector>()->Width(), 3u);
EXPECT_EQ(TypeOf(vec3_f16)->As<type::Vector>()->Width(), 3u);
- EXPECT_EQ(TypeOf(vec3_bool), TypeOf(vec3_bool->target.type));
- EXPECT_EQ(TypeOf(vec3_i32), TypeOf(vec3_i32->target.type));
- EXPECT_EQ(TypeOf(vec3_u32), TypeOf(vec3_u32->target.type));
- EXPECT_EQ(TypeOf(vec3_f32), TypeOf(vec3_f32->target.type));
- EXPECT_EQ(TypeOf(vec3_f16), TypeOf(vec3_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec4ElementTypeFromScalars) {
Enable(ast::Extension::kF16);
- auto* vec4_bool =
- Call(create<ast::Vector>(nullptr, 4u), Expr(true), Expr(false), Expr(true), Expr(false));
- auto* vec4_i32 =
- Call(create<ast::Vector>(nullptr, 4u), Expr(1_i), Expr(2_i), Expr(3_i), Expr(4_i));
- auto* vec4_u32 =
- Call(create<ast::Vector>(nullptr, 4u), Expr(1_u), Expr(2_u), Expr(3_u), Expr(4_u));
- auto* vec4_f32 =
- Call(create<ast::Vector>(nullptr, 4u), Expr(1_f), Expr(2_f), Expr(3_f), Expr(4_f));
- auto* vec4_f16 =
- Call(create<ast::Vector>(nullptr, 4u), Expr(1_h), Expr(2_h), Expr(3_h), Expr(4_h));
+ auto* vec4_bool = vec4<Infer>(Expr(true), Expr(false), Expr(true), Expr(false));
+ auto* vec4_i32 = vec4<Infer>(Expr(1_i), Expr(2_i), Expr(3_i), Expr(4_i));
+ auto* vec4_u32 = vec4<Infer>(Expr(1_u), Expr(2_u), Expr(3_u), Expr(4_u));
+ auto* vec4_f32 = vec4<Infer>(Expr(1_f), Expr(2_f), Expr(3_f), Expr(4_f));
+ auto* vec4_f16 = vec4<Infer>(Expr(1_h), Expr(2_h), Expr(3_h), Expr(4_h));
WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32, vec4_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2236,21 +2206,16 @@
EXPECT_EQ(TypeOf(vec4_u32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f16)->As<type::Vector>()->Width(), 4u);
- EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
- EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
- EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
- EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
- EXPECT_EQ(TypeOf(vec4_f16), TypeOf(vec4_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec4ElementTypeFromVec4) {
Enable(ast::Extension::kF16);
- auto* vec4_bool = Call(create<ast::Vector>(nullptr, 4u), vec4<bool>(true, false, true, false));
- auto* vec4_i32 = Call(create<ast::Vector>(nullptr, 4u), vec4<i32>(1_i, 2_i, 3_i, 4_i));
- auto* vec4_u32 = Call(create<ast::Vector>(nullptr, 4u), vec4<u32>(1_u, 2_u, 3_u, 4_u));
- auto* vec4_f32 = Call(create<ast::Vector>(nullptr, 4u), vec4<f32>(1_f, 2_f, 3_f, 4_f));
- auto* vec4_f16 = Call(create<ast::Vector>(nullptr, 4u), vec4<f16>(1_h, 2_h, 3_h, 4_h));
+ auto* vec4_bool = vec4<Infer>(vec4<bool>(true, false, true, false));
+ auto* vec4_i32 = vec4<Infer>(vec4<i32>(1_i, 2_i, 3_i, 4_i));
+ auto* vec4_u32 = vec4<Infer>(vec4<u32>(1_u, 2_u, 3_u, 4_u));
+ auto* vec4_f32 = vec4<Infer>(vec4<f32>(1_f, 2_f, 3_f, 4_f));
+ auto* vec4_f16 = vec4<Infer>(vec4<f16>(1_h, 2_h, 3_h, 4_h));
WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32, vec4_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2270,22 +2235,16 @@
EXPECT_EQ(TypeOf(vec4_u32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f16)->As<type::Vector>()->Width(), 4u);
- EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
- EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
- EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
- EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
- EXPECT_EQ(TypeOf(vec4_f16), TypeOf(vec4_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec4ElementTypeFromScalarAndVec3) {
Enable(ast::Extension::kF16);
- auto* vec4_bool =
- Call(create<ast::Vector>(nullptr, 4u), Expr(true), vec3<bool>(false, true, false));
- auto* vec4_i32 = Call(create<ast::Vector>(nullptr, 4u), Expr(1_i), vec3<i32>(2_i, 3_i, 4_i));
- auto* vec4_u32 = Call(create<ast::Vector>(nullptr, 4u), Expr(1_u), vec3<u32>(2_u, 3_u, 4_u));
- auto* vec4_f32 = Call(create<ast::Vector>(nullptr, 4u), Expr(1_f), vec3<f32>(2_f, 3_f, 4_f));
- auto* vec4_f16 = Call(create<ast::Vector>(nullptr, 4u), Expr(1_h), vec3<f16>(2_h, 3_h, 4_h));
+ auto* vec4_bool = vec4<Infer>(Expr(true), vec3<bool>(false, true, false));
+ auto* vec4_i32 = vec4<Infer>(Expr(1_i), vec3<i32>(2_i, 3_i, 4_i));
+ auto* vec4_u32 = vec4<Infer>(Expr(1_u), vec3<u32>(2_u, 3_u, 4_u));
+ auto* vec4_f32 = vec4<Infer>(Expr(1_f), vec3<f32>(2_f, 3_f, 4_f));
+ auto* vec4_f16 = vec4<Infer>(Expr(1_h), vec3<f16>(2_h, 3_h, 4_h));
WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32, vec4_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2305,26 +2264,16 @@
EXPECT_EQ(TypeOf(vec4_u32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f16)->As<type::Vector>()->Width(), 4u);
- EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
- EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
- EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
- EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
- EXPECT_EQ(TypeOf(vec4_f16), TypeOf(vec4_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, InferVec4ElementTypeFromVec2AndVec2) {
Enable(ast::Extension::kF16);
- auto* vec4_bool =
- Call(create<ast::Vector>(nullptr, 4u), vec2<bool>(true, false), vec2<bool>(true, false));
- auto* vec4_i32 =
- Call(create<ast::Vector>(nullptr, 4u), vec2<i32>(1_i, 2_i), vec2<i32>(3_i, 4_i));
- auto* vec4_u32 =
- Call(create<ast::Vector>(nullptr, 4u), vec2<u32>(1_u, 2_u), vec2<u32>(3_u, 4_u));
- auto* vec4_f32 =
- Call(create<ast::Vector>(nullptr, 4u), vec2<f32>(1_f, 2_f), vec2<f32>(3_f, 4_f));
- auto* vec4_f16 =
- Call(create<ast::Vector>(nullptr, 4u), vec2<f16>(1_h, 2_h), vec2<f16>(3_h, 4_h));
+ auto* vec4_bool = vec4<Infer>(vec2<bool>(true, false), vec2<bool>(true, false));
+ auto* vec4_i32 = vec4<Infer>(vec2<i32>(1_i, 2_i), vec2<i32>(3_i, 4_i));
+ auto* vec4_u32 = vec4<Infer>(vec2<u32>(1_u, 2_u), vec2<u32>(3_u, 4_u));
+ auto* vec4_f32 = vec4<Infer>(vec2<f32>(1_f, 2_f), vec2<f32>(3_f, 4_f));
+ auto* vec4_f16 = vec4<Infer>(vec2<f16>(1_h, 2_h), vec2<f16>(3_h, 4_h));
WrapInFunction(vec4_bool, vec4_i32, vec4_u32, vec4_f32, vec4_f16);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2344,22 +2293,17 @@
EXPECT_EQ(TypeOf(vec4_u32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f32)->As<type::Vector>()->Width(), 4u);
EXPECT_EQ(TypeOf(vec4_f16)->As<type::Vector>()->Width(), 4u);
- EXPECT_EQ(TypeOf(vec4_bool), TypeOf(vec4_bool->target.type));
- EXPECT_EQ(TypeOf(vec4_i32), TypeOf(vec4_i32->target.type));
- EXPECT_EQ(TypeOf(vec4_u32), TypeOf(vec4_u32->target.type));
- EXPECT_EQ(TypeOf(vec4_f32), TypeOf(vec4_f32->target.type));
- EXPECT_EQ(TypeOf(vec4_f16), TypeOf(vec4_f16->target.type));
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVectorElementTypeWithoutArgs) {
- WrapInFunction(Call(Source{{12, 34}}, create<ast::Vector>(nullptr, 3u)));
+ WrapInFunction(Call(Source{{12, 34}}, "vec3"));
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching initializer for vec3()"));
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec2ElementTypeFromScalarsMismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 2u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec2", //
Expr(Source{{1, 2}}, 1_i), //
Expr(Source{{1, 3}}, 2_u)));
@@ -2368,7 +2312,7 @@
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec3ElementTypeFromScalarsMismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 3u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec3", //
Expr(Source{{1, 2}}, 1_i), //
Expr(Source{{1, 3}}, 2_u), //
Expr(Source{{1, 4}}, 3_i)));
@@ -2379,7 +2323,7 @@
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec3ElementTypeFromScalarAndVec2Mismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 3u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec3", //
Expr(Source{{1, 2}}, 1_i), //
Call(Source{{1, 3}}, ty.vec2<f32>(), 2_f, 3_f)));
@@ -2389,7 +2333,7 @@
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec4ElementTypeFromScalarsMismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 4u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec4", //
Expr(Source{{1, 2}}, 1_i), //
Expr(Source{{1, 3}}, 2_i), //
Expr(Source{{1, 4}}, 3_f), //
@@ -2401,7 +2345,7 @@
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec4ElementTypeFromScalarAndVec3Mismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 4u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec4", //
Expr(Source{{1, 2}}, 1_i), //
Call(Source{{1, 3}}, ty.vec3<u32>(), 2_u, 3_u, 4_u)));
@@ -2411,7 +2355,7 @@
}
TEST_F(ResolverTypeInitializerValidationTest, CannotInferVec4ElementTypeFromVec2AndVec2Mismatch) {
- WrapInFunction(Call(Source{{1, 1}}, create<ast::Vector>(nullptr, 4u),
+ WrapInFunction(Call(Source{{1, 1}}, "vec4", //
Call(Source{{1, 2}}, ty.vec2<i32>(), 3_i, 4_i), //
Call(Source{{1, 3}}, ty.vec2<u32>(), 3_u, 4_u)));
@@ -2468,7 +2412,7 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns - 1; i++) {
- auto* vec_type = param.create_column_ast_type(*this);
+ ast::Type vec_type = param.create_column_ast_type(*this);
args.Push(Call(vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2476,7 +2420,7 @@
args_tys << "vec" << param.rows << "<" + element_type_name + ">";
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2504,7 +2448,7 @@
args_tys << element_type_name;
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2525,7 +2469,7 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns + 1; i++) {
- auto* vec_type = param.create_column_ast_type(*this);
+ ast::Type vec_type = param.create_column_ast_type(*this);
args.Push(Call(vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2533,7 +2477,7 @@
args_tys << "vec" << param.rows << "<" + element_type_name + ">";
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2561,7 +2505,7 @@
args_tys << element_type_name;
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2581,7 +2525,7 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = ty.vec<u32>(param.rows);
+ auto vec_type = ty.vec<u32>(param.rows);
args.Push(Call(vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2589,7 +2533,7 @@
args_tys << "vec" << param.rows << "<u32>";
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2616,7 +2560,7 @@
args_tys << "u32";
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2642,7 +2586,7 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* valid_vec_type = param.create_column_ast_type(*this);
+ ast::Type valid_vec_type = param.create_column_ast_type(*this);
args.Push(Call(valid_vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2650,11 +2594,11 @@
args_tys << "vec" << param.rows << "<" + element_type_name + ">";
}
const size_t kInvalidLoc = 2 * (param.columns - 1);
- auto* invalid_vec_type = ty.vec(param.create_element_ast_type(*this), param.rows - 1);
+ auto invalid_vec_type = ty.vec(param.create_element_ast_type(*this), param.rows - 1);
args.Push(Call(Source{{12, kInvalidLoc}}, invalid_vec_type));
args_tys << ", vec" << (param.rows - 1) << "<" + element_type_name + ">";
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2680,18 +2624,18 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* valid_vec_type = param.create_column_ast_type(*this);
+ ast::Type valid_vec_type = param.create_column_ast_type(*this);
args.Push(Call(valid_vec_type));
if (i > 0) {
args_tys << ", ";
}
args_tys << "vec" << param.rows << "<" + element_type_name + ">";
}
- auto* invalid_vec_type = ty.vec(param.create_element_ast_type(*this), param.rows + 1);
+ auto invalid_vec_type = ty.vec(param.create_element_ast_type(*this), param.rows + 1);
args.Push(Call(invalid_vec_type));
args_tys << ", vec" << (param.rows + 1) << "<" + element_type_name + ">";
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2708,7 +2652,7 @@
Enable(ast::Extension::kF16);
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{{12, 40}}, matrix_type);
WrapInFunction(tc);
@@ -2725,11 +2669,11 @@
utils::Vector<const ast::Expression*, 4> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = param.create_column_ast_type(*this);
+ ast::Type vec_type = param.create_column_ast_type(*this);
args.Push(Call(vec_type));
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2749,7 +2693,7 @@
args.Push(Call(param.create_element_ast_type(*this)));
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2769,7 +2713,7 @@
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 4> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = ty.vec(ty.u32(), param.rows);
+ auto vec_type = ty.vec(ty.u32(), param.rows);
args.Push(Call(vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2777,7 +2721,7 @@
args_tys << "vec" << param.rows << "<u32>";
}
- auto* matrix_type = ty.mat(ty.Of(elem_type_alias), param.columns, param.rows);
+ auto matrix_type = ty.mat(ty.Of(elem_type_alias), param.columns, param.rows);
auto* tc = Call(Source{{12, 34}}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2798,11 +2742,11 @@
utils::Vector<const ast::Expression*, 8> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = param.create_column_ast_type(*this);
+ ast::Type vec_type = param.create_column_ast_type(*this);
args.Push(Call(vec_type));
}
- auto* matrix_type = ty.mat(ty.Of(elem_type_alias), param.columns, param.rows);
+ auto matrix_type = ty.mat(ty.Of(elem_type_alias), param.columns, param.rows);
auto* tc = Call(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2825,8 +2769,8 @@
Enable(ast::Extension::kF16);
- auto* matrix_type = param.create_mat_ast_type(*this);
- auto* vec_type = param.create_column_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
+ ast::Type vec_type = param.create_column_ast_type(*this);
auto* vec_alias = Alias("ColVectorAlias", vec_type);
utils::Vector<const ast::Expression*, 4> args;
@@ -2845,13 +2789,13 @@
Enable(ast::Extension::kF16);
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* u32_type_alias = Alias("UnsignedInt", ty.u32());
std::stringstream args_tys;
utils::Vector<const ast::Expression*, 4> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = ty.vec(ty.Of(u32_type_alias), param.rows);
+ auto vec_type = ty.vec(ty.Of(u32_type_alias), param.rows);
args.Push(Call(vec_type));
if (i > 0) {
args_tys << ", ";
@@ -2876,11 +2820,11 @@
utils::Vector<const ast::Expression*, 4> args;
for (uint32_t i = 0; i < param.columns; i++) {
- auto* vec_type = ty.vec(ty.Of(elem_type_alias), param.rows);
+ auto vec_type = ty.vec(ty.Of(elem_type_alias), param.rows);
args.Push(Call(vec_type));
}
- auto* matrix_type = param.create_mat_ast_type(*this);
+ ast::Type matrix_type = param.create_mat_ast_type(*this);
auto* tc = Call(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2897,7 +2841,7 @@
args.Push(Call(param.create_column_ast_type(*this)));
}
- auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ auto matrix_type = ty.mat<Infer>(param.columns, param.rows);
auto* tc = Call(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
@@ -2914,7 +2858,7 @@
args.Push(param.create_element_ast_value(*this, static_cast<double>(i)));
}
- auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ auto matrix_type = ty.mat<Infer>(param.columns, param.rows);
WrapInFunction(Call(Source{{12, 34}}, matrix_type, std::move(args)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -2944,7 +2888,7 @@
}
}
- auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ auto matrix_type = ty.mat<Infer>(param.columns, param.rows);
WrapInFunction(Call(Source{{12, 34}}, matrix_type, std::move(args)));
EXPECT_FALSE(r()->Resolve());
@@ -2976,7 +2920,7 @@
err << ")";
- auto* matrix_type = create<ast::Matrix>(nullptr, param.rows, param.columns);
+ auto matrix_type = ty.mat<Infer>(param.columns, param.rows);
WrapInFunction(Call(Source{{12, 34}}, matrix_type, std::move(args)));
EXPECT_FALSE(r()->Resolve());
@@ -3049,7 +2993,7 @@
utils::Vector<const ast::StructMember*, 16> members;
utils::Vector<const ast::Expression*, 16> values;
for (uint32_t i = 0; i < N; i++) {
- auto* struct_type = str_params.ast(*this);
+ ast::Type struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
if (i < N - 1) {
auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
@@ -3075,7 +3019,7 @@
utils::Vector<const ast::Expression*, 8> values;
for (uint32_t i = 0; i < N + 1; i++) {
if (i < N) {
- auto* struct_type = str_params.ast(*this);
+ ast::Type struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
}
auto* ctor_value_expr = str_params.expr_from_double(*this, 0);
@@ -3113,7 +3057,7 @@
// make the last value of the initializer to have a different type
uint32_t initializer_value_with_different_type = N - 1;
for (uint32_t i = 0; i < N; i++) {
- auto* struct_type = str_params.ast(*this);
+ ast::Type struct_type = str_params.ast(*this);
members.Push(Member("member_" + std::to_string(i), struct_type));
auto* ctor_value_expr = (i == initializer_value_with_different_type)
? ctor_params.expr_from_double(*this, 0)
@@ -3124,11 +3068,9 @@
auto* tc = Call(ty.Of(s), values);
WrapInFunction(tc);
- std::string found = FriendlyName(ctor_params.ast(*this));
- std::string expected = FriendlyName(str_params.ast(*this));
std::stringstream err;
err << "error: type in struct initializer does not match struct member ";
- err << "type: expected '" << expected << "', found '" << found << "'";
+ err << "type: expected '" << str_params.name() << "', found '" << ctor_params.name() << "'";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), err.str());
}
diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc
index 70b3837..0e69da8 100644
--- a/src/tint/resolver/type_validation_test.cc
+++ b/src/tint/resolver/type_validation_test.cc
@@ -612,11 +612,11 @@
// };
Structure("S", utils::Vector{
- Member("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3u)),
+ Member("a", ty.vec3<Infer>(Source{{12, 34}})),
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'vec3'");
}
TEST_F(ResolverTypeValidationTest, Struct_Member_MatrixNoType) {
@@ -624,11 +624,11 @@
// a: mat3x3;
// };
Structure("S", utils::Vector{
- Member("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3u, 3u)),
+ Member("a", ty.mat3x3<Infer>(Source{{12, 34}})),
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'mat3x3'");
}
TEST_F(ResolverTypeValidationTest, Struct_TooBig) {
@@ -641,11 +641,9 @@
// }
Structure(Source{{10, 34}}, "Bar", utils::Vector{Member("a", ty.array<f32, 10000>())});
- Structure(
- Source{{12, 34}}, "Foo",
- utils::Vector{
- Member("a", ty.array(ty(Source{{12, 30}}, "Bar"), Expr(Source{{12, 34}}, 65535_a))),
- Member("b", ty.array(ty(Source{{12, 30}}, "Bar"), Expr(Source{{12, 34}}, 65535_a)))});
+ Structure(Source{{12, 34}}, "Foo",
+ utils::Vector{Member("a", ty.array(ty(Source{{12, 30}}, "Bar"), Expr(65535_a))),
+ Member("b", ty.array(ty("Bar"), Expr(65535_a)))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -865,7 +863,7 @@
}
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
- auto* tex_ty = ty.sampled_texture(Source{{12, 34}}, type::TextureDimension::k2d, ty.f32());
+ auto tex_ty = ty.sampled_texture(Source{{12, 34}}, type::TextureDimension::k2d, ty.f32());
GlobalVar("arr", ty.array(tex_ty, 4_i), type::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
@@ -874,7 +872,7 @@
}
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableTypeWithStride) {
- auto* ptr_ty = ty.pointer<u32>(Source{{12, 34}}, type::AddressSpace::kUniform);
+ auto ptr_ty = ty.pointer<u32>(Source{{12, 34}}, type::AddressSpace::kUniform);
GlobalVar("arr", ty.array(ptr_ty, 4_i, utils::Vector{Stride(16)}),
type::AddressSpace::kPrivate);
@@ -922,7 +920,7 @@
TEST_P(CanonicalTest, All) {
auto& params = GetParam();
- auto* type = params.create_ast_type(*this);
+ ast::Type type = params.create_ast_type(*this);
auto* var = Var("v", type);
auto* expr = Expr("v");
@@ -941,15 +939,11 @@
} // namespace GetCanonicalTests
namespace SampledTextureTests {
-struct DimensionParams {
- type::TextureDimension dim;
- bool is_valid;
-};
-using SampledTextureDimensionTest = ResolverTestWithParam<DimensionParams>;
+using SampledTextureDimensionTest = ResolverTestWithParam<type::TextureDimension>;
TEST_P(SampledTextureDimensionTest, All) {
auto& params = GetParam();
- GlobalVar(Source{{12, 34}}, "a", ty.sampled_texture(params.dim, ty.i32()), Group(0_a),
+ GlobalVar(Source{{12, 34}}, "a", ty.sampled_texture(params, ty.i32()), Group(0_a),
Binding(0_a));
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -957,35 +951,24 @@
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
SampledTextureDimensionTest,
testing::Values( //
- DimensionParams{type::TextureDimension::k1d, true},
- DimensionParams{type::TextureDimension::k2d, true},
- DimensionParams{type::TextureDimension::k2dArray, true},
- DimensionParams{type::TextureDimension::k3d, true},
- DimensionParams{type::TextureDimension::kCube, true},
- DimensionParams{type::TextureDimension::kCubeArray, true}));
+ type::TextureDimension::k1d,
+ type::TextureDimension::k2d,
+ type::TextureDimension::k2dArray,
+ type::TextureDimension::k3d,
+ type::TextureDimension::kCube,
+ type::TextureDimension::kCubeArray));
-using MultisampledTextureDimensionTest = ResolverTestWithParam<DimensionParams>;
+using MultisampledTextureDimensionTest = ResolverTestWithParam<type::TextureDimension>;
TEST_P(MultisampledTextureDimensionTest, All) {
auto& params = GetParam();
- GlobalVar("a", ty.multisampled_texture(Source{{12, 34}}, params.dim, ty.i32()), Group(0_a),
+ GlobalVar("a", ty.multisampled_texture(Source{{12, 34}}, params, ty.i32()), Group(0_a),
Binding(0_a));
- if (params.is_valid) {
- EXPECT_TRUE(r()->Resolve()) << r()->error();
- } else {
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: only 2d multisampled textures are supported");
- }
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
MultisampledTextureDimensionTest,
- testing::Values( //
- DimensionParams{type::TextureDimension::k1d, false},
- DimensionParams{type::TextureDimension::k2d, true},
- DimensionParams{type::TextureDimension::k2dArray, false},
- DimensionParams{type::TextureDimension::k3d, false},
- DimensionParams{type::TextureDimension::kCube, false},
- DimensionParams{type::TextureDimension::kCubeArray, false}));
+ testing::Values(type::TextureDimension::k2d));
struct TypeParams {
builder::ast_type_func_ptr type_func;
@@ -1082,8 +1065,8 @@
// var a : texture_storage_*<r32uint, write>;
auto& params = GetParam();
- auto* st = ty(Source{{12, 34}}, params.name, utils::ToString(type::TexelFormat::kR32Uint),
- utils::ToString(type::Access::kWrite));
+ auto st = ty(Source{{12, 34}}, params.name, utils::ToString(type::TexelFormat::kR32Uint),
+ utils::ToString(type::Access::kWrite));
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1133,19 +1116,19 @@
// @group(0) @binding(3)
// var d : texture_storage_3d<*, write>;
- auto* st_a = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d, params.format,
- type::Access::kWrite);
+ auto st_a = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d, params.format,
+ type::Access::kWrite);
GlobalVar("a", st_a, Group(0_a), Binding(0_a));
- auto* st_b =
+ ast::Type st_b =
ty.storage_texture(type::TextureDimension::k2d, params.format, type::Access::kWrite);
GlobalVar("b", st_b, Group(0_a), Binding(1_a));
- auto* st_c =
+ ast::Type st_c =
ty.storage_texture(type::TextureDimension::k2dArray, params.format, type::Access::kWrite);
GlobalVar("c", st_c, Group(0_a), Binding(2_a));
- auto* st_d =
+ ast::Type st_d =
ty.storage_texture(type::TextureDimension::k3d, params.format, type::Access::kWrite);
GlobalVar("d", st_d, Group(0_a), Binding(3_a));
@@ -1155,7 +1138,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: image format must be one of the texel formats specified for "
- "storage textues in https://gpuweb.github.io/gpuweb/wgsl/#texel-formats");
+ "storage textures in https://gpuweb.github.io/gpuweb/wgsl/#texel-formats");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
@@ -1168,7 +1151,7 @@
// @group(0) @binding(0)
// var a : texture_storage_1d<r32uint>;
- auto* st = ty(Source{{12, 34}}, "texture_storage_1d");
+ auto st = ty(Source{{12, 34}}, "texture_storage_1d");
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1180,7 +1163,7 @@
// @group(0) @binding(0)
// var a : texture_storage_1d<r32uint>;
- auto* st = ty(Source{{12, 34}}, "texture_storage_1d", "r32uint");
+ auto st = ty(Source{{12, 34}}, "texture_storage_1d", "r32uint");
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1192,8 +1175,8 @@
// @group(0) @binding(0)
// var a : texture_storage_1d<r32uint, read_write>;
- auto* st = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d,
- type::TexelFormat::kR32Uint, type::Access::kReadWrite);
+ auto st = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d,
+ type::TexelFormat::kR32Uint, type::Access::kReadWrite);
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1206,8 +1189,8 @@
// @group(0) @binding(0)
// var a : texture_storage_1d<r32uint, read>;
- auto* st = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d,
- type::TexelFormat::kR32Uint, type::Access::kRead);
+ auto st = ty.storage_texture(Source{{12, 34}}, type::TextureDimension::k1d,
+ type::TexelFormat::kR32Uint, type::Access::kRead);
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1220,8 +1203,8 @@
// @group(0) @binding(0)
// var a : texture_storage_1d<r32uint, write>;
- auto* st = ty.storage_texture(type::TextureDimension::k1d, type::TexelFormat::kR32Uint,
- type::Access::kWrite);
+ auto st = ty.storage_texture(type::TextureDimension::k1d, type::TexelFormat::kR32Uint,
+ type::Access::kWrite);
GlobalVar("a", st, Group(0_a), Binding(0_a));
@@ -1250,8 +1233,9 @@
Enable(ast::Extension::kF16);
- GlobalVar("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
- type::AddressSpace::kPrivate);
+ ast::Type el_ty = params.elem_ty(*this);
+
+ GlobalVar("a", ty.mat(el_ty, params.columns, params.rows), type::AddressSpace::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
@@ -1289,7 +1273,9 @@
Enable(ast::Extension::kF16);
- GlobalVar("a", ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns, params.rows),
+ ast::Type el_ty = params.elem_ty(*this);
+
+ GlobalVar("a", ty.mat(Source{{12, 34}}, el_ty, params.columns, params.rows),
type::AddressSpace::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32' or 'f16'");
@@ -1417,7 +1403,7 @@
Enable(ast::Extension::kF16);
- WrapInFunction(Decl(Var("v", params.type(*this), Call(ty(params.alias)))));
+ WrapInFunction(Decl(Var("v", params.type(*this), Call(params.alias))));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
@@ -1474,7 +1460,7 @@
TEST_P(ResolverUntemplatedTypeUsedWithTemplateArgs, BuiltinAlias_UseWithTemplateArgs) {
// enable f16;
// alias A = f32;
- // var<private> v : S<true>;
+ // var<private> v : A<true>;
Enable(ast::Extension::kF16);
Alias(Source{{56, 78}}, "A", ty(GetParam()));
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 403ac53..0b8fd7f 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -386,12 +386,27 @@
return current_function_->CreateNode(std::move(tag_list), ast);
}
- /// Get the symbol name of an AST node.
- /// @param ast the AST node to get the symbol name of
+ /// Get the symbol name of an AST expression.
+ /// @param expr the expression to get the symbol name of
/// @returns the symbol name
- template <typename T>
- inline std::string NameFor(const T* ast) {
- return builder_->Symbols().NameFor(ast->symbol);
+ inline std::string NameFor(const ast::IdentifierExpression* expr) {
+ return builder_->Symbols().NameFor(expr->identifier->symbol);
+ }
+
+ /// @param var the variable to get the name of
+ /// @returns the name of the variable @p var
+ inline std::string NameFor(const ast::Variable* var) {
+ return builder_->Symbols().NameFor(var->name->symbol);
+ }
+
+ /// @param var the variable to get the name of
+ /// @returns the name of the variable @p var
+ inline std::string NameFor(const sem::Variable* var) { return NameFor(var->Declaration()); }
+
+ /// @param fn the function to get the name of
+ /// @returns the name of the function @p fn
+ inline std::string NameFor(const sem::Function* fn) {
+ return builder_->Symbols().NameFor(fn->Declaration()->name->symbol);
}
/// Process a function.
@@ -621,7 +636,7 @@
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = NameFor(var->Declaration()->name);
+ auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
@@ -657,7 +672,7 @@
// Add an edge from the variable exit node to its value at this point.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = NameFor(var->Declaration()->name);
+ auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
@@ -730,8 +745,7 @@
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto* in_node =
- CreateNode({NameFor(v->Declaration()->name), "_value_forloop_in"});
+ auto* in_node = CreateNode({NameFor(v), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
@@ -748,7 +762,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = NameFor(var->Declaration()->name);
+ auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
@@ -806,8 +820,7 @@
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto* in_node =
- CreateNode({NameFor(v->Declaration()->name), "_value_forloop_in"});
+ auto* in_node = CreateNode({NameFor(v), "_value_forloop_in"});
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
current_function_->variables.Set(v, in_node);
@@ -825,7 +838,7 @@
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = NameFor(var->Declaration()->name);
+ auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
@@ -909,8 +922,7 @@
}
// Create an exit node for the variable.
- auto* out_node =
- CreateNode({NameFor(var->Declaration()->name), "_value_if_exit"});
+ auto* out_node = CreateNode({NameFor(var), "_value_if_exit"});
// Add edges to the assigned value or the initial value.
// Only add edges if the behavior for that block contains 'Next'.
@@ -966,7 +978,7 @@
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
- auto name = NameFor(v->Declaration()->name);
+ auto name = NameFor(v);
auto* in_node = CreateNode({name, "_value_loop_in"}, v->Declaration());
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes.Replace(v, in_node);
@@ -1065,7 +1077,7 @@
// Add an edge from the variable exit node to its new value.
auto* exit_node = info.var_exit_nodes.GetOrCreate(var, [&]() {
- auto name = NameFor(var->Declaration()->name);
+ auto name = NameFor(var);
return CreateNode({name, "_value_", info.type, "_exit"});
});
exit_node->AddEdge(current_function_->variables.Get(var));
@@ -1144,7 +1156,7 @@
return true;
};
- auto* node = CreateNode({NameFor(ident->identifier), "_ident_expr"}, ident);
+ auto* node = CreateNode({NameFor(ident), "_ident_expr"}, ident);
auto* sem_ident = sem_.GetVal(ident);
TINT_ASSERT(Resolver, sem_ident);
auto* var_user = sem_ident->Unwrap()->As<sem::VariableUser>();
@@ -1367,7 +1379,7 @@
return std::make_pair(cf, current_function_->may_be_non_uniform);
} else if (auto* local = sem->Variable()->As<sem::LocalVariable>()) {
// Create a new value node for this variable.
- auto* value = CreateNode({NameFor(i->identifier), "_lvalue"});
+ auto* value = CreateNode({NameFor(i), "_lvalue"});
auto* old_value = current_function_->variables.Set(local, value);
// If i is part of an expression that is a partial reference to a variable (e.g.
@@ -1404,7 +1416,7 @@
// Cut the analysis short, since we only need to know the originating variable
// that is being written to.
auto* root_ident = sem_.Get(u)->RootIdentifier();
- auto* deref = CreateNode({NameFor(root_ident->Declaration()->name), "_deref"});
+ auto* deref = CreateNode({NameFor(root_ident), "_deref"});
auto* old_value = current_function_->variables.Set(root_ident, deref);
if (old_value) {
@@ -1432,12 +1444,7 @@
/// @param call the function call to process
/// @returns a pair of (control flow node, value node)
std::pair<Node*, Node*> ProcessCall(Node* cf, const ast::CallExpression* call) {
- std::string name;
- if (call->target.name) {
- name = NameFor(call->target.name);
- } else {
- name = call->target.type->FriendlyName(builder_->Symbols());
- }
+ std::string name = NameFor(call->target);
// Process call arguments
Node* cf_last_arg = cf;
@@ -1771,10 +1778,10 @@
std::ostringstream ss;
if (auto* param = var->As<sem::Parameter>()) {
auto* func = param->Owner()->As<sem::Function>();
- ss << param_type(param) << "'" << NameFor(ident->identifier) << "' of '"
- << NameFor(func->Declaration()->name) << "' may be non-uniform";
+ ss << param_type(param) << "'" << NameFor(ident) << "' of '" << NameFor(func)
+ << "' may be non-uniform";
} else {
- ss << "reading from " << var_type(var) << "'" << NameFor(ident->identifier)
+ ss << "reading from " << var_type(var) << "'" << NameFor(ident)
<< "' may result in a non-uniform value";
}
diagnostics_.add_note(diag::System::Resolver, ss.str(), ident->source);
@@ -1782,12 +1789,12 @@
[&](const ast::Variable* v) {
auto* var = sem_.Get(v);
std::ostringstream ss;
- ss << "reading from " << var_type(var) << "'" << NameFor(v->name)
+ ss << "reading from " << var_type(var) << "'" << NameFor(v)
<< "' may result in a non-uniform value";
diagnostics_.add_note(diag::System::Resolver, ss.str(), v->source);
},
[&](const ast::CallExpression* c) {
- auto target_name = NameFor(c->target.name);
+ auto target_name = NameFor(c->target);
switch (non_uniform_source->type) {
case Node::kFunctionCallReturnValue: {
diagnostics_.add_note(
@@ -1799,8 +1806,7 @@
auto* arg = c->args[non_uniform_source->arg_index];
auto* var = sem_.GetVal(arg)->RootIdentifier();
std::ostringstream ss;
- ss << "reading from " << var_type(var) << "'"
- << NameFor(var->Declaration()->name)
+ ss << "reading from " << var_type(var) << "'" << NameFor(var)
<< "' may result in a non-uniform value";
diagnostics_.add_note(diag::System::Resolver, ss.str(),
var->Declaration()->source);
@@ -1867,7 +1873,7 @@
auto* call = cause->ast->As<ast::CallExpression>();
TINT_ASSERT(Resolver, call);
auto* target = SemCall(call)->Target();
- auto func_name = NameFor(call->target.name);
+ auto func_name = NameFor(call->target);
if (cause->type == Node::kFunctionCallArgumentValue ||
cause->type == Node::kFunctionCallArgumentContents) {
@@ -1897,7 +1903,7 @@
// Show a builtin was reachable from this call (which may be the call itself).
// This will be the trigger location for the failure.
std::ostringstream ss;
- ss << "'" << NameFor(builtin_call->target.name)
+ ss << "'" << NameFor(builtin_call->target)
<< "' must only be called from uniform control flow";
report(builtin_call->source, ss.str(), /* note */ false);
}
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index 719db08..1cf996d 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -5330,7 +5330,7 @@
args.Push(b.AddressOf(name));
}
main_body.Push(b.Assign("v0", "non_uniform_global"));
- main_body.Push(b.CallStmt(b.create<ast::CallExpression>(b.Ident("foo"), args)));
+ main_body.Push(b.CallStmt(b.Call("foo", args)));
main_body.Push(b.If(b.Equal("v254", 0_i), b.Block(b.CallStmt(b.Call("workgroupBarrier")))));
b.Func("main", utils::Empty, ty.void_(), main_body);
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index bbfd4aa..634e3a2 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -294,7 +294,7 @@
}
TEST_F(ResolverValidationTest, AddressSpace_SamplerExplicitAddressSpace) {
- auto* t = ty.sampler(type::SamplerKind::kSampler);
+ auto t = ty.sampler(type::SamplerKind::kSampler);
GlobalVar(Source{{12, 34}}, "var", t, type::AddressSpace::kHandle, Binding(0_a), Group(0_a));
EXPECT_FALSE(r()->Resolve());
@@ -304,7 +304,7 @@
}
TEST_F(ResolverValidationTest, AddressSpace_TextureExplicitAddressSpace) {
- auto* t = ty.sampled_texture(type::TextureDimension::k1d, ty.f32());
+ auto t = ty.sampled_texture(type::TextureDimension::k1d, ty.f32());
GlobalVar(Source{{12, 34}}, "var", t, type::AddressSpace::kHandle, Binding(0_a), Group(0_a));
EXPECT_FALSE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 71e39af..0a7abcf 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -19,7 +19,6 @@
#include <utility>
#include "src/tint/ast/alias.h"
-#include "src/tint/ast/array.h"
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/ast/break_statement.h"
@@ -33,16 +32,11 @@
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/loop_statement.h"
-#include "src/tint/ast/matrix.h"
-#include "src/tint/ast/pointer.h"
#include "src/tint/ast/return_statement.h"
-#include "src/tint/ast/sampled_texture.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/traverse_expressions.h"
-#include "src/tint/ast/type_name.h"
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
-#include "src/tint/ast/vector.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/sem/break_if_statement.h"
#include "src/tint/sem/call.h"
@@ -283,28 +277,28 @@
return nullptr;
}
-bool Validator::Atomic(const ast::Atomic* a, const type::Atomic* s) const {
+bool Validator::Atomic(const ast::TemplatedIdentifier* a, const type::Atomic* s) const {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// T must be either u32 or i32.
if (!s->Type()->IsAnyOf<type::U32, type::I32>()) {
- AddError("atomic only supports i32 or u32 types", a->type ? a->type->source : a->source);
+ AddError("atomic only supports i32 or u32 types", a->arguments[0]->source);
return false;
}
return true;
}
-bool Validator::Pointer(const ast::Pointer* a, const type::Pointer* s) const {
+bool Validator::Pointer(const ast::TemplatedIdentifier* a, const type::Pointer* s) const {
if (s->AddressSpace() == type::AddressSpace::kUndefined) {
AddError("ptr missing address space", a->source);
return false;
}
- if (a->access != type::Access::kUndefined) {
+ if (a->arguments.Length() > 2) { // ptr<address-space, type [, access]>
// https://www.w3.org/TR/WGSL/#access-mode-defaults
// When writing a variable declaration or a pointer type in WGSL source:
// * For the storage address space, the access mode is optional, and defaults to read.
// * For other address spaces, the access mode must not be written.
- if (a->address_space != type::AddressSpace::kStorage) {
+ if (s->AddressSpace() != type::AddressSpace::kStorage) {
AddError("only pointers in <storage> address space may declare an access mode",
a->source);
return false;
@@ -1567,10 +1561,10 @@
}
if (!is_call_statement) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
- // If the called function does not return a value, a function call
- // statement should be used instead.
- auto* ident = call->Declaration()->target.name;
- auto name = symbols_.NameFor(ident->symbol);
+ // If the called function does not return a value, a function call statement should be
+ // used instead.
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ auto name = utils::ToString(builtin->Type());
AddError("builtin '" + name + "' does not return a value", call->Declaration()->source);
return false;
}
@@ -1685,7 +1679,7 @@
bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_statement) const {
auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>();
- auto sym = decl->target.name->symbol;
+ auto sym = target->Declaration()->name->symbol;
auto name = symbols_.NameFor(sym);
if (!current_statement) { // Function call at module-scope.
@@ -1852,16 +1846,16 @@
return true;
}
-bool Validator::Vector(const type::Vector* ty, const Source& source) const {
- if (!ty->type()->is_scalar()) {
+bool Validator::Vector(const type::Type* el_ty, const Source& source) const {
+ if (!el_ty->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'f16', 'i32' or 'u32'", source);
return false;
}
return true;
}
-bool Validator::Matrix(const type::Matrix* ty, const Source& source) const {
- if (!ty->is_float_matrix()) {
+bool Validator::Matrix(const type::Type* el_ty, const Source& source) const {
+ if (!el_ty->is_float_scalar()) {
AddError("matrix element type must be 'f32' or 'f16'", source);
return false;
}
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index 9e11f50..dc1bea6 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -185,13 +185,13 @@
/// @param a the atomic ast node
/// @param s the atomic sem node
/// @returns true on success, false otherwise.
- bool Atomic(const ast::Atomic* a, const type::Atomic* s) const;
+ bool Atomic(const ast::TemplatedIdentifier* a, const type::Atomic* s) const;
/// Validates a pointer type
/// @param a the pointer ast node
/// @param s the pointer sem node
/// @returns true on success, false otherwise.
- bool Pointer(const ast::Pointer* a, const type::Pointer* s) const;
+ bool Pointer(const ast::TemplatedIdentifier* a, const type::Pointer* s) const;
/// Validates an assignment
/// @param a the assignment statement
@@ -343,10 +343,10 @@
bool Materialize(const type::Type* to, const type::Type* from, const Source& source) const;
/// Validates a matrix
- /// @param ty the matrix to validate
+ /// @param el_ty the matrix element type to validate
/// @param source the source of the matrix
/// @returns true on success, false otherwise
- bool Matrix(const type::Matrix* ty, const Source& source) const;
+ bool Matrix(const type::Type* el_ty, const Source& source) const;
/// Validates a function parameter
/// @param func the function the variable is for
@@ -440,10 +440,10 @@
const sem::ValueExpression* initializer) const;
/// Validates a vector
- /// @param ty the vector to validate
+ /// @param el_ty the vector element type to validate
/// @param source the source of the vector
/// @returns true on success, false otherwise
- bool Vector(const type::Vector* ty, const Source& source) const;
+ bool Vector(const type::Type* el_ty, const Source& source) const;
/// Validates an array initializer
/// @param ctor the call expresion to validate
diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc
index 9d00a5a..c5af49c 100644
--- a/src/tint/resolver/variable_test.cc
+++ b/src/tint/resolver/variable_test.cc
@@ -902,7 +902,7 @@
auto* c_vu32 = Const("e", ty.vec3<u32>(), vec3<u32>());
auto* c_vf32 = Const("f", ty.vec3<f32>(), vec3<f32>());
auto* c_mf32 = Const("g", ty.mat3x3<f32>(), mat3x3<f32>());
- auto* c_s = Const("h", ty("S"), Call(ty("S")));
+ auto* c_s = Const("h", ty("S"), Call("S"));
WrapInFunction(c_i32, c_u32, c_f32, c_vi32, c_vu32, c_vf32, c_mf32, c_s);
@@ -947,14 +947,14 @@
auto* c_vi32 = Const("f", vec3<i32>());
auto* c_vu32 = Const("g", vec3<u32>());
auto* c_vf32 = Const("h", vec3<f32>());
- auto* c_vai = Const("i", Call(ty.vec(nullptr, 3), Expr(0_a)));
- auto* c_vaf = Const("j", Call(ty.vec(nullptr, 3), Expr(0._a)));
+ auto* c_vai = Const("i", Call(ty.vec<Infer>(3), Expr(0_a)));
+ auto* c_vaf = Const("j", Call(ty.vec<Infer>(3), Expr(0._a)));
auto* c_mf32 = Const("k", mat3x3<f32>());
auto* c_maf32 =
- Const("l", Call(ty.mat(nullptr, 3, 3), //
- Call(ty.vec(nullptr, 3), Expr(0._a)), Call(ty.vec(nullptr, 3), Expr(0._a)),
- Call(ty.vec(nullptr, 3), Expr(0._a))));
- auto* c_s = Const("m", Call(ty("S")));
+ Const("l", Call(ty.mat3x3<Infer>(), //
+ Call(ty.vec<Infer>(3), Expr(0._a)), Call(ty.vec<Infer>(3), Expr(0._a)),
+ Call(ty.vec<Infer>(3), Expr(0._a))));
+ auto* c_s = Const("m", Call("S"));
WrapInFunction(c_i32, c_u32, c_f32, c_ai, c_af, c_vi32, c_vu32, c_vf32, c_vai, c_vaf, c_mf32,
c_maf32, c_s);
@@ -1123,13 +1123,13 @@
auto* c_vi32 = GlobalConst("f", vec3<i32>());
auto* c_vu32 = GlobalConst("g", vec3<u32>());
auto* c_vf32 = GlobalConst("h", vec3<f32>());
- auto* c_vai = GlobalConst("i", Call(ty.vec(nullptr, 3), Expr(0_a)));
- auto* c_vaf = GlobalConst("j", Call(ty.vec(nullptr, 3), Expr(0._a)));
+ auto* c_vai = GlobalConst("i", Call(ty.vec<Infer>(3), Expr(0_a)));
+ auto* c_vaf = GlobalConst("j", Call(ty.vec<Infer>(3), Expr(0._a)));
auto* c_mf32 = GlobalConst("k", mat3x3<f32>());
auto* c_maf32 = GlobalConst(
- "l", Call(ty.mat(nullptr, 3, 3), //
- Call(ty.vec(nullptr, 3), Expr(0._a)), Call(ty.vec(nullptr, 3), Expr(0._a)),
- Call(ty.vec(nullptr, 3), Expr(0._a))));
+ "l", Call(ty.mat3x3<Infer>(), //
+ Call(ty.vec<Infer>(3), Expr(0._a)), Call(ty.vec<Infer>(3), Expr(0._a)),
+ Call(ty.vec<Infer>(3), Expr(0._a))));
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc
index 5434042..2649562 100644
--- a/src/tint/resolver/variable_validation_test.cc
+++ b/src/tint/resolver/variable_validation_test.cc
@@ -369,52 +369,51 @@
}
TEST_F(ResolverVariableValidationTest, VectorConstNoType) {
- // const a : mat3x3 = mat3x3<f32>();
- WrapInFunction(Const("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3u), vec3<f32>()));
+ // const a vec3 = vec3<f32>();
+ WrapInFunction(Const("a", ty.vec3<Infer>(Source{{12, 34}}), vec3<f32>()));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'vec3'");
}
TEST_F(ResolverVariableValidationTest, VectorLetNoType) {
- // let a : mat3x3 = mat3x3<f32>();
- WrapInFunction(Let("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3u), vec3<f32>()));
+ // let a : vec3 = vec3<f32>();
+ WrapInFunction(Let("a", ty.vec3<Infer>(Source{{12, 34}}), vec3<f32>()));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'vec3'");
}
TEST_F(ResolverVariableValidationTest, VectorVarNoType) {
- // var a : mat3x3;
- WrapInFunction(Var("a", create<ast::Vector>(Source{{12, 34}}, nullptr, 3u)));
+ // var a : vec3;
+ WrapInFunction(Var("a", ty.vec3<Infer>(Source{{12, 34}})));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing vector element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'vec3'");
}
TEST_F(ResolverVariableValidationTest, MatrixConstNoType) {
// const a : mat3x3 = mat3x3<f32>();
- WrapInFunction(
- Const("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3u, 3u), mat3x3<f32>()));
+ WrapInFunction(Const("a", ty.mat3x3<Infer>(Source{{12, 34}}), mat3x3<f32>()));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'mat3x3'");
}
TEST_F(ResolverVariableValidationTest, MatrixLetNoType) {
// let a : mat3x3 = mat3x3<f32>();
- WrapInFunction(Let("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3u, 3u), mat3x3<f32>()));
+ WrapInFunction(Let("a", ty.mat3x3<Infer>(Source{{12, 34}}), mat3x3<f32>()));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'mat3x3'");
}
TEST_F(ResolverVariableValidationTest, MatrixVarNoType) {
// var a : mat3x3;
- WrapInFunction(Var("a", create<ast::Matrix>(Source{{12, 34}}, nullptr, 3u, 3u)));
+ WrapInFunction(Var("a", ty.mat3x3<Infer>(Source{{12, 34}})));
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: missing matrix element type");
+ EXPECT_EQ(r()->error(), "12:34 error: expected '<' for 'mat3x3'");
}
TEST_F(ResolverVariableValidationTest, GlobalConstWithRuntimeExpression) {